From 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 Mon Sep 17 00:00:00 2001 From: cyfraeviolae Date: Wed, 3 Apr 2024 03:10:44 -0400 Subject: venv --- .../site-packages/sqlalchemy/__init__.py | 294 + .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 13956 bytes .../sqlalchemy/__pycache__/events.cpython-311.pyc | Bin 0 -> 671 bytes .../sqlalchemy/__pycache__/exc.cpython-311.pyc | Bin 0 -> 34771 bytes .../__pycache__/inspection.cpython-311.pyc | Bin 0 -> 7435 bytes .../sqlalchemy/__pycache__/log.cpython-311.pyc | Bin 0 -> 12306 bytes .../sqlalchemy/__pycache__/schema.cpython-311.pyc | Bin 0 -> 3177 bytes .../sqlalchemy/__pycache__/types.cpython-311.pyc | Bin 0 -> 3240 bytes .../sqlalchemy/connectors/__init__.py | 18 + .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 682 bytes .../connectors/__pycache__/aioodbc.cpython-311.pyc | Bin 0 -> 8027 bytes .../connectors/__pycache__/asyncio.cpython-311.pyc | Bin 0 -> 12757 bytes .../connectors/__pycache__/pyodbc.cpython-311.pyc | Bin 0 -> 11185 bytes .../site-packages/sqlalchemy/connectors/aioodbc.py | 174 + .../site-packages/sqlalchemy/connectors/asyncio.py | 208 + .../site-packages/sqlalchemy/connectors/pyodbc.py | 249 + .../sqlalchemy/cyextension/__init__.py | 6 + .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 206 bytes .../collections.cpython-311-x86_64-linux-gnu.so | Bin 0 -> 2019496 bytes .../sqlalchemy/cyextension/collections.pyx | 409 + .../immutabledict.cpython-311-x86_64-linux-gnu.so | Bin 0 -> 703720 bytes .../sqlalchemy/cyextension/immutabledict.pxd | 8 + .../sqlalchemy/cyextension/immutabledict.pyx | 133 + .../processors.cpython-311-x86_64-linux-gnu.so | Bin 0 -> 509544 bytes .../sqlalchemy/cyextension/processors.pyx | 68 + .../resultproxy.cpython-311-x86_64-linux-gnu.so | Bin 0 -> 586752 bytes .../sqlalchemy/cyextension/resultproxy.pyx | 102 + .../util.cpython-311-x86_64-linux-gnu.so | Bin 0 -> 870128 bytes .../site-packages/sqlalchemy/cyextension/util.pyx | 91 + .../site-packages/sqlalchemy/dialects/__init__.py | 61 + .../dialects/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 2097 bytes .../dialects/__pycache__/_typing.cpython-311.pyc | Bin 0 -> 1094 bytes .../site-packages/sqlalchemy/dialects/_typing.py | 25 + .../sqlalchemy/dialects/mssql/__init__.py | 88 + .../mssql/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 2226 bytes .../mssql/__pycache__/aioodbc.cpython-311.pyc | Bin 0 -> 2591 bytes .../mssql/__pycache__/base.cpython-311.pyc | Bin 0 -> 157867 bytes .../__pycache__/information_schema.cpython-311.pyc | Bin 0 -> 9859 bytes .../mssql/__pycache__/json.cpython-311.pyc | Bin 0 -> 5828 bytes .../mssql/__pycache__/provision.cpython-311.pyc | Bin 0 -> 8399 bytes .../mssql/__pycache__/pymssql.cpython-311.pyc | Bin 0 -> 6678 bytes .../mssql/__pycache__/pyodbc.cpython-311.pyc | Bin 0 -> 33146 bytes .../sqlalchemy/dialects/mssql/aioodbc.py | 64 + .../sqlalchemy/dialects/mssql/base.py | 4007 ++++++++++ .../dialects/mssql/information_schema.py | 254 + .../sqlalchemy/dialects/mssql/json.py | 133 + .../sqlalchemy/dialects/mssql/provision.py | 155 + .../sqlalchemy/dialects/mssql/pymssql.py | 125 + .../sqlalchemy/dialects/mssql/pyodbc.py | 745 ++ .../sqlalchemy/dialects/mysql/__init__.py | 101 + .../mysql/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 2654 bytes .../mysql/__pycache__/aiomysql.cpython-311.pyc | Bin 0 -> 18123 bytes .../mysql/__pycache__/asyncmy.cpython-311.pyc | Bin 0 -> 18787 bytes .../mysql/__pycache__/base.cpython-311.pyc | Bin 0 -> 145328 bytes .../mysql/__pycache__/cymysql.cpython-311.pyc | Bin 0 -> 3348 bytes .../dialects/mysql/__pycache__/dml.cpython-311.pyc | Bin 0 -> 9104 bytes .../mysql/__pycache__/enumerated.cpython-311.pyc | Bin 0 -> 11258 bytes .../mysql/__pycache__/expression.cpython-311.pyc | Bin 0 -> 5392 bytes .../mysql/__pycache__/json.cpython-311.pyc | Bin 0 -> 3982 bytes .../mysql/__pycache__/mariadb.cpython-311.pyc | Bin 0 -> 1171 bytes .../__pycache__/mariadbconnector.cpython-311.pyc | Bin 0 -> 12859 bytes .../__pycache__/mysqlconnector.cpython-311.pyc | Bin 0 -> 9700 bytes .../mysql/__pycache__/mysqldb.cpython-311.pyc | Bin 0 -> 12714 bytes .../mysql/__pycache__/provision.cpython-311.pyc | Bin 0 -> 4848 bytes .../mysql/__pycache__/pymysql.cpython-311.pyc | Bin 0 -> 5649 bytes .../mysql/__pycache__/pyodbc.cpython-311.pyc | Bin 0 -> 5859 bytes .../mysql/__pycache__/reflection.cpython-311.pyc | Bin 0 -> 27141 bytes .../__pycache__/reserved_words.cpython-311.pyc | Bin 0 -> 4446 bytes .../mysql/__pycache__/types.cpython-311.pyc | Bin 0 -> 33731 bytes .../sqlalchemy/dialects/mysql/aiomysql.py | 332 + .../sqlalchemy/dialects/mysql/asyncmy.py | 337 + .../sqlalchemy/dialects/mysql/base.py | 3447 +++++++++ .../sqlalchemy/dialects/mysql/cymysql.py | 84 + .../site-packages/sqlalchemy/dialects/mysql/dml.py | 219 + .../sqlalchemy/dialects/mysql/enumerated.py | 244 + .../sqlalchemy/dialects/mysql/expression.py | 141 + .../sqlalchemy/dialects/mysql/json.py | 81 + .../sqlalchemy/dialects/mysql/mariadb.py | 32 + .../sqlalchemy/dialects/mysql/mariadbconnector.py | 275 + .../sqlalchemy/dialects/mysql/mysqlconnector.py | 179 + .../sqlalchemy/dialects/mysql/mysqldb.py | 303 + .../sqlalchemy/dialects/mysql/provision.py | 107 + .../sqlalchemy/dialects/mysql/pymysql.py | 137 + .../sqlalchemy/dialects/mysql/pyodbc.py | 138 + .../sqlalchemy/dialects/mysql/reflection.py | 677 ++ .../sqlalchemy/dialects/mysql/reserved_words.py | 571 ++ .../sqlalchemy/dialects/mysql/types.py | 774 ++ .../sqlalchemy/dialects/oracle/__init__.py | 67 + .../oracle/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1731 bytes .../oracle/__pycache__/base.cpython-311.pyc | Bin 0 -> 136700 bytes .../oracle/__pycache__/cx_oracle.cpython-311.pyc | Bin 0 -> 62959 bytes .../oracle/__pycache__/dictionary.cpython-311.pyc | Bin 0 -> 32392 bytes .../oracle/__pycache__/oracledb.cpython-311.pyc | Bin 0 -> 15311 bytes .../oracle/__pycache__/provision.cpython-311.pyc | Bin 0 -> 12688 bytes .../oracle/__pycache__/types.cpython-311.pyc | Bin 0 -> 13844 bytes .../sqlalchemy/dialects/oracle/base.py | 3240 ++++++++ .../sqlalchemy/dialects/oracle/cx_oracle.py | 1492 ++++ .../sqlalchemy/dialects/oracle/dictionary.py | 507 ++ .../sqlalchemy/dialects/oracle/oracledb.py | 311 + .../sqlalchemy/dialects/oracle/provision.py | 220 + .../sqlalchemy/dialects/oracle/types.py | 287 + .../sqlalchemy/dialects/postgresql/__init__.py | 167 + .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 4640 bytes .../__pycache__/_psycopg_common.cpython-311.pyc | Bin 0 -> 8755 bytes .../postgresql/__pycache__/array.cpython-311.pyc | Bin 0 -> 17843 bytes .../postgresql/__pycache__/asyncpg.cpython-311.pyc | Bin 0 -> 61182 bytes .../postgresql/__pycache__/base.cpython-311.pyc | Bin 0 -> 207977 bytes .../postgresql/__pycache__/dml.cpython-311.pyc | Bin 0 -> 12630 bytes .../postgresql/__pycache__/ext.cpython-311.pyc | Bin 0 -> 20760 bytes .../postgresql/__pycache__/hstore.cpython-311.pyc | Bin 0 -> 16586 bytes .../postgresql/__pycache__/json.cpython-311.pyc | Bin 0 -> 14343 bytes .../__pycache__/named_types.cpython-311.pyc | Bin 0 -> 25105 bytes .../__pycache__/operators.cpython-311.pyc | Bin 0 -> 2188 bytes .../postgresql/__pycache__/pg8000.cpython-311.pyc | Bin 0 -> 32853 bytes .../__pycache__/pg_catalog.cpython-311.pyc | Bin 0 -> 13658 bytes .../__pycache__/provision.cpython-311.pyc | Bin 0 -> 9178 bytes .../postgresql/__pycache__/psycopg.cpython-311.pyc | Bin 0 -> 39132 bytes .../__pycache__/psycopg2.cpython-311.pyc | Bin 0 -> 36923 bytes .../__pycache__/psycopg2cffi.cpython-311.pyc | Bin 0 -> 2307 bytes .../postgresql/__pycache__/ranges.cpython-311.pyc | Bin 0 -> 37622 bytes .../postgresql/__pycache__/types.cpython-311.pyc | Bin 0 -> 12406 bytes .../dialects/postgresql/_psycopg_common.py | 187 + .../sqlalchemy/dialects/postgresql/array.py | 425 ++ .../sqlalchemy/dialects/postgresql/asyncpg.py | 1262 ++++ .../sqlalchemy/dialects/postgresql/base.py | 5007 +++++++++++++ .../sqlalchemy/dialects/postgresql/dml.py | 310 + .../sqlalchemy/dialects/postgresql/ext.py | 496 ++ .../sqlalchemy/dialects/postgresql/hstore.py | 397 + .../sqlalchemy/dialects/postgresql/json.py | 325 + .../sqlalchemy/dialects/postgresql/named_types.py | 509 ++ .../sqlalchemy/dialects/postgresql/operators.py | 129 + .../sqlalchemy/dialects/postgresql/pg8000.py | 662 ++ .../sqlalchemy/dialects/postgresql/pg_catalog.py | 300 + .../sqlalchemy/dialects/postgresql/provision.py | 175 + .../sqlalchemy/dialects/postgresql/psycopg.py | 749 ++ .../sqlalchemy/dialects/postgresql/psycopg2.py | 876 +++ .../sqlalchemy/dialects/postgresql/psycopg2cffi.py | 61 + .../sqlalchemy/dialects/postgresql/ranges.py | 1029 +++ .../sqlalchemy/dialects/postgresql/types.py | 303 + .../sqlalchemy/dialects/sqlite/__init__.py | 57 + .../sqlite/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1381 bytes .../sqlite/__pycache__/aiosqlite.cpython-311.pyc | Bin 0 -> 19317 bytes .../sqlite/__pycache__/base.cpython-311.pyc | Bin 0 -> 104751 bytes .../sqlite/__pycache__/dml.cpython-311.pyc | Bin 0 -> 10200 bytes .../sqlite/__pycache__/json.cpython-311.pyc | Bin 0 -> 4319 bytes .../sqlite/__pycache__/provision.cpython-311.pyc | Bin 0 -> 7990 bytes .../sqlite/__pycache__/pysqlcipher.cpython-311.pyc | Bin 0 -> 6604 bytes .../sqlite/__pycache__/pysqlite.cpython-311.pyc | Bin 0 -> 33807 bytes .../sqlalchemy/dialects/sqlite/aiosqlite.py | 396 + .../sqlalchemy/dialects/sqlite/base.py | 2782 +++++++ .../sqlalchemy/dialects/sqlite/dml.py | 240 + .../sqlalchemy/dialects/sqlite/json.py | 92 + .../sqlalchemy/dialects/sqlite/provision.py | 198 + .../sqlalchemy/dialects/sqlite/pysqlcipher.py | 155 + .../sqlalchemy/dialects/sqlite/pysqlite.py | 756 ++ .../dialects/type_migration_guidelines.txt | 145 + .../site-packages/sqlalchemy/engine/__init__.py | 62 + .../engine/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 2961 bytes .../__pycache__/_py_processors.cpython-311.pyc | Bin 0 -> 5171 bytes .../engine/__pycache__/_py_row.cpython-311.pyc | Bin 0 -> 6772 bytes .../engine/__pycache__/_py_util.cpython-311.pyc | Bin 0 -> 2564 bytes .../engine/__pycache__/base.cpython-311.pyc | Bin 0 -> 133235 bytes .../__pycache__/characteristics.cpython-311.pyc | Bin 0 -> 3920 bytes .../engine/__pycache__/create.cpython-311.pyc | Bin 0 -> 35466 bytes .../engine/__pycache__/cursor.cpython-311.pyc | Bin 0 -> 87327 bytes .../engine/__pycache__/default.cpython-311.pyc | Bin 0 -> 93432 bytes .../engine/__pycache__/events.cpython-311.pyc | Bin 0 -> 40574 bytes .../engine/__pycache__/interfaces.cpython-311.pyc | Bin 0 -> 103491 bytes .../engine/__pycache__/mock.cpython-311.pyc | Bin 0 -> 6280 bytes .../engine/__pycache__/processors.cpython-311.pyc | Bin 0 -> 1664 bytes .../engine/__pycache__/reflection.cpython-311.pyc | Bin 0 -> 85712 bytes .../engine/__pycache__/result.cpython-311.pyc | Bin 0 -> 101187 bytes .../engine/__pycache__/row.cpython-311.pyc | Bin 0 -> 19534 bytes .../engine/__pycache__/strategies.cpython-311.pyc | Bin 0 -> 653 bytes .../engine/__pycache__/url.cpython-311.pyc | Bin 0 -> 36722 bytes .../engine/__pycache__/util.cpython-311.pyc | Bin 0 -> 7641 bytes .../sqlalchemy/engine/_py_processors.py | 136 + .../site-packages/sqlalchemy/engine/_py_row.py | 128 + .../site-packages/sqlalchemy/engine/_py_util.py | 74 + .../site-packages/sqlalchemy/engine/base.py | 3377 +++++++++ .../sqlalchemy/engine/characteristics.py | 81 + .../site-packages/sqlalchemy/engine/create.py | 875 +++ .../site-packages/sqlalchemy/engine/cursor.py | 2178 ++++++ .../site-packages/sqlalchemy/engine/default.py | 2343 ++++++ .../site-packages/sqlalchemy/engine/events.py | 951 +++ .../site-packages/sqlalchemy/engine/interfaces.py | 3395 +++++++++ .../site-packages/sqlalchemy/engine/mock.py | 131 + .../site-packages/sqlalchemy/engine/processors.py | 61 + .../site-packages/sqlalchemy/engine/reflection.py | 2089 ++++++ .../site-packages/sqlalchemy/engine/result.py | 2382 ++++++ .../site-packages/sqlalchemy/engine/row.py | 401 + .../site-packages/sqlalchemy/engine/strategies.py | 19 + .../site-packages/sqlalchemy/engine/url.py | 910 +++ .../site-packages/sqlalchemy/engine/util.py | 167 + .../site-packages/sqlalchemy/event/__init__.py | 25 + .../event/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1099 bytes .../event/__pycache__/api.cpython-311.pyc | Bin 0 -> 9489 bytes .../event/__pycache__/attr.cpython-311.pyc | Bin 0 -> 33953 bytes .../event/__pycache__/base.cpython-311.pyc | Bin 0 -> 21931 bytes .../event/__pycache__/legacy.cpython-311.pyc | Bin 0 -> 10089 bytes .../event/__pycache__/registry.cpython-311.pyc | Bin 0 -> 13444 bytes .../site-packages/sqlalchemy/event/api.py | 225 + .../site-packages/sqlalchemy/event/attr.py | 655 ++ .../site-packages/sqlalchemy/event/base.py | 462 ++ .../site-packages/sqlalchemy/event/legacy.py | 246 + .../site-packages/sqlalchemy/event/registry.py | 386 + .../python3.11/site-packages/sqlalchemy/events.py | 17 + .../lib/python3.11/site-packages/sqlalchemy/exc.py | 830 +++ .../site-packages/sqlalchemy/ext/__init__.py | 11 + .../ext/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 373 bytes .../__pycache__/associationproxy.cpython-311.pyc | Bin 0 -> 94141 bytes .../ext/__pycache__/automap.cpython-311.pyc | Bin 0 -> 58905 bytes .../ext/__pycache__/baked.cpython-311.pyc | Bin 0 -> 25104 bytes .../ext/__pycache__/compiler.cpython-311.pyc | Bin 0 -> 21019 bytes .../__pycache__/horizontal_shard.cpython-311.pyc | Bin 0 -> 19003 bytes .../ext/__pycache__/hybrid.cpython-311.pyc | Bin 0 -> 62274 bytes .../ext/__pycache__/indexable.cpython-311.pyc | Bin 0 -> 12654 bytes .../__pycache__/instrumentation.cpython-311.pyc | Bin 0 -> 21569 bytes .../ext/__pycache__/mutable.cpython-311.pyc | Bin 0 -> 50959 bytes .../ext/__pycache__/orderinglist.cpython-311.pyc | Bin 0 -> 18761 bytes .../ext/__pycache__/serializer.cpython-311.pyc | Bin 0 -> 8182 bytes .../sqlalchemy/ext/associationproxy.py | 2005 +++++ .../sqlalchemy/ext/asyncio/__init__.py | 25 + .../asyncio/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1250 bytes .../ext/asyncio/__pycache__/base.cpython-311.pyc | Bin 0 -> 12227 bytes .../ext/asyncio/__pycache__/engine.cpython-311.pyc | Bin 0 -> 59576 bytes .../ext/asyncio/__pycache__/exc.cpython-311.pyc | Bin 0 -> 1191 bytes .../ext/asyncio/__pycache__/result.cpython-311.pyc | Bin 0 -> 39306 bytes .../asyncio/__pycache__/scoping.cpython-311.pyc | Bin 0 -> 57031 bytes .../asyncio/__pycache__/session.cpython-311.pyc | Bin 0 -> 72886 bytes .../site-packages/sqlalchemy/ext/asyncio/base.py | 279 + .../site-packages/sqlalchemy/ext/asyncio/engine.py | 1466 ++++ .../site-packages/sqlalchemy/ext/asyncio/exc.py | 21 + .../site-packages/sqlalchemy/ext/asyncio/result.py | 961 +++ .../sqlalchemy/ext/asyncio/scoping.py | 1614 ++++ .../sqlalchemy/ext/asyncio/session.py | 1936 +++++ .../site-packages/sqlalchemy/ext/automap.py | 1658 +++++ .../site-packages/sqlalchemy/ext/baked.py | 574 ++ .../site-packages/sqlalchemy/ext/compiler.py | 555 ++ .../sqlalchemy/ext/declarative/__init__.py | 65 + .../__pycache__/__init__.cpython-311.pyc | Bin 0 -> 2254 bytes .../__pycache__/extensions.cpython-311.pyc | Bin 0 -> 22442 bytes .../sqlalchemy/ext/declarative/extensions.py | 548 ++ .../sqlalchemy/ext/horizontal_shard.py | 481 ++ .../site-packages/sqlalchemy/ext/hybrid.py | 1514 ++++ .../site-packages/sqlalchemy/ext/indexable.py | 341 + .../sqlalchemy/ext/instrumentation.py | 450 ++ .../site-packages/sqlalchemy/ext/mutable.py | 1073 +++ .../site-packages/sqlalchemy/ext/mypy/__init__.py | 6 + .../ext/mypy/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 203 bytes .../ext/mypy/__pycache__/apply.cpython-311.pyc | Bin 0 -> 11087 bytes .../mypy/__pycache__/decl_class.cpython-311.pyc | Bin 0 -> 15907 bytes .../ext/mypy/__pycache__/infer.cpython-311.pyc | Bin 0 -> 16021 bytes .../ext/mypy/__pycache__/names.cpython-311.pyc | Bin 0 -> 11696 bytes .../ext/mypy/__pycache__/plugin.cpython-311.pyc | Bin 0 -> 13050 bytes .../ext/mypy/__pycache__/util.cpython-311.pyc | Bin 0 -> 14939 bytes .../site-packages/sqlalchemy/ext/mypy/apply.py | 320 + .../sqlalchemy/ext/mypy/decl_class.py | 515 ++ .../site-packages/sqlalchemy/ext/mypy/infer.py | 590 ++ .../site-packages/sqlalchemy/ext/mypy/names.py | 335 + .../site-packages/sqlalchemy/ext/mypy/plugin.py | 303 + .../site-packages/sqlalchemy/ext/mypy/util.py | 338 + .../site-packages/sqlalchemy/ext/orderinglist.py | 416 ++ .../site-packages/sqlalchemy/ext/serializer.py | 185 + .../site-packages/sqlalchemy/future/__init__.py | 16 + .../future/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 532 bytes .../future/__pycache__/engine.cpython-311.pyc | Bin 0 -> 450 bytes .../site-packages/sqlalchemy/future/engine.py | 15 + .../site-packages/sqlalchemy/inspection.py | 174 + .../lib/python3.11/site-packages/sqlalchemy/log.py | 288 + .../site-packages/sqlalchemy/orm/__init__.py | 170 + .../orm/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 8557 bytes .../__pycache__/_orm_constructors.cpython-311.pyc | Bin 0 -> 100035 bytes .../orm/__pycache__/_typing.cpython-311.pyc | Bin 0 -> 7834 bytes .../orm/__pycache__/attributes.cpython-311.pyc | Bin 0 -> 104505 bytes .../orm/__pycache__/base.cpython-311.pyc | Bin 0 -> 32806 bytes .../__pycache__/bulk_persistence.cpython-311.pyc | Bin 0 -> 70614 bytes .../orm/__pycache__/clsregistry.cpython-311.pyc | Bin 0 -> 26896 bytes .../orm/__pycache__/collections.cpython-311.pyc | Bin 0 -> 68350 bytes .../orm/__pycache__/context.cpython-311.pyc | Bin 0 -> 103753 bytes .../orm/__pycache__/decl_api.cpython-311.pyc | Bin 0 -> 70976 bytes .../orm/__pycache__/decl_base.cpython-311.pyc | Bin 0 -> 76416 bytes .../orm/__pycache__/dependency.cpython-311.pyc | Bin 0 -> 44509 bytes .../__pycache__/descriptor_props.cpython-311.pyc | Bin 0 -> 53451 bytes .../orm/__pycache__/dynamic.cpython-311.pyc | Bin 0 -> 14108 bytes .../orm/__pycache__/evaluator.cpython-311.pyc | Bin 0 -> 17661 bytes .../orm/__pycache__/events.cpython-311.pyc | Bin 0 -> 140389 bytes .../sqlalchemy/orm/__pycache__/exc.cpython-311.pyc | Bin 0 -> 11103 bytes .../orm/__pycache__/identity.cpython-311.pyc | Bin 0 -> 13935 bytes .../__pycache__/instrumentation.cpython-311.pyc | Bin 0 -> 33762 bytes .../orm/__pycache__/interfaces.cpython-311.pyc | Bin 0 -> 56553 bytes .../orm/__pycache__/loading.cpython-311.pyc | Bin 0 -> 51952 bytes .../__pycache__/mapped_collection.cpython-311.pyc | Bin 0 -> 23740 bytes .../orm/__pycache__/mapper.cpython-311.pyc | Bin 0 -> 175532 bytes .../orm/__pycache__/path_registry.cpython-311.pyc | Bin 0 -> 34743 bytes .../orm/__pycache__/persistence.cpython-311.pyc | Bin 0 -> 50760 bytes .../orm/__pycache__/properties.cpython-311.pyc | Bin 0 -> 34341 bytes .../orm/__pycache__/query.cpython-311.pyc | Bin 0 -> 132189 bytes .../orm/__pycache__/relationships.cpython-311.pyc | Bin 0 -> 135454 bytes .../orm/__pycache__/scoping.cpython-311.pyc | Bin 0 -> 84345 bytes .../orm/__pycache__/session.cpython-311.pyc | Bin 0 -> 205815 bytes .../orm/__pycache__/state.cpython-311.pyc | Bin 0 -> 47732 bytes .../orm/__pycache__/state_changes.cpython-311.pyc | Bin 0 -> 7462 bytes .../orm/__pycache__/strategies.cpython-311.pyc | Bin 0 -> 109480 bytes .../__pycache__/strategy_options.cpython-311.pyc | Bin 0 -> 90575 bytes .../orm/__pycache__/sync.cpython-311.pyc | Bin 0 -> 6977 bytes .../orm/__pycache__/unitofwork.cpython-311.pyc | Bin 0 -> 37079 bytes .../orm/__pycache__/util.cpython-311.pyc | Bin 0 -> 92891 bytes .../orm/__pycache__/writeonly.cpython-311.pyc | Bin 0 -> 29722 bytes .../sqlalchemy/orm/_orm_constructors.py | 2471 +++++++ .../site-packages/sqlalchemy/orm/_typing.py | 179 + .../site-packages/sqlalchemy/orm/attributes.py | 2835 +++++++ .../site-packages/sqlalchemy/orm/base.py | 971 +++ .../sqlalchemy/orm/bulk_persistence.py | 2048 +++++ .../site-packages/sqlalchemy/orm/clsregistry.py | 570 ++ .../site-packages/sqlalchemy/orm/collections.py | 1618 ++++ .../site-packages/sqlalchemy/orm/context.py | 3243 ++++++++ .../site-packages/sqlalchemy/orm/decl_api.py | 1875 +++++ .../site-packages/sqlalchemy/orm/decl_base.py | 2152 ++++++ .../site-packages/sqlalchemy/orm/dependency.py | 1304 ++++ .../sqlalchemy/orm/descriptor_props.py | 1074 +++ .../site-packages/sqlalchemy/orm/dynamic.py | 298 + .../site-packages/sqlalchemy/orm/evaluator.py | 368 + .../site-packages/sqlalchemy/orm/events.py | 3259 ++++++++ .../python3.11/site-packages/sqlalchemy/orm/exc.py | 228 + .../site-packages/sqlalchemy/orm/identity.py | 302 + .../sqlalchemy/orm/instrumentation.py | 754 ++ .../site-packages/sqlalchemy/orm/interfaces.py | 1469 ++++ .../site-packages/sqlalchemy/orm/loading.py | 1665 +++++ .../sqlalchemy/orm/mapped_collection.py | 560 ++ .../site-packages/sqlalchemy/orm/mapper.py | 4420 +++++++++++ .../site-packages/sqlalchemy/orm/path_registry.py | 808 ++ .../site-packages/sqlalchemy/orm/persistence.py | 1782 +++++ .../site-packages/sqlalchemy/orm/properties.py | 886 +++ .../site-packages/sqlalchemy/orm/query.py | 3394 +++++++++ .../site-packages/sqlalchemy/orm/relationships.py | 3500 +++++++++ .../site-packages/sqlalchemy/orm/scoping.py | 2165 ++++++ .../site-packages/sqlalchemy/orm/session.py | 5238 +++++++++++++ .../site-packages/sqlalchemy/orm/state.py | 1136 +++ .../site-packages/sqlalchemy/orm/state_changes.py | 198 + .../site-packages/sqlalchemy/orm/strategies.py | 3344 +++++++++ .../sqlalchemy/orm/strategy_options.py | 2555 +++++++ .../site-packages/sqlalchemy/orm/sync.py | 164 + .../site-packages/sqlalchemy/orm/unitofwork.py | 796 ++ .../site-packages/sqlalchemy/orm/util.py | 2416 ++++++ .../site-packages/sqlalchemy/orm/writeonly.py | 678 ++ .../site-packages/sqlalchemy/pool/__init__.py | 44 + .../pool/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1878 bytes .../pool/__pycache__/base.cpython-311.pyc | Bin 0 -> 59317 bytes .../pool/__pycache__/events.cpython-311.pyc | Bin 0 -> 14483 bytes .../pool/__pycache__/impl.cpython-311.pyc | Bin 0 -> 27550 bytes .../site-packages/sqlalchemy/pool/base.py | 1515 ++++ .../site-packages/sqlalchemy/pool/events.py | 370 + .../site-packages/sqlalchemy/pool/impl.py | 581 ++ .../python3.11/site-packages/sqlalchemy/py.typed | 0 .../python3.11/site-packages/sqlalchemy/schema.py | 70 + .../site-packages/sqlalchemy/sql/__init__.py | 145 + .../sql/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 6476 bytes .../__pycache__/_dml_constructors.cpython-311.pyc | Bin 0 -> 4303 bytes .../_elements_constructors.cpython-311.pyc | Bin 0 -> 66541 bytes .../sql/__pycache__/_orm_types.cpython-311.pyc | Bin 0 -> 678 bytes .../sql/__pycache__/_py_util.cpython-311.pyc | Bin 0 -> 3351 bytes .../_selectable_constructors.cpython-311.pyc | Bin 0 -> 21634 bytes .../sql/__pycache__/_typing.cpython-311.pyc | Bin 0 -> 16730 bytes .../sql/__pycache__/annotation.cpython-311.pyc | Bin 0 -> 23147 bytes .../sql/__pycache__/base.cpython-311.pyc | Bin 0 -> 107198 bytes .../sql/__pycache__/cache_key.cpython-311.pyc | Bin 0 -> 39232 bytes .../sql/__pycache__/coercions.cpython-311.pyc | Bin 0 -> 53716 bytes .../sql/__pycache__/compiler.cpython-311.pyc | Bin 0 -> 286030 bytes .../sql/__pycache__/crud.cpython-311.pyc | Bin 0 -> 47495 bytes .../sqlalchemy/sql/__pycache__/ddl.cpython-311.pyc | Bin 0 -> 62979 bytes .../__pycache__/default_comparator.cpython-311.pyc | Bin 0 -> 18086 bytes .../sqlalchemy/sql/__pycache__/dml.cpython-311.pyc | Bin 0 -> 77526 bytes .../sql/__pycache__/elements.cpython-311.pyc | Bin 0 -> 217761 bytes .../sql/__pycache__/events.cpython-311.pyc | Bin 0 -> 19364 bytes .../sql/__pycache__/expression.cpython-311.pyc | Bin 0 -> 7230 bytes .../sql/__pycache__/functions.cpython-311.pyc | Bin 0 -> 80876 bytes .../sql/__pycache__/lambdas.cpython-311.pyc | Bin 0 -> 59594 bytes .../sql/__pycache__/naming.cpython-311.pyc | Bin 0 -> 9203 bytes .../sql/__pycache__/operators.cpython-311.pyc | Bin 0 -> 93289 bytes .../sql/__pycache__/roles.cpython-311.pyc | Bin 0 -> 14938 bytes .../sql/__pycache__/schema.cpython-311.pyc | Bin 0 -> 255679 bytes .../sql/__pycache__/selectable.cpython-311.pyc | Bin 0 -> 272258 bytes .../sql/__pycache__/sqltypes.cpython-311.pyc | Bin 0 -> 158830 bytes .../sql/__pycache__/traversals.cpython-311.pyc | Bin 0 -> 49251 bytes .../sql/__pycache__/type_api.cpython-311.pyc | Bin 0 -> 87962 bytes .../sql/__pycache__/util.cpython-311.pyc | Bin 0 -> 59936 bytes .../sql/__pycache__/visitors.cpython-311.pyc | Bin 0 -> 38853 bytes .../sqlalchemy/sql/_dml_constructors.py | 140 + .../sqlalchemy/sql/_elements_constructors.py | 1840 +++++ .../site-packages/sqlalchemy/sql/_orm_types.py | 20 + .../site-packages/sqlalchemy/sql/_py_util.py | 75 + .../sqlalchemy/sql/_selectable_constructors.py | 635 ++ .../site-packages/sqlalchemy/sql/_typing.py | 457 ++ .../site-packages/sqlalchemy/sql/annotation.py | 585 ++ .../site-packages/sqlalchemy/sql/base.py | 2180 ++++++ .../site-packages/sqlalchemy/sql/cache_key.py | 1057 +++ .../site-packages/sqlalchemy/sql/coercions.py | 1389 ++++ .../site-packages/sqlalchemy/sql/compiler.py | 7811 ++++++++++++++++++++ .../site-packages/sqlalchemy/sql/crud.py | 1669 +++++ .../python3.11/site-packages/sqlalchemy/sql/ddl.py | 1378 ++++ .../sqlalchemy/sql/default_comparator.py | 552 ++ .../python3.11/site-packages/sqlalchemy/sql/dml.py | 1817 +++++ .../site-packages/sqlalchemy/sql/elements.py | 5405 ++++++++++++++ .../site-packages/sqlalchemy/sql/events.py | 455 ++ .../site-packages/sqlalchemy/sql/expression.py | 162 + .../site-packages/sqlalchemy/sql/functions.py | 2052 +++++ .../site-packages/sqlalchemy/sql/lambdas.py | 1449 ++++ .../site-packages/sqlalchemy/sql/naming.py | 212 + .../site-packages/sqlalchemy/sql/operators.py | 2573 +++++++ .../site-packages/sqlalchemy/sql/roles.py | 323 + .../site-packages/sqlalchemy/sql/schema.py | 6115 +++++++++++++++ .../site-packages/sqlalchemy/sql/selectable.py | 6913 +++++++++++++++++ .../site-packages/sqlalchemy/sql/sqltypes.py | 3786 ++++++++++ .../site-packages/sqlalchemy/sql/traversals.py | 1022 +++ .../site-packages/sqlalchemy/sql/type_api.py | 2303 ++++++ .../site-packages/sqlalchemy/sql/util.py | 1486 ++++ .../site-packages/sqlalchemy/sql/visitors.py | 1165 +++ .../site-packages/sqlalchemy/testing/__init__.py | 95 + .../testing/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 4472 bytes .../testing/__pycache__/assertions.cpython-311.pyc | Bin 0 -> 49524 bytes .../testing/__pycache__/assertsql.cpython-311.pyc | Bin 0 -> 22362 bytes .../testing/__pycache__/asyncio.cpython-311.pyc | Bin 0 -> 4499 bytes .../testing/__pycache__/config.cpython-311.pyc | Bin 0 -> 19660 bytes .../testing/__pycache__/engines.cpython-311.pyc | Bin 0 -> 23298 bytes .../testing/__pycache__/entities.cpython-311.pyc | Bin 0 -> 6025 bytes .../testing/__pycache__/exclusions.cpython-311.pyc | Bin 0 -> 24306 bytes .../testing/__pycache__/pickleable.cpython-311.pyc | Bin 0 -> 7604 bytes .../testing/__pycache__/profiling.cpython-311.pyc | Bin 0 -> 14276 bytes .../testing/__pycache__/provision.cpython-311.pyc | Bin 0 -> 23283 bytes .../__pycache__/requirements.cpython-311.pyc | Bin 0 -> 87090 bytes .../testing/__pycache__/schema.cpython-311.pyc | Bin 0 -> 9944 bytes .../testing/__pycache__/util.cpython-311.pyc | Bin 0 -> 24674 bytes .../testing/__pycache__/warnings.cpython-311.pyc | Bin 0 -> 2272 bytes .../site-packages/sqlalchemy/testing/assertions.py | 989 +++ .../site-packages/sqlalchemy/testing/assertsql.py | 516 ++ .../site-packages/sqlalchemy/testing/asyncio.py | 135 + .../site-packages/sqlalchemy/testing/config.py | 427 ++ .../site-packages/sqlalchemy/testing/engines.py | 472 ++ .../site-packages/sqlalchemy/testing/entities.py | 117 + .../site-packages/sqlalchemy/testing/exclusions.py | 435 ++ .../sqlalchemy/testing/fixtures/__init__.py | 28 + .../fixtures/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 1191 bytes .../fixtures/__pycache__/base.cpython-311.pyc | Bin 0 -> 15739 bytes .../fixtures/__pycache__/mypy.cpython-311.pyc | Bin 0 -> 14625 bytes .../fixtures/__pycache__/orm.cpython-311.pyc | Bin 0 -> 12949 bytes .../fixtures/__pycache__/sql.cpython-311.pyc | Bin 0 -> 25189 bytes .../sqlalchemy/testing/fixtures/base.py | 366 + .../sqlalchemy/testing/fixtures/mypy.py | 312 + .../sqlalchemy/testing/fixtures/orm.py | 227 + .../sqlalchemy/testing/fixtures/sql.py | 493 ++ .../site-packages/sqlalchemy/testing/pickleable.py | 155 + .../sqlalchemy/testing/plugin/__init__.py | 6 + .../plugin/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 209 bytes .../plugin/__pycache__/bootstrap.cpython-311.pyc | Bin 0 -> 2230 bytes .../plugin/__pycache__/plugin_base.cpython-311.pyc | Bin 0 -> 31509 bytes .../__pycache__/pytestplugin.cpython-311.pyc | Bin 0 -> 37248 bytes .../sqlalchemy/testing/plugin/bootstrap.py | 51 + .../sqlalchemy/testing/plugin/plugin_base.py | 779 ++ .../sqlalchemy/testing/plugin/pytestplugin.py | 868 +++ .../site-packages/sqlalchemy/testing/profiling.py | 324 + .../site-packages/sqlalchemy/testing/provision.py | 496 ++ .../sqlalchemy/testing/requirements.py | 1783 +++++ .../site-packages/sqlalchemy/testing/schema.py | 224 + .../sqlalchemy/testing/suite/__init__.py | 19 + .../suite/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 663 bytes .../suite/__pycache__/test_cte.cpython-311.pyc | Bin 0 -> 10661 bytes .../suite/__pycache__/test_ddl.cpython-311.pyc | Bin 0 -> 21767 bytes .../__pycache__/test_deprecations.cpython-311.pyc | Bin 0 -> 10015 bytes .../suite/__pycache__/test_dialect.cpython-311.pyc | Bin 0 -> 39595 bytes .../suite/__pycache__/test_insert.cpython-311.pyc | Bin 0 -> 28272 bytes .../__pycache__/test_reflection.cpython-311.pyc | Bin 0 -> 152512 bytes .../suite/__pycache__/test_results.cpython-311.pyc | Bin 0 -> 26509 bytes .../__pycache__/test_rowcount.cpython-311.pyc | Bin 0 -> 11067 bytes .../suite/__pycache__/test_select.cpython-311.pyc | Bin 0 -> 115631 bytes .../__pycache__/test_sequence.cpython-311.pyc | Bin 0 -> 16954 bytes .../suite/__pycache__/test_types.cpython-311.pyc | Bin 0 -> 107496 bytes .../__pycache__/test_unicode_ddl.cpython-311.pyc | Bin 0 -> 9173 bytes .../__pycache__/test_update_delete.cpython-311.pyc | Bin 0 -> 7763 bytes .../sqlalchemy/testing/suite/test_cte.py | 211 + .../sqlalchemy/testing/suite/test_ddl.py | 389 + .../sqlalchemy/testing/suite/test_deprecations.py | 153 + .../sqlalchemy/testing/suite/test_dialect.py | 740 ++ .../sqlalchemy/testing/suite/test_insert.py | 630 ++ .../sqlalchemy/testing/suite/test_reflection.py | 3128 ++++++++ .../sqlalchemy/testing/suite/test_results.py | 468 ++ .../sqlalchemy/testing/suite/test_rowcount.py | 258 + .../sqlalchemy/testing/suite/test_select.py | 1888 +++++ .../sqlalchemy/testing/suite/test_sequence.py | 317 + .../sqlalchemy/testing/suite/test_types.py | 2071 ++++++ .../sqlalchemy/testing/suite/test_unicode_ddl.py | 189 + .../sqlalchemy/testing/suite/test_update_delete.py | 139 + .../site-packages/sqlalchemy/testing/util.py | 519 ++ .../site-packages/sqlalchemy/testing/warnings.py | 52 + .../python3.11/site-packages/sqlalchemy/types.py | 76 + .../site-packages/sqlalchemy/util/__init__.py | 159 + .../util/__pycache__/__init__.cpython-311.pyc | Bin 0 -> 7824 bytes .../util/__pycache__/_collections.cpython-311.pyc | Bin 0 -> 36530 bytes .../__pycache__/_concurrency_py3k.cpython-311.pyc | Bin 0 -> 12296 bytes .../util/__pycache__/_has_cy.cpython-311.pyc | Bin 0 -> 1273 bytes .../__pycache__/_py_collections.cpython-311.pyc | Bin 0 -> 35537 bytes .../util/__pycache__/compat.cpython-311.pyc | Bin 0 -> 13640 bytes .../util/__pycache__/concurrency.cpython-311.pyc | Bin 0 -> 4664 bytes .../util/__pycache__/deprecations.cpython-311.pyc | Bin 0 -> 14959 bytes .../util/__pycache__/langhelpers.cpython-311.pyc | Bin 0 -> 94945 bytes .../util/__pycache__/preloaded.cpython-311.pyc | Bin 0 -> 6691 bytes .../util/__pycache__/queue.cpython-311.pyc | Bin 0 -> 16638 bytes .../util/__pycache__/tool_support.cpython-311.pyc | Bin 0 -> 9675 bytes .../util/__pycache__/topological.cpython-311.pyc | Bin 0 -> 4760 bytes .../util/__pycache__/typing.cpython-311.pyc | Bin 0 -> 23055 bytes .../site-packages/sqlalchemy/util/_collections.py | 715 ++ .../sqlalchemy/util/_concurrency_py3k.py | 290 + .../site-packages/sqlalchemy/util/_has_cy.py | 40 + .../sqlalchemy/util/_py_collections.py | 541 ++ .../site-packages/sqlalchemy/util/compat.py | 300 + .../site-packages/sqlalchemy/util/concurrency.py | 108 + .../site-packages/sqlalchemy/util/deprecations.py | 401 + .../site-packages/sqlalchemy/util/langhelpers.py | 2211 ++++++ .../site-packages/sqlalchemy/util/preloaded.py | 150 + .../site-packages/sqlalchemy/util/queue.py | 322 + .../site-packages/sqlalchemy/util/tool_support.py | 201 + .../site-packages/sqlalchemy/util/topological.py | 120 + .../site-packages/sqlalchemy/util/typing.py | 580 ++ 523 files changed, 230921 insertions(+) create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/events.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/exc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/inspection.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/log.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/schema.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/types.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/aioodbc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/asyncio.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/pyodbc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/aioodbc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/asyncio.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/connectors/pyodbc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__pycache__/__init__.cpython-311.pyc create mode 100755 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.cpython-311-x86_64-linux-gnu.so create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.pyx create mode 100755 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.cpython-311-x86_64-linux-gnu.so create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pxd create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pyx create mode 100755 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.cpython-311-x86_64-linux-gnu.so create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.pyx create mode 100755 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.cpython-311-x86_64-linux-gnu.so create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.pyx create mode 100755 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.cpython-311-x86_64-linux-gnu.so create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.pyx create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/_typing.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/_typing.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/aioodbc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/information_schema.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/json.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/provision.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pymssql.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pyodbc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/aioodbc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/information_schema.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/json.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/provision.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pymssql.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pyodbc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/aiomysql.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/asyncmy.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/cymysql.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/dml.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/enumerated.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/expression.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/json.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadb.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadbconnector.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqlconnector.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqldb.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/provision.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pymysql.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pyodbc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reflection.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reserved_words.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/types.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/aiomysql.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/asyncmy.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/cymysql.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/dml.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/enumerated.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/expression.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/json.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadb.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqldb.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/provision.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pymysql.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pyodbc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reflection.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reserved_words.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/types.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/cx_oracle.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/dictionary.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/oracledb.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/provision.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/types.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/dictionary.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/oracledb.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/provision.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/types.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/_psycopg_common.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/array.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/asyncpg.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/dml.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ext.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/hstore.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/json.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/named_types.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/operators.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg8000.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg_catalog.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/provision.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2cffi.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ranges.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/types.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/array.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/dml.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ext.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/hstore.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/json.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/named_types.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/operators.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg8000.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/provision.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ranges.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/types.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/aiosqlite.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/dml.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/json.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/provision.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlcipher.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlite.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/dml.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/json.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/provision.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/dialects/type_migration_guidelines.txt create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_processors.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_row.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_util.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/characteristics.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/create.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/cursor.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/default.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/events.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/interfaces.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/mock.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/processors.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/reflection.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/result.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/row.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/strategies.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/url.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/util.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_processors.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_row.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_util.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/characteristics.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/create.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/cursor.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/default.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/events.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/interfaces.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/mock.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/processors.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/reflection.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/result.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/row.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/strategies.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/url.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/engine/util.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/api.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/attr.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/legacy.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/registry.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/api.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/attr.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/legacy.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/event/registry.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/events.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/exc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/associationproxy.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/automap.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/baked.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/compiler.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/horizontal_shard.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/hybrid.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/indexable.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/instrumentation.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/mutable.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/orderinglist.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/serializer.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/associationproxy.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/engine.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/exc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/result.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/session.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/engine.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/exc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/result.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/scoping.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/session.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/automap.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/baked.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/extensions.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/extensions.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/horizontal_shard.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/indexable.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/instrumentation.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/apply.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/infer.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/names.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/plugin.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/util.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/apply.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/decl_class.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/infer.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/names.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/plugin.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/util.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/orderinglist.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/ext/serializer.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/future/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/engine.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/future/engine.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/inspection.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/log.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_orm_constructors.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_typing.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/attributes.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/bulk_persistence.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/clsregistry.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/collections.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/context.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_api.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dependency.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/descriptor_props.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dynamic.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/evaluator.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/events.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/exc.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/identity.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/instrumentation.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/interfaces.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/loading.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapped_collection.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapper.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/path_registry.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/persistence.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/properties.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/query.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/relationships.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/scoping.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/session.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state_changes.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategies.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategy_options.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/sync.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/unitofwork.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/util.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/writeonly.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/_orm_constructors.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/_typing.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/attributes.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/bulk_persistence.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/clsregistry.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/collections.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/context.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_api.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/dependency.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/descriptor_props.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/dynamic.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/events.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/exc.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/identity.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/instrumentation.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/interfaces.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/loading.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/mapped_collection.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/mapper.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/path_registry.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/properties.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/query.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/relationships.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/scoping.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/session.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/state.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/state_changes.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/strategies.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/strategy_options.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/sync.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/unitofwork.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/util.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/orm/writeonly.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/events.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/impl.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/events.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/pool/impl.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/py.typed create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/schema.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_dml_constructors.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_elements_constructors.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_orm_types.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_py_util.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_selectable_constructors.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_typing.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/annotation.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/cache_key.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/coercions.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/compiler.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/crud.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/ddl.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/default_comparator.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/dml.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/elements.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/events.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/expression.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/functions.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/lambdas.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/naming.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/operators.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/roles.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/schema.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/selectable.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/sqltypes.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/traversals.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/type_api.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/util.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/visitors.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/_dml_constructors.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/_elements_constructors.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/_orm_types.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/_py_util.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/_selectable_constructors.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/_typing.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/annotation.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/cache_key.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/coercions.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/compiler.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/crud.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/ddl.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/default_comparator.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/dml.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/elements.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/events.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/expression.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/functions.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/lambdas.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/naming.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/operators.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/roles.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/schema.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/selectable.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/sqltypes.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/traversals.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/type_api.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/util.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/sql/visitors.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertions.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/assertsql.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/asyncio.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/config.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/engines.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/entities.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/exclusions.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/pickleable.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/profiling.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/provision.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/requirements.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/schema.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/util.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/__pycache__/warnings.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/assertions.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/assertsql.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/asyncio.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/config.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/engines.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/entities.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/exclusions.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/orm.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/__pycache__/sql.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/mypy.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/orm.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/fixtures/sql.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/pickleable.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/bootstrap.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/plugin_base.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/plugin/pytestplugin.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/profiling.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/provision.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/requirements.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/schema.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_cte.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_insert.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_results.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_select.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_types.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_cte.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_ddl.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_deprecations.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_dialect.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_insert.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_reflection.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_results.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_rowcount.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_select.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_sequence.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_types.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_unicode_ddl.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/suite/test_update_delete.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/util.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/testing/warnings.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/types.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__init__.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/__init__.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/_collections.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/_concurrency_py3k.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/_has_cy.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/_py_collections.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/compat.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/concurrency.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/deprecations.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/langhelpers.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/preloaded.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/queue.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/tool_support.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/topological.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/__pycache__/typing.cpython-311.pyc create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/_collections.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/_concurrency_py3k.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/_has_cy.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/_py_collections.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/compat.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/concurrency.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/deprecations.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/langhelpers.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/preloaded.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/queue.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/tool_support.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/topological.py create mode 100644 venv/lib/python3.11/site-packages/sqlalchemy/util/typing.py (limited to 'venv/lib/python3.11/site-packages/sqlalchemy') diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/__init__.py new file mode 100644 index 0000000..9e983d0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/__init__.py @@ -0,0 +1,294 @@ +# __init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any + +from . import util as _util +from .engine import AdaptedConnection as AdaptedConnection +from .engine import BaseRow as BaseRow +from .engine import BindTyping as BindTyping +from .engine import ChunkedIteratorResult as ChunkedIteratorResult +from .engine import Compiled as Compiled +from .engine import Connection as Connection +from .engine import create_engine as create_engine +from .engine import create_mock_engine as create_mock_engine +from .engine import create_pool_from_url as create_pool_from_url +from .engine import CreateEnginePlugin as CreateEnginePlugin +from .engine import CursorResult as CursorResult +from .engine import Dialect as Dialect +from .engine import Engine as Engine +from .engine import engine_from_config as engine_from_config +from .engine import ExceptionContext as ExceptionContext +from .engine import ExecutionContext as ExecutionContext +from .engine import FrozenResult as FrozenResult +from .engine import Inspector as Inspector +from .engine import IteratorResult as IteratorResult +from .engine import make_url as make_url +from .engine import MappingResult as MappingResult +from .engine import MergedResult as MergedResult +from .engine import NestedTransaction as NestedTransaction +from .engine import Result as Result +from .engine import result_tuple as result_tuple +from .engine import ResultProxy as ResultProxy +from .engine import RootTransaction as RootTransaction +from .engine import Row as Row +from .engine import RowMapping as RowMapping +from .engine import ScalarResult as ScalarResult +from .engine import Transaction as Transaction +from .engine import TwoPhaseTransaction as TwoPhaseTransaction +from .engine import TypeCompiler as TypeCompiler +from .engine import URL as URL +from .inspection import inspect as inspect +from .pool import AssertionPool as AssertionPool +from .pool import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .pool import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .pool import NullPool as NullPool +from .pool import Pool as Pool +from .pool import PoolProxiedConnection as PoolProxiedConnection +from .pool import PoolResetState as PoolResetState +from .pool import QueuePool as QueuePool +from .pool import SingletonThreadPool as SingletonThreadPool +from .pool import StaticPool as StaticPool +from .schema import BaseDDLElement as BaseDDLElement +from .schema import BLANK_SCHEMA as BLANK_SCHEMA +from .schema import CheckConstraint as CheckConstraint +from .schema import Column as Column +from .schema import ColumnDefault as ColumnDefault +from .schema import Computed as Computed +from .schema import Constraint as Constraint +from .schema import DDL as DDL +from .schema import DDLElement as DDLElement +from .schema import DefaultClause as DefaultClause +from .schema import ExecutableDDLElement as ExecutableDDLElement +from .schema import FetchedValue as FetchedValue +from .schema import ForeignKey as ForeignKey +from .schema import ForeignKeyConstraint as ForeignKeyConstraint +from .schema import Identity as Identity +from .schema import Index as Index +from .schema import insert_sentinel as insert_sentinel +from .schema import MetaData as MetaData +from .schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from .schema import Sequence as Sequence +from .schema import Table as Table +from .schema import UniqueConstraint as UniqueConstraint +from .sql import ColumnExpressionArgument as ColumnExpressionArgument +from .sql import NotNullable as NotNullable +from .sql import Nullable as Nullable +from .sql import SelectLabelStyle as SelectLabelStyle +from .sql.expression import Alias as Alias +from .sql.expression import alias as alias +from .sql.expression import AliasedReturnsRows as AliasedReturnsRows +from .sql.expression import all_ as all_ +from .sql.expression import and_ as and_ +from .sql.expression import any_ as any_ +from .sql.expression import asc as asc +from .sql.expression import between as between +from .sql.expression import BinaryExpression as BinaryExpression +from .sql.expression import bindparam as bindparam +from .sql.expression import BindParameter as BindParameter +from .sql.expression import bitwise_not as bitwise_not +from .sql.expression import BooleanClauseList as BooleanClauseList +from .sql.expression import CacheKey as CacheKey +from .sql.expression import Case as Case +from .sql.expression import case as case +from .sql.expression import Cast as Cast +from .sql.expression import cast as cast +from .sql.expression import ClauseElement as ClauseElement +from .sql.expression import ClauseList as ClauseList +from .sql.expression import collate as collate +from .sql.expression import CollectionAggregate as CollectionAggregate +from .sql.expression import column as column +from .sql.expression import ColumnClause as ColumnClause +from .sql.expression import ColumnCollection as ColumnCollection +from .sql.expression import ColumnElement as ColumnElement +from .sql.expression import ColumnOperators as ColumnOperators +from .sql.expression import CompoundSelect as CompoundSelect +from .sql.expression import CTE as CTE +from .sql.expression import cte as cte +from .sql.expression import custom_op as custom_op +from .sql.expression import Delete as Delete +from .sql.expression import delete as delete +from .sql.expression import desc as desc +from .sql.expression import distinct as distinct +from .sql.expression import except_ as except_ +from .sql.expression import except_all as except_all +from .sql.expression import Executable as Executable +from .sql.expression import Exists as Exists +from .sql.expression import exists as exists +from .sql.expression import Extract as Extract +from .sql.expression import extract as extract +from .sql.expression import false as false +from .sql.expression import False_ as False_ +from .sql.expression import FromClause as FromClause +from .sql.expression import FromGrouping as FromGrouping +from .sql.expression import func as func +from .sql.expression import funcfilter as funcfilter +from .sql.expression import Function as Function +from .sql.expression import FunctionElement as FunctionElement +from .sql.expression import FunctionFilter as FunctionFilter +from .sql.expression import GenerativeSelect as GenerativeSelect +from .sql.expression import Grouping as Grouping +from .sql.expression import HasCTE as HasCTE +from .sql.expression import HasPrefixes as HasPrefixes +from .sql.expression import HasSuffixes as HasSuffixes +from .sql.expression import Insert as Insert +from .sql.expression import insert as insert +from .sql.expression import intersect as intersect +from .sql.expression import intersect_all as intersect_all +from .sql.expression import Join as Join +from .sql.expression import join as join +from .sql.expression import Label as Label +from .sql.expression import label as label +from .sql.expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from .sql.expression import ( + LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, +) +from .sql.expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from .sql.expression import ( + LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, +) +from .sql.expression import lambda_stmt as lambda_stmt +from .sql.expression import LambdaElement as LambdaElement +from .sql.expression import Lateral as Lateral +from .sql.expression import lateral as lateral +from .sql.expression import literal as literal +from .sql.expression import literal_column as literal_column +from .sql.expression import modifier as modifier +from .sql.expression import not_ as not_ +from .sql.expression import Null as Null +from .sql.expression import null as null +from .sql.expression import nulls_first as nulls_first +from .sql.expression import nulls_last as nulls_last +from .sql.expression import nullsfirst as nullsfirst +from .sql.expression import nullslast as nullslast +from .sql.expression import Operators as Operators +from .sql.expression import or_ as or_ +from .sql.expression import outerjoin as outerjoin +from .sql.expression import outparam as outparam +from .sql.expression import Over as Over +from .sql.expression import over as over +from .sql.expression import quoted_name as quoted_name +from .sql.expression import ReleaseSavepointClause as ReleaseSavepointClause +from .sql.expression import ReturnsRows as ReturnsRows +from .sql.expression import ( + RollbackToSavepointClause as RollbackToSavepointClause, +) +from .sql.expression import SavepointClause as SavepointClause +from .sql.expression import ScalarSelect as ScalarSelect +from .sql.expression import Select as Select +from .sql.expression import select as select +from .sql.expression import Selectable as Selectable +from .sql.expression import SelectBase as SelectBase +from .sql.expression import SQLColumnExpression as SQLColumnExpression +from .sql.expression import StatementLambdaElement as StatementLambdaElement +from .sql.expression import Subquery as Subquery +from .sql.expression import table as table +from .sql.expression import TableClause as TableClause +from .sql.expression import TableSample as TableSample +from .sql.expression import tablesample as tablesample +from .sql.expression import TableValuedAlias as TableValuedAlias +from .sql.expression import text as text +from .sql.expression import TextAsFrom as TextAsFrom +from .sql.expression import TextClause as TextClause +from .sql.expression import TextualSelect as TextualSelect +from .sql.expression import true as true +from .sql.expression import True_ as True_ +from .sql.expression import try_cast as try_cast +from .sql.expression import TryCast as TryCast +from .sql.expression import Tuple as Tuple +from .sql.expression import tuple_ as tuple_ +from .sql.expression import type_coerce as type_coerce +from .sql.expression import TypeClause as TypeClause +from .sql.expression import TypeCoerce as TypeCoerce +from .sql.expression import UnaryExpression as UnaryExpression +from .sql.expression import union as union +from .sql.expression import union_all as union_all +from .sql.expression import Update as Update +from .sql.expression import update as update +from .sql.expression import UpdateBase as UpdateBase +from .sql.expression import Values as Values +from .sql.expression import values as values +from .sql.expression import ValuesBase as ValuesBase +from .sql.expression import Visitable as Visitable +from .sql.expression import within_group as within_group +from .sql.expression import WithinGroup as WithinGroup +from .types import ARRAY as ARRAY +from .types import BIGINT as BIGINT +from .types import BigInteger as BigInteger +from .types import BINARY as BINARY +from .types import BLOB as BLOB +from .types import BOOLEAN as BOOLEAN +from .types import Boolean as Boolean +from .types import CHAR as CHAR +from .types import CLOB as CLOB +from .types import DATE as DATE +from .types import Date as Date +from .types import DATETIME as DATETIME +from .types import DateTime as DateTime +from .types import DECIMAL as DECIMAL +from .types import DOUBLE as DOUBLE +from .types import Double as Double +from .types import DOUBLE_PRECISION as DOUBLE_PRECISION +from .types import Enum as Enum +from .types import FLOAT as FLOAT +from .types import Float as Float +from .types import INT as INT +from .types import INTEGER as INTEGER +from .types import Integer as Integer +from .types import Interval as Interval +from .types import JSON as JSON +from .types import LargeBinary as LargeBinary +from .types import NCHAR as NCHAR +from .types import NUMERIC as NUMERIC +from .types import Numeric as Numeric +from .types import NVARCHAR as NVARCHAR +from .types import PickleType as PickleType +from .types import REAL as REAL +from .types import SMALLINT as SMALLINT +from .types import SmallInteger as SmallInteger +from .types import String as String +from .types import TEXT as TEXT +from .types import Text as Text +from .types import TIME as TIME +from .types import Time as Time +from .types import TIMESTAMP as TIMESTAMP +from .types import TupleType as TupleType +from .types import TypeDecorator as TypeDecorator +from .types import Unicode as Unicode +from .types import UnicodeText as UnicodeText +from .types import UUID as UUID +from .types import Uuid as Uuid +from .types import VARBINARY as VARBINARY +from .types import VARCHAR as VARCHAR + +__version__ = "2.0.29" + + +def __go(lcls: Any) -> None: + _util.preloaded.import_prefix("sqlalchemy") + + from . import exc + + exc._version_token = "".join(__version__.split(".")[0:2]) + + +__go(locals()) + + +def __getattr__(name: str) -> Any: + if name == "SingleonThreadPool": + _util.warn_deprecated( + "SingleonThreadPool was a typo in the v2 series. " + "Please use the correct SingletonThreadPool name.", + "2.0.24", + ) + return SingletonThreadPool + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..10ae04b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/events.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/events.cpython-311.pyc new file mode 100644 index 0000000..bdb0f5e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/events.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/exc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/exc.cpython-311.pyc new file mode 100644 index 0000000..e1feac8 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/exc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/inspection.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/inspection.cpython-311.pyc new file mode 100644 index 0000000..c722a5b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/inspection.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/log.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/log.cpython-311.pyc new file mode 100644 index 0000000..71fbe24 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/log.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/schema.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/schema.cpython-311.pyc new file mode 100644 index 0000000..71308b0 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/schema.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000..a554e2d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/__pycache__/types.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__init__.py new file mode 100644 index 0000000..f1cae0b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__init__.py @@ -0,0 +1,18 @@ +# connectors/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +from ..engine.interfaces import Dialect + + +class Connector(Dialect): + """Base class for dialect mixins, for DBAPIs that work + across entirely different database backends. + + Currently the only such mixin is pyodbc. + + """ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..172b726 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/aioodbc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/aioodbc.cpython-311.pyc new file mode 100644 index 0000000..86bb366 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/aioodbc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/asyncio.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/asyncio.cpython-311.pyc new file mode 100644 index 0000000..c9b451d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/asyncio.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/pyodbc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/pyodbc.cpython-311.pyc new file mode 100644 index 0000000..5805308 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/__pycache__/pyodbc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/aioodbc.py b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/aioodbc.py new file mode 100644 index 0000000..3b5c3b4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/aioodbc.py @@ -0,0 +1,174 @@ +# connectors/aioodbc.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .asyncio import AsyncAdapt_dbapi_connection +from .asyncio import AsyncAdapt_dbapi_cursor +from .asyncio import AsyncAdapt_dbapi_ss_cursor +from .asyncio import AsyncAdaptFallback_dbapi_connection +from .pyodbc import PyODBCConnector +from .. import pool +from .. import util +from ..util.concurrency import await_fallback +from ..util.concurrency import await_only + +if TYPE_CHECKING: + from ..engine.interfaces import ConnectArgsType + from ..engine.url import URL + + +class AsyncAdapt_aioodbc_cursor(AsyncAdapt_dbapi_cursor): + __slots__ = () + + def setinputsizes(self, *inputsizes): + # see https://github.com/aio-libs/aioodbc/issues/451 + return self._cursor._impl.setinputsizes(*inputsizes) + + # how it's supposed to work + # return self.await_(self._cursor.setinputsizes(*inputsizes)) + + +class AsyncAdapt_aioodbc_ss_cursor( + AsyncAdapt_aioodbc_cursor, AsyncAdapt_dbapi_ss_cursor +): + __slots__ = () + + +class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection): + _cursor_cls = AsyncAdapt_aioodbc_cursor + _ss_cursor_cls = AsyncAdapt_aioodbc_ss_cursor + __slots__ = () + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + # https://github.com/aio-libs/aioodbc/issues/448 + # self._connection.autocommit = value + + self._connection._conn.autocommit = value + + def cursor(self, server_side=False): + # aioodbc sets connection=None when closed and just fails with + # AttributeError here. Here we use the same ProgrammingError + + # message that pyodbc uses, so it triggers is_disconnect() as well. + if self._connection.closed: + raise self.dbapi.ProgrammingError( + "Attempt to use a closed connection." + ) + return super().cursor(server_side=server_side) + + def rollback(self): + # aioodbc sets connection=None when closed and just fails with + # AttributeError here. should be a no-op + if not self._connection.closed: + super().rollback() + + def commit(self): + # aioodbc sets connection=None when closed and just fails with + # AttributeError here. should be a no-op + if not self._connection.closed: + super().commit() + + def close(self): + # aioodbc sets connection=None when closed and just fails with + # AttributeError here. should be a no-op + if not self._connection.closed: + super().close() + + +class AsyncAdaptFallback_aioodbc_connection( + AsyncAdaptFallback_dbapi_connection, AsyncAdapt_aioodbc_connection +): + __slots__ = () + + +class AsyncAdapt_aioodbc_dbapi: + def __init__(self, aioodbc, pyodbc): + self.aioodbc = aioodbc + self.pyodbc = pyodbc + self.paramstyle = pyodbc.paramstyle + self._init_dbapi_attributes() + self.Cursor = AsyncAdapt_dbapi_cursor + self.version = pyodbc.version + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + "NUMBER", + "STRING", + "DATETIME", + "BINARY", + "Binary", + "BinaryNull", + "SQL_VARCHAR", + "SQL_WVARCHAR", + ): + setattr(self, name, getattr(self.pyodbc, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.aioodbc.connect) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aioodbc_connection( + self, + await_fallback(creator_fn(*arg, **kw)), + ) + else: + return AsyncAdapt_aioodbc_connection( + self, + await_only(creator_fn(*arg, **kw)), + ) + + +class aiodbcConnector(PyODBCConnector): + is_async = True + supports_statement_cache = True + + supports_server_side_cursors = True + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_aioodbc_dbapi( + __import__("aioodbc"), __import__("pyodbc") + ) + + def create_connect_args(self, url: URL) -> ConnectArgsType: + arg, kw = super().create_connect_args(url) + if arg and arg[0]: + kw["dsn"] = arg[0] + + return (), kw + + @classmethod + def get_pool_class(cls, url): + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def get_driver_connection(self, connection): + return connection._connection diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/asyncio.py b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/asyncio.py new file mode 100644 index 0000000..0b44f23 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/asyncio.py @@ -0,0 +1,208 @@ +# connectors/asyncio.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""generic asyncio-adapted versions of DBAPI connection and cursor""" + +from __future__ import annotations + +import collections +import itertools + +from ..engine import AdaptedConnection +from ..util.concurrency import asyncio +from ..util.concurrency import await_fallback +from ..util.concurrency import await_only + + +class AsyncAdapt_dbapi_cursor: + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + self._cursor = self._aenter_cursor(cursor) + + self._rows = collections.deque() + + def _aenter_cursor(self, cursor): + return self.await_(cursor.__aenter__()) + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. see notes in aiomysql dialect + self._rows.clear() + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._execute_mutex: + result = await self._cursor.execute(operation, parameters or ()) + + if self._cursor.description and not self.server_side: + self._rows = collections.deque(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._execute_mutex: + return await self._cursor.executemany(operation, seq_of_parameters) + + def nextset(self): + self.await_(self._cursor.nextset()) + if self._cursor.description and not self.server_side: + self._rows = collections.deque( + self.await_(self._cursor.fetchall()) + ) + + def setinputsizes(self, *inputsizes): + # NOTE: this is overrridden in aioodbc due to + # see https://github.com/aio-libs/aioodbc/issues/451 + # right now + + return self.await_(self._cursor.setinputsizes(*inputsizes)) + + def __iter__(self): + while self._rows: + yield self._rows.popleft() + + def fetchone(self): + if self._rows: + return self._rows.popleft() + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + rr = iter(self._rows) + retval = list(itertools.islice(rr, 0, size)) + self._rows = collections.deque(rr) + return retval + + def fetchall(self): + retval = list(self._rows) + self._rows.clear() + return retval + + +class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): + __slots__ = () + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_dbapi_connection(AdaptedConnection): + _cursor_cls = AsyncAdapt_dbapi_cursor + _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor + + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def add_output_converter(self, *arg, **kw): + self._connection.add_output_converter(*arg, **kw) + + def character_set_name(self): + return self._connection.character_set_name() + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + # https://github.com/aio-libs/aioodbc/issues/448 + # self._connection.autocommit = value + + self._connection._conn.autocommit = value + + def cursor(self, server_side=False): + if server_side: + return self._ss_cursor_cls(self) + else: + return self._cursor_cls(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): + self.await_(self._connection.close()) + + +class AsyncAdaptFallback_dbapi_connection(AsyncAdapt_dbapi_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/connectors/pyodbc.py b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/pyodbc.py new file mode 100644 index 0000000..f204d80 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/connectors/pyodbc.py @@ -0,0 +1,249 @@ +# connectors/pyodbc.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import re +from types import ModuleType +import typing +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union +from urllib.parse import unquote_plus + +from . import Connector +from .. import ExecutionContext +from .. import pool +from .. import util +from ..engine import ConnectArgsType +from ..engine import Connection +from ..engine import interfaces +from ..engine import URL +from ..sql.type_api import TypeEngine + +if typing.TYPE_CHECKING: + from ..engine.interfaces import IsolationLevel + + +class PyODBCConnector(Connector): + driver = "pyodbc" + + # this is no longer False for pyodbc in general + supports_sane_rowcount_returning = True + supports_sane_multi_rowcount = False + + supports_native_decimal = True + default_paramstyle = "named" + + fast_executemany = False + + # for non-DSN connections, this *may* be used to + # hold the desired driver name + pyodbc_driver_name: Optional[str] = None + + dbapi: ModuleType + + def __init__(self, use_setinputsizes: bool = False, **kw: Any): + super().__init__(**kw) + if use_setinputsizes: + self.bind_typing = interfaces.BindTyping.SETINPUTSIZES + + @classmethod + def import_dbapi(cls) -> ModuleType: + return __import__("pyodbc") + + def create_connect_args(self, url: URL) -> ConnectArgsType: + opts = url.translate_connect_args(username="user") + opts.update(url.query) + + keys = opts + + query = url.query + + connect_args: Dict[str, Any] = {} + connectors: List[str] + + for param in ("ansi", "unicode_results", "autocommit"): + if param in keys: + connect_args[param] = util.asbool(keys.pop(param)) + + if "odbc_connect" in keys: + connectors = [unquote_plus(keys.pop("odbc_connect"))] + else: + + def check_quote(token: str) -> str: + if ";" in str(token) or str(token).startswith("{"): + token = "{%s}" % token.replace("}", "}}") + return token + + keys = {k: check_quote(v) for k, v in keys.items()} + + dsn_connection = "dsn" in keys or ( + "host" in keys and "database" not in keys + ) + if dsn_connection: + connectors = [ + "dsn=%s" % (keys.pop("host", "") or keys.pop("dsn", "")) + ] + else: + port = "" + if "port" in keys and "port" not in query: + port = ",%d" % int(keys.pop("port")) + + connectors = [] + driver = keys.pop("driver", self.pyodbc_driver_name) + if driver is None and keys: + # note if keys is empty, this is a totally blank URL + util.warn( + "No driver name specified; " + "this is expected by PyODBC when using " + "DSN-less connections" + ) + else: + connectors.append("DRIVER={%s}" % driver) + + connectors.extend( + [ + "Server=%s%s" % (keys.pop("host", ""), port), + "Database=%s" % keys.pop("database", ""), + ] + ) + + user = keys.pop("user", None) + if user: + connectors.append("UID=%s" % user) + pwd = keys.pop("password", "") + if pwd: + connectors.append("PWD=%s" % pwd) + else: + authentication = keys.pop("authentication", None) + if authentication: + connectors.append("Authentication=%s" % authentication) + else: + connectors.append("Trusted_Connection=Yes") + + # if set to 'Yes', the ODBC layer will try to automagically + # convert textual data from your database encoding to your + # client encoding. This should obviously be set to 'No' if + # you query a cp1253 encoded database from a latin1 client... + if "odbc_autotranslate" in keys: + connectors.append( + "AutoTranslate=%s" % keys.pop("odbc_autotranslate") + ) + + connectors.extend(["%s=%s" % (k, v) for k, v in keys.items()]) + + return ((";".join(connectors),), connect_args) + + def is_disconnect( + self, + e: Exception, + connection: Optional[ + Union[pool.PoolProxiedConnection, interfaces.DBAPIConnection] + ], + cursor: Optional[interfaces.DBAPICursor], + ) -> bool: + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str( + e + ) or "Attempt to use a closed connection." in str(e) + else: + return False + + def _dbapi_version(self) -> interfaces.VersionInfoType: + if not self.dbapi: + return () + return self._parse_dbapi_version(self.dbapi.version) + + def _parse_dbapi_version(self, vers: str) -> interfaces.VersionInfoType: + m = re.match(r"(?:py.*-)?([\d\.]+)(?:-(\w+))?", vers) + if not m: + return () + vers_tuple: interfaces.VersionInfoType = tuple( + [int(x) for x in m.group(1).split(".")] + ) + if m.group(2): + vers_tuple += (m.group(2),) + return vers_tuple + + def _get_server_version_info( + self, connection: Connection + ) -> interfaces.VersionInfoType: + # NOTE: this function is not reliable, particularly when + # freetds is in use. Implement database-specific server version + # queries. + dbapi_con = connection.connection.dbapi_connection + version: Tuple[Union[int, str], ...] = () + r = re.compile(r"[.\-]") + for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): # type: ignore[union-attr] # noqa: E501 + try: + version += (int(n),) + except ValueError: + pass + return tuple(version) + + def do_set_input_sizes( + self, + cursor: interfaces.DBAPICursor, + list_of_tuples: List[Tuple[str, Any, TypeEngine[Any]]], + context: ExecutionContext, + ) -> None: + # the rules for these types seems a little strange, as you can pass + # non-tuples as well as tuples, however it seems to assume "0" + # for the subsequent values if you don't pass a tuple which fails + # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype + # that ticket #5649 is targeting. + + # NOTE: as of #6058, this won't be called if the use_setinputsizes + # parameter were not passed to the dialect, or if no types were + # specified in list_of_tuples + + # as of #8177 for 2.0 we assume use_setinputsizes=True and only + # omit the setinputsizes calls for .executemany() with + # fast_executemany=True + + if ( + context.execute_style is interfaces.ExecuteStyle.EXECUTEMANY + and self.fast_executemany + ): + return + + cursor.setinputsizes( + [ + ( + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype + ) + for key, dbtype, sqltype in list_of_tuples + ] + ) + + def get_isolation_level_values( + self, dbapi_connection: interfaces.DBAPIConnection + ) -> List[IsolationLevel]: + return super().get_isolation_level_values(dbapi_connection) + [ + "AUTOCOMMIT" + ] + + def set_isolation_level( + self, + dbapi_connection: interfaces.DBAPIConnection, + level: IsolationLevel, + ) -> None: + # adjust for ConnectionFairy being present + # allows attribute set e.g. "connection.autocommit = True" + # to work properly + + if level == "AUTOCOMMIT": + dbapi_connection.autocommit = True + else: + dbapi_connection.autocommit = False + super().set_isolation_level(dbapi_connection, level) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__init__.py new file mode 100644 index 0000000..88a4d90 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__init__.py @@ -0,0 +1,6 @@ +# cyextension/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..2c16dd2 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.cpython-311-x86_64-linux-gnu.so b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.cpython-311-x86_64-linux-gnu.so new file mode 100755 index 0000000..71f55a1 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.cpython-311-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.pyx b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.pyx new file mode 100644 index 0000000..86d2485 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/collections.pyx @@ -0,0 +1,409 @@ +# cyextension/collections.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +cimport cython +from cpython.long cimport PyLong_FromLongLong +from cpython.set cimport PySet_Add + +from collections.abc import Collection +from itertools import filterfalse + +cdef bint add_not_present(set seen, object item, hashfunc): + hash_value = hashfunc(item) + if hash_value not in seen: + PySet_Add(seen, hash_value) + return True + else: + return False + +cdef list cunique_list(seq, hashfunc=None): + cdef set seen = set() + if not hashfunc: + return [x for x in seq if x not in seen and not PySet_Add(seen, x)] + else: + return [x for x in seq if add_not_present(seen, x, hashfunc)] + +def unique_list(seq, hashfunc=None): + return cunique_list(seq, hashfunc) + +cdef class OrderedSet(set): + + cdef list _list + + @classmethod + def __class_getitem__(cls, key): + return cls + + def __init__(self, d=None): + set.__init__(self) + if d is not None: + self._list = cunique_list(d) + set.update(self, self._list) + else: + self._list = [] + + cpdef OrderedSet copy(self): + cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) + cp._list = list(self._list) + set.update(cp, cp._list) + return cp + + @cython.final + cdef OrderedSet _from_list(self, list new_list): + cdef OrderedSet new = OrderedSet.__new__(OrderedSet) + new._list = new_list + set.update(new, new_list) + return new + + def add(self, element): + if element not in self: + self._list.append(element) + PySet_Add(self, element) + + def remove(self, element): + # set.remove will raise if element is not in self + set.remove(self, element) + self._list.remove(element) + + def pop(self): + try: + value = self._list.pop() + except IndexError: + raise KeyError("pop from an empty set") from None + set.remove(self, value) + return value + + def insert(self, Py_ssize_t pos, element): + if element not in self: + self._list.insert(pos, element) + PySet_Add(self, element) + + def discard(self, element): + if element in self: + set.remove(self, element) + self._list.remove(element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, *iterables): + for iterable in iterables: + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + + def __ior__(self, iterable): + self.update(iterable) + return self + + def union(self, *other): + result = self.copy() + result.update(*other) + return result + + def __or__(self, other): + return self.union(other) + + def intersection(self, *other): + cdef set other_set = set.intersection(self, *other) + return self._from_list([a for a in self._list if a in other_set]) + + def __and__(self, other): + return self.intersection(other) + + def symmetric_difference(self, other): + cdef set other_set + if isinstance(other, set): + other_set = other + collection = other_set + elif isinstance(other, Collection): + collection = other + other_set = set(other) + else: + collection = list(other) + other_set = set(collection) + result = self._from_list([a for a in self._list if a not in other_set]) + result.update(a for a in collection if a not in self) + return result + + def __xor__(self, other): + return self.symmetric_difference(other) + + def difference(self, *other): + cdef set other_set = set.difference(self, *other) + return self._from_list([a for a in self._list if a in other_set]) + + def __sub__(self, other): + return self.difference(other) + + def intersection_update(self, *other): + set.intersection_update(self, *other) + self._list = [a for a in self._list if a in self] + + def __iand__(self, other): + self.intersection_update(other) + return self + + cpdef symmetric_difference_update(self, other): + collection = other if isinstance(other, Collection) else list(other) + set.symmetric_difference_update(self, collection) + self._list = [a for a in self._list if a in self] + self._list += [a for a in collection if a in self] + + def __ixor__(self, other): + self.symmetric_difference_update(other) + return self + + def difference_update(self, *other): + set.difference_update(self, *other) + self._list = [a for a in self._list if a in self] + + def __isub__(self, other): + self.difference_update(other) + return self + +cdef object cy_id(object item): + return PyLong_FromLongLong( (item)) + +# NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped +# instead of the __rmeth__, so they need to check that also self is of the +# correct type. This is fixed in cython 3.x. See: +# https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods +cdef class IdentitySet: + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + cdef dict _members + + def __init__(self, iterable=None): + self._members = {} + if iterable: + self.update(iterable) + + def add(self, value): + self._members[cy_id(value)] = value + + def __contains__(self, value): + return cy_id(value) in self._members + + cpdef remove(self, value): + del self._members[cy_id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + cdef tuple pair + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError("pop from an empty set") + + def clear(self): + self._members.clear() + + def __eq__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members == other_._members + else: + return False + + def __ne__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members != other_._members + else: + return True + + cpdef issubset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) > len(other): + return False + for m in filterfalse(other._members.__contains__, self._members): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + cpdef issuperset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) < len(other): + return False + for m in filterfalse(self._members.__contains__, other._members): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + cpdef IdentitySet union(self, iterable): + cdef IdentitySet result = self.__class__() + result._members.update(self._members) + result.update(iterable) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.union(other) + + cpdef update(self, iterable): + for obj in iterable: + self._members[cy_id(obj)] = obj + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + cpdef IdentitySet difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k:v for k, v in self._members.items() if k not in other} + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.difference(other) + + cpdef difference_update(self, iterable): + cdef IdentitySet other = self.difference(iterable) + self._members = other._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + cpdef IdentitySet intersection(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k in other} + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.intersection(other) + + cpdef intersection_update(self, iterable): + cdef IdentitySet other = self.intersection(iterable) + self._members = other._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + cpdef IdentitySet symmetric_difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + cdef dict other + if isinstance(iterable, self.__class__): + other = (iterable)._members + else: + other = {cy_id(obj): obj for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k not in other} + result._members.update( + [(k, v) for k, v in other.items() if k not in self._members] + ) + return result + + def __xor__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + cpdef symmetric_difference_update(self, iterable): + cdef IdentitySet other = self.symmetric_difference(iterable) + self._members = other._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + cpdef IdentitySet copy(self): + cdef IdentitySet cp = self.__new__(self.__class__) + cp._members = self._members.copy() + return cp + + def __copy__(self): + return self.copy() + + def __len__(self): + return len(self._members) + + def __iter__(self): + return iter(self._members.values()) + + def __hash__(self): + raise TypeError("set objects are unhashable") + + def __repr__(self): + return "%s(%r)" % (type(self).__name__, list(self._members.values())) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.cpython-311-x86_64-linux-gnu.so b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.cpython-311-x86_64-linux-gnu.so new file mode 100755 index 0000000..bc41cd9 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.cpython-311-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pxd b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pxd new file mode 100644 index 0000000..76f2289 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pxd @@ -0,0 +1,8 @@ +# cyextension/immutabledict.pxd +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +cdef class immutabledict(dict): + pass diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pyx b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pyx new file mode 100644 index 0000000..b37eccc --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/immutabledict.pyx @@ -0,0 +1,133 @@ +# cyextension/immutabledict.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size + + +def _readonly_fn(obj): + raise TypeError( + "%s object is immutable and/or readonly" % obj.__class__.__name__) + + +def _immutable_fn(obj): + raise TypeError( + "%s object is immutable" % obj.__class__.__name__) + + +class ReadOnlyContainer: + + __slots__ = () + + def _readonly(self, *a,**kw): + _readonly_fn(self) + + __delitem__ = __setitem__ = __setattr__ = _readonly + + +class ImmutableDictBase(dict): + def _immutable(self, *a,**kw): + _immutable_fn(self) + + @classmethod + def __class_getitem__(cls, key): + return cls + + __delitem__ = __setitem__ = __setattr__ = _immutable + clear = pop = popitem = setdefault = update = _immutable + + +cdef class immutabledict(dict): + def __repr__(self): + return f"immutabledict({dict.__repr__(self)})" + + @classmethod + def __class_getitem__(cls, key): + return cls + + def union(self, *args, **kw): + cdef dict to_merge = None + cdef immutabledict result + cdef Py_ssize_t args_len = len(args) + if args_len > 1: + raise TypeError( + f'union expected at most 1 argument, got {args_len}' + ) + if args_len == 1: + attribute = args[0] + if isinstance(attribute, dict): + to_merge = attribute + if to_merge is None: + to_merge = dict(*args, **kw) + + if PyDict_Size(to_merge) == 0: + return self + + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update(result, to_merge) + return result + + def merge_with(self, *other): + cdef immutabledict result = None + cdef object d + cdef bint update = False + if not other: + return self + for d in other: + if d: + if update == False: + update = True + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update( + result, (d if isinstance(d, dict) else dict(d)) + ) + + return self if update == False else result + + def copy(self): + return self + + def __reduce__(self): + return immutabledict, (dict(self), ) + + def __delitem__(self, k): + _immutable_fn(self) + + def __setitem__(self, k, v): + _immutable_fn(self) + + def __setattr__(self, k, v): + _immutable_fn(self) + + def clear(self, *args, **kw): + _immutable_fn(self) + + def pop(self, *args, **kw): + _immutable_fn(self) + + def popitem(self, *args, **kw): + _immutable_fn(self) + + def setdefault(self, *args, **kw): + _immutable_fn(self) + + def update(self, *args, **kw): + _immutable_fn(self) + + # PEP 584 + def __ior__(self, other): + _immutable_fn(self) + + def __or__(self, other): + return immutabledict(dict.__or__(self, other)) + + def __ror__(self, other): + # NOTE: this is used only in cython 3.x; + # version 0.x will call __or__ with args inversed + return immutabledict(dict.__ror__(self, other)) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.cpython-311-x86_64-linux-gnu.so b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.cpython-311-x86_64-linux-gnu.so new file mode 100755 index 0000000..1d86a7a Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.cpython-311-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.pyx b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.pyx new file mode 100644 index 0000000..3d71456 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/processors.pyx @@ -0,0 +1,68 @@ +# cyextension/processors.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +import datetime +from datetime import datetime as datetime_cls +from datetime import time as time_cls +from datetime import date as date_cls +import re + +from cpython.object cimport PyObject_Str +from cpython.unicode cimport PyUnicode_AsASCIIString, PyUnicode_Check, PyUnicode_Decode +from libc.stdio cimport sscanf + + +def int_to_boolean(value): + if value is None: + return None + return True if value else False + +def to_str(value): + return PyObject_Str(value) if value is not None else None + +def to_float(value): + return float(value) if value is not None else None + +cdef inline bytes to_bytes(object value, str type_name): + try: + return PyUnicode_AsASCIIString(value) + except Exception as e: + raise ValueError( + f"Couldn't parse {type_name} string '{value!r}' " + "- value is not a string." + ) from e + +def str_to_datetime(value): + if value is not None: + value = datetime_cls.fromisoformat(value) + return value + +def str_to_time(value): + if value is not None: + value = time_cls.fromisoformat(value) + return value + + +def str_to_date(value): + if value is not None: + value = date_cls.fromisoformat(value) + return value + + + +cdef class DecimalResultProcessor: + cdef object type_ + cdef str format_ + + def __cinit__(self, type_, format_): + self.type_ = type_ + self.format_ = format_ + + def process(self, object value): + if value is None: + return None + else: + return self.type_(self.format_ % value) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.cpython-311-x86_64-linux-gnu.so b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.cpython-311-x86_64-linux-gnu.so new file mode 100755 index 0000000..8e7c46c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.cpython-311-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.pyx b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.pyx new file mode 100644 index 0000000..b6e357a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/resultproxy.pyx @@ -0,0 +1,102 @@ +# cyextension/resultproxy.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +import operator + +cdef class BaseRow: + cdef readonly object _parent + cdef readonly dict _key_to_index + cdef readonly tuple _data + + def __init__(self, object parent, object processors, dict key_to_index, object data): + """Row objects are constructed by CursorResult objects.""" + + self._parent = parent + + self._key_to_index = key_to_index + + if processors: + self._data = _apply_processors(processors, data) + else: + self._data = tuple(data) + + def __reduce__(self): + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self): + return {"_parent": self._parent, "_data": self._data} + + def __setstate__(self, dict state): + parent = state["_parent"] + self._parent = parent + self._data = state["_data"] + self._key_to_index = parent._key_to_index + + def _values_impl(self): + return list(self) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __hash__(self): + return hash(self._data) + + def __getitem__(self, index): + return self._data[index] + + def _get_by_key_impl_mapping(self, key): + return self._get_by_key_impl(key, 0) + + cdef _get_by_key_impl(self, object key, int attr_err): + index = self._key_to_index.get(key) + if index is not None: + return self._data[index] + self._parent._key_not_found(key, attr_err != 0) + + def __getattr__(self, name): + return self._get_by_key_impl(name, 1) + + def _to_tuple_instance(self): + return self._data + + +cdef tuple _apply_processors(proc, data): + res = [] + for i in range(len(proc)): + p = proc[i] + if p is None: + res.append(data[i]) + else: + res.append(p(data[i])) + return tuple(res) + + +def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + +cdef int is_contiguous(tuple indexes): + cdef int i + for i in range(1, len(indexes)): + if indexes[i-1] != indexes[i] -1: + return 0 + return 1 + + +def tuplegetter(*indexes): + if len(indexes) == 1 or is_contiguous(indexes) != 0: + # slice form is faster but returns a list if input is list + return operator.itemgetter(slice(indexes[0], indexes[-1] + 1)) + else: + return operator.itemgetter(*indexes) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.cpython-311-x86_64-linux-gnu.so b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.cpython-311-x86_64-linux-gnu.so new file mode 100755 index 0000000..db67d3e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.cpython-311-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.pyx b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.pyx new file mode 100644 index 0000000..cb17acd --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/cyextension/util.pyx @@ -0,0 +1,91 @@ +# cyextension/util.pyx +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from collections.abc import Mapping + +from sqlalchemy import exc + +cdef tuple _Empty_Tuple = () + +cdef inline bint _mapping_or_tuple(object value): + return isinstance(value, dict) or isinstance(value, tuple) or isinstance(value, Mapping) + +cdef inline bint _check_item(object params) except 0: + cdef object item + cdef bint ret = 1 + if params: + item = params[0] + if not _mapping_or_tuple(item): + ret = 0 + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + return ret + +def _distill_params_20(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list) or isinstance(params, tuple): + _check_item(params) + return params + elif isinstance(params, dict) or isinstance(params, Mapping): + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list): + _check_item(params) + return params + elif _mapping_or_tuple(params): + return [params] + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") + +cdef class prefix_anon_map(dict): + def __missing__(self, str key): + cdef str derived + cdef int anonymous_counter + cdef dict self_dict = self + + derived = key.split(" ", 1)[1] + + anonymous_counter = self_dict.get(derived, 1) + self_dict[derived] = anonymous_counter + 1 + value = f"{derived}_{anonymous_counter}" + self_dict[key] = value + return value + + +cdef class cache_anon_map(dict): + cdef int _index + + def __init__(self): + self._index = 0 + + def get_anon(self, obj): + cdef long long idself + cdef str id_ + cdef dict self_dict = self + + idself = id(obj) + if idself in self_dict: + return self_dict[idself], True + else: + id_ = self.__missing__(idself) + return id_, False + + def __missing__(self, key): + cdef str val + cdef dict self_dict = self + + self_dict[key] = val = str(self._index) + self._index += 1 + return val + diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__init__.py new file mode 100644 index 0000000..7d5cc1c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__init__.py @@ -0,0 +1,61 @@ +# dialects/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Callable +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING + +from .. import util + +if TYPE_CHECKING: + from ..engine.interfaces import Dialect + +__all__ = ("mssql", "mysql", "oracle", "postgresql", "sqlite") + + +def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]: + """default dialect importer. + + plugs into the :class:`.PluginLoader` + as a first-hit system. + + """ + if "." in name: + dialect, driver = name.split(".") + else: + dialect = name + driver = "base" + + try: + if dialect == "mariadb": + # it's "OK" for us to hardcode here since _auto_fn is already + # hardcoded. if mysql / mariadb etc were third party dialects + # they would just publish all the entrypoints, which would actually + # look much nicer. + module = __import__( + "sqlalchemy.dialects.mysql.mariadb" + ).dialects.mysql.mariadb + return module.loader(driver) # type: ignore + else: + module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects + module = getattr(module, dialect) + except ImportError: + return None + + if hasattr(module, driver): + module = getattr(module, driver) + return lambda: module.dialect + else: + return None + + +registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn) + +plugins = util.PluginLoader("sqlalchemy.plugins") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..5287370 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/_typing.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/_typing.cpython-311.pyc new file mode 100644 index 0000000..c91d658 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/__pycache__/_typing.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/_typing.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/_typing.py new file mode 100644 index 0000000..9ee6e4b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/_typing.py @@ -0,0 +1,25 @@ +# dialects/_typing.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +from typing import Any +from typing import Iterable +from typing import Mapping +from typing import Optional +from typing import Union + +from ..sql._typing import _DDLColumnArgument +from ..sql.elements import DQLDMLClauseElement +from ..sql.schema import ColumnCollectionConstraint +from ..sql.schema import Index + + +_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None] +_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]] +_OnConflictIndexWhereT = Optional[DQLDMLClauseElement] +_OnConflictSetT = Optional[Mapping[Any, Any]] +_OnConflictWhereT = Union[DQLDMLClauseElement, str, None] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__init__.py new file mode 100644 index 0000000..19ab7c4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,88 @@ +# dialects/mssql/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from . import aioodbc # noqa +from . import base # noqa +from . import pymssql # noqa +from . import pyodbc # noqa +from .base import BIGINT +from .base import BINARY +from .base import BIT +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DATETIME2 +from .base import DATETIMEOFFSET +from .base import DECIMAL +from .base import DOUBLE_PRECISION +from .base import FLOAT +from .base import IMAGE +from .base import INTEGER +from .base import JSON +from .base import MONEY +from .base import NCHAR +from .base import NTEXT +from .base import NUMERIC +from .base import NVARCHAR +from .base import REAL +from .base import ROWVERSION +from .base import SMALLDATETIME +from .base import SMALLINT +from .base import SMALLMONEY +from .base import SQL_VARIANT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import TINYINT +from .base import UNIQUEIDENTIFIER +from .base import VARBINARY +from .base import VARCHAR +from .base import XML +from ...sql import try_cast + + +base.dialect = dialect = pyodbc.dialect + + +__all__ = ( + "JSON", + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "TEXT", + "NTEXT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "DATETIME", + "DATETIME2", + "DATETIMEOFFSET", + "DATE", + "DOUBLE_PRECISION", + "TIME", + "SMALLDATETIME", + "BINARY", + "VARBINARY", + "BIT", + "REAL", + "IMAGE", + "TIMESTAMP", + "ROWVERSION", + "MONEY", + "SMALLMONEY", + "UNIQUEIDENTIFIER", + "SQL_VARIANT", + "XML", + "dialect", + "try_cast", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..225eebc Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/aioodbc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/aioodbc.cpython-311.pyc new file mode 100644 index 0000000..3a0585c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/aioodbc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..527fcdf Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/information_schema.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/information_schema.cpython-311.pyc new file mode 100644 index 0000000..88bc790 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/information_schema.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/json.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/json.cpython-311.pyc new file mode 100644 index 0000000..6312e58 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/json.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/provision.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/provision.cpython-311.pyc new file mode 100644 index 0000000..fd72045 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/provision.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pymssql.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pymssql.cpython-311.pyc new file mode 100644 index 0000000..84555dc Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pymssql.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pyodbc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pyodbc.cpython-311.pyc new file mode 100644 index 0000000..19ecd43 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/__pycache__/pyodbc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/aioodbc.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/aioodbc.py new file mode 100644 index 0000000..65945d9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/aioodbc.py @@ -0,0 +1,64 @@ +# dialects/mssql/aioodbc.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +r""" +.. dialect:: mssql+aioodbc + :name: aioodbc + :dbapi: aioodbc + :connectstring: mssql+aioodbc://:@ + :url: https://pypi.org/project/aioodbc/ + + +Support for the SQL Server database in asyncio style, using the aioodbc +driver which itself is a thread-wrapper around pyodbc. + +.. versionadded:: 2.0.23 Added the mssql+aioodbc dialect which builds + on top of the pyodbc and general aio* dialect architecture. + +Using a special asyncio mediation layer, the aioodbc dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +Most behaviors and caveats for this driver are the same as that of the +pyodbc dialect used on SQL Server; see :ref:`mssql_pyodbc` for general +background. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function; connection +styles are otherwise equivalent to those documented in the pyodbc section:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine( + "mssql+aioodbc://scott:tiger@mssql2017:1433/test?" + "driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" + ) + + + +""" + +from __future__ import annotations + +from .pyodbc import MSDialect_pyodbc +from .pyodbc import MSExecutionContext_pyodbc +from ...connectors.aioodbc import aiodbcConnector + + +class MSExecutionContext_aioodbc(MSExecutionContext_pyodbc): + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class MSDialectAsync_aioodbc(aiodbcConnector, MSDialect_pyodbc): + driver = "aioodbc" + + supports_statement_cache = True + + execution_ctx_cls = MSExecutionContext_aioodbc + + +dialect = MSDialectAsync_aioodbc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/base.py new file mode 100644 index 0000000..872f858 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/base.py @@ -0,0 +1,4007 @@ +# dialects/mssql/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +""" +.. dialect:: mssql + :name: Microsoft SQL Server + :full_support: 2017 + :normal_support: 2012+ + :best_effort: 2005+ + +.. _mssql_external_dialects: + +External Dialects +----------------- + +In addition to the above DBAPI layers with native SQLAlchemy support, there +are third-party dialects for other DBAPI layers that are compatible +with SQL Server. See the "External Dialects" list on the +:ref:`dialect_toplevel` page. + +.. _mssql_identity: + +Auto Increment Behavior / IDENTITY Columns +------------------------------------------ + +SQL Server provides so-called "auto incrementing" behavior using the +``IDENTITY`` construct, which can be placed on any single integer column in a +table. SQLAlchemy considers ``IDENTITY`` within its default "autoincrement" +behavior for an integer primary key column, described at +:paramref:`_schema.Column.autoincrement`. This means that by default, +the first integer primary key column in a :class:`_schema.Table` will be +considered to be the identity column - unless it is associated with a +:class:`.Sequence` - and will generate DDL as such:: + + from sqlalchemy import Table, MetaData, Column, Integer + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True), + Column('x', Integer)) + m.create_all(engine) + +The above example will generate DDL as: + +.. sourcecode:: sql + + CREATE TABLE t ( + id INTEGER NOT NULL IDENTITY, + x INTEGER NULL, + PRIMARY KEY (id) + ) + +For the case where this default generation of ``IDENTITY`` is not desired, +specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag, +on the first integer primary key column:: + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer)) + m.create_all(engine) + +To add the ``IDENTITY`` keyword to a non-primary key column, specify +``True`` for the :paramref:`_schema.Column.autoincrement` flag on the desired +:class:`_schema.Column` object, and ensure that +:paramref:`_schema.Column.autoincrement` +is set to ``False`` on any integer primary key column:: + + m = MetaData() + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer, autoincrement=True)) + m.create_all(engine) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the start and increment + parameters of an IDENTITY. These replace + the use of the :class:`.Sequence` object in order to specify these values. + +.. deprecated:: 1.4 + + The ``mssql_identity_start`` and ``mssql_identity_increment`` parameters + to :class:`_schema.Column` are deprecated and should we replaced by + an :class:`_schema.Identity` object. Specifying both ways of configuring + an IDENTITY will result in a compile error. + These options are also no longer returned as part of the + ``dialect_options`` key in :meth:`_reflection.Inspector.get_columns`. + Use the information in the ``identity`` key instead. + +.. deprecated:: 1.3 + + The use of :class:`.Sequence` to specify IDENTITY characteristics is + deprecated and will be removed in a future release. Please use + the :class:`_schema.Identity` object parameters + :paramref:`_schema.Identity.start` and + :paramref:`_schema.Identity.increment`. + +.. versionchanged:: 1.4 Removed the ability to use a :class:`.Sequence` + object to modify IDENTITY characteristics. :class:`.Sequence` objects + now only manipulate true T-SQL SEQUENCE types. + +.. note:: + + There can only be one IDENTITY column on the table. When using + ``autoincrement=True`` to enable the IDENTITY keyword, SQLAlchemy does not + guard against multiple columns specifying the option simultaneously. The + SQL Server database will instead reject the ``CREATE TABLE`` statement. + +.. note:: + + An INSERT statement which attempts to provide a value for a column that is + marked with IDENTITY will be rejected by SQL Server. In order for the + value to be accepted, a session-level option "SET IDENTITY_INSERT" must be + enabled. The SQLAlchemy SQL Server dialect will perform this operation + automatically when using a core :class:`_expression.Insert` + construct; if the + execution specifies a value for the IDENTITY column, the "IDENTITY_INSERT" + option will be enabled for the span of that statement's invocation.However, + this scenario is not high performing and should not be relied upon for + normal use. If a table doesn't actually require IDENTITY behavior in its + integer primary key column, the keyword should be disabled when creating + the table by ensuring that ``autoincrement=False`` is set. + +Controlling "Start" and "Increment" +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Specific control over the "start" and "increment" values for +the ``IDENTITY`` generator are provided using the +:paramref:`_schema.Identity.start` and :paramref:`_schema.Identity.increment` +parameters passed to the :class:`_schema.Identity` object:: + + from sqlalchemy import Table, Integer, Column, Identity + + test = Table( + 'test', metadata, + Column( + 'id', + Integer, + primary_key=True, + Identity(start=100, increment=10) + ), + Column('name', String(20)) + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE test ( + id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, + name VARCHAR(20) NULL, + ) + +.. note:: + + The :class:`_schema.Identity` object supports many other parameter in + addition to ``start`` and ``increment``. These are not supported by + SQL Server and will be ignored when generating the CREATE TABLE ddl. + +.. versionchanged:: 1.3.19 The :class:`_schema.Identity` object is + now used to affect the + ``IDENTITY`` generator for a :class:`_schema.Column` under SQL Server. + Previously, the :class:`.Sequence` object was used. As SQL Server now + supports real sequences as a separate construct, :class:`.Sequence` will be + functional in the normal way starting from SQLAlchemy version 1.4. + + +Using IDENTITY with Non-Integer numeric types +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQL Server also allows ``IDENTITY`` to be used with ``NUMERIC`` columns. To +implement this pattern smoothly in SQLAlchemy, the primary datatype of the +column should remain as ``Integer``, however the underlying implementation +type deployed to the SQL Server database can be specified as ``Numeric`` using +:meth:`.TypeEngine.with_variant`:: + + from sqlalchemy import Column + from sqlalchemy import Integer + from sqlalchemy import Numeric + from sqlalchemy import String + from sqlalchemy.ext.declarative import declarative_base + + Base = declarative_base() + + class TestTable(Base): + __tablename__ = "test" + id = Column( + Integer().with_variant(Numeric(10, 0), "mssql"), + primary_key=True, + autoincrement=True, + ) + name = Column(String) + +In the above example, ``Integer().with_variant()`` provides clear usage +information that accurately describes the intent of the code. The general +restriction that ``autoincrement`` only applies to ``Integer`` is established +at the metadata level and not at the per-dialect level. + +When using the above pattern, the primary key identifier that comes back from +the insertion of a row, which is also the value that would be assigned to an +ORM object such as ``TestTable`` above, will be an instance of ``Decimal()`` +and not ``int`` when using SQL Server. The numeric return type of the +:class:`_types.Numeric` type can be changed to return floats by passing False +to :paramref:`_types.Numeric.asdecimal`. To normalize the return type of the +above ``Numeric(10, 0)`` to return Python ints (which also support "long" +integer values in Python 3), use :class:`_types.TypeDecorator` as follows:: + + from sqlalchemy import TypeDecorator + + class NumericAsInteger(TypeDecorator): + '''normalize floating point return values into ints''' + + impl = Numeric(10, 0, asdecimal=False) + cache_ok = True + + def process_result_value(self, value, dialect): + if value is not None: + value = int(value) + return value + + class TestTable(Base): + __tablename__ = "test" + id = Column( + Integer().with_variant(NumericAsInteger, "mssql"), + primary_key=True, + autoincrement=True, + ) + name = Column(String) + +.. _mssql_insert_behavior: + +INSERT behavior +^^^^^^^^^^^^^^^^ + +Handling of the ``IDENTITY`` column at INSERT time involves two key +techniques. The most common is being able to fetch the "last inserted value" +for a given ``IDENTITY`` column, a process which SQLAlchemy performs +implicitly in many cases, most importantly within the ORM. + +The process for fetching this value has several variants: + +* In the vast majority of cases, RETURNING is used in conjunction with INSERT + statements on SQL Server in order to get newly generated primary key values: + + .. sourcecode:: sql + + INSERT INTO t (x) OUTPUT inserted.id VALUES (?) + + As of SQLAlchemy 2.0, the :ref:`engine_insertmanyvalues` feature is also + used by default to optimize many-row INSERT statements; for SQL Server + the feature takes place for both RETURNING and-non RETURNING + INSERT statements. + + .. versionchanged:: 2.0.10 The :ref:`engine_insertmanyvalues` feature for + SQL Server was temporarily disabled for SQLAlchemy version 2.0.9 due to + issues with row ordering. As of 2.0.10 the feature is re-enabled, with + special case handling for the unit of work's requirement for RETURNING to + be ordered. + +* When RETURNING is not available or has been disabled via + ``implicit_returning=False``, either the ``scope_identity()`` function or + the ``@@identity`` variable is used; behavior varies by backend: + + * when using PyODBC, the phrase ``; select scope_identity()`` will be + appended to the end of the INSERT statement; a second result set will be + fetched in order to receive the value. Given a table as:: + + t = Table( + 't', + metadata, + Column('id', Integer, primary_key=True), + Column('x', Integer), + implicit_returning=False + ) + + an INSERT will look like: + + .. sourcecode:: sql + + INSERT INTO t (x) VALUES (?); select scope_identity() + + * Other dialects such as pymssql will call upon + ``SELECT scope_identity() AS lastrowid`` subsequent to an INSERT + statement. If the flag ``use_scope_identity=False`` is passed to + :func:`_sa.create_engine`, + the statement ``SELECT @@identity AS lastrowid`` + is used instead. + +A table that contains an ``IDENTITY`` column will prohibit an INSERT statement +that refers to the identity column explicitly. The SQLAlchemy dialect will +detect when an INSERT construct, created using a core +:func:`_expression.insert` +construct (not a plain string SQL), refers to the identity column, and +in this case will emit ``SET IDENTITY_INSERT ON`` prior to the insert +statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the +execution. Given this example:: + + m = MetaData() + t = Table('t', m, Column('id', Integer, primary_key=True), + Column('x', Integer)) + m.create_all(engine) + + with engine.begin() as conn: + conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) + +The above column will be created with IDENTITY, however the INSERT statement +we emit is specifying explicit values. In the echo output we can see +how SQLAlchemy handles this: + +.. sourcecode:: sql + + CREATE TABLE t ( + id INTEGER NOT NULL IDENTITY(1,1), + x INTEGER NULL, + PRIMARY KEY (id) + ) + + COMMIT + SET IDENTITY_INSERT t ON + INSERT INTO t (id, x) VALUES (?, ?) + ((1, 1), (2, 2)) + SET IDENTITY_INSERT t OFF + COMMIT + + + +This is an auxiliary use case suitable for testing and bulk insert scenarios. + +SEQUENCE support +---------------- + +The :class:`.Sequence` object creates "real" sequences, i.e., +``CREATE SEQUENCE``: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import Sequence + >>> from sqlalchemy.schema import CreateSequence + >>> from sqlalchemy.dialects import mssql + >>> print(CreateSequence(Sequence("my_seq", start=1)).compile(dialect=mssql.dialect())) + {printsql}CREATE SEQUENCE my_seq START WITH 1 + +For integer primary key generation, SQL Server's ``IDENTITY`` construct should +generally be preferred vs. sequence. + +.. tip:: + + The default start value for T-SQL is ``-2**63`` instead of 1 as + in most other SQL databases. Users should explicitly set the + :paramref:`.Sequence.start` to 1 if that's the expected default:: + + seq = Sequence("my_sequence", start=1) + +.. versionadded:: 1.4 added SQL Server support for :class:`.Sequence` + +.. versionchanged:: 2.0 The SQL Server dialect will no longer implicitly + render "START WITH 1" for ``CREATE SEQUENCE``, which was the behavior + first implemented in version 1.4. + +MAX on VARCHAR / NVARCHAR +------------------------- + +SQL Server supports the special string "MAX" within the +:class:`_types.VARCHAR` and :class:`_types.NVARCHAR` datatypes, +to indicate "maximum length possible". The dialect currently handles this as +a length of "None" in the base type, rather than supplying a +dialect-specific version of these types, so that a base type +specified such as ``VARCHAR(None)`` can assume "unlengthed" behavior on +more than one backend without using dialect-specific types. + +To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None:: + + my_table = Table( + 'my_table', metadata, + Column('my_data', VARCHAR(None)), + Column('my_n_data', NVARCHAR(None)) + ) + + +Collation Support +----------------- + +Character collations are supported by the base string types, +specified by the string argument "collation":: + + from sqlalchemy import VARCHAR + Column('login', VARCHAR(32, collation='Latin1_General_CI_AS')) + +When such a column is associated with a :class:`_schema.Table`, the +CREATE TABLE statement for this column will yield:: + + login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL + +LIMIT/OFFSET Support +-------------------- + +MSSQL has added support for LIMIT / OFFSET as of SQL Server 2012, via the +"OFFSET n ROWS" and "FETCH NEXT n ROWS" clauses. SQLAlchemy supports these +syntaxes automatically if SQL Server 2012 or greater is detected. + +.. versionchanged:: 1.4 support added for SQL Server "OFFSET n ROWS" and + "FETCH NEXT n ROWS" syntax. + +For statements that specify only LIMIT and no OFFSET, all versions of SQL +Server support the TOP keyword. This syntax is used for all SQL Server +versions when no OFFSET clause is present. A statement such as:: + + select(some_table).limit(5) + +will render similarly to:: + + SELECT TOP 5 col1, col2.. FROM table + +For versions of SQL Server prior to SQL Server 2012, a statement that uses +LIMIT and OFFSET, or just OFFSET alone, will be rendered using the +``ROW_NUMBER()`` window function. A statement such as:: + + select(some_table).order_by(some_table.c.col3).limit(5).offset(10) + +will render similarly to:: + + SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2, + ROW_NUMBER() OVER (ORDER BY col3) AS + mssql_rn FROM table WHERE t.x = :x_1) AS + anon_1 WHERE mssql_rn > :param_1 AND mssql_rn <= :param_2 + :param_1 + +Note that when using LIMIT and/or OFFSET, whether using the older +or newer SQL Server syntaxes, the statement must have an ORDER BY as well, +else a :class:`.CompileError` is raised. + +.. _mssql_comment_support: + +DDL Comment Support +-------------------- + +Comment support, which includes DDL rendering for attributes such as +:paramref:`_schema.Table.comment` and :paramref:`_schema.Column.comment`, as +well as the ability to reflect these comments, is supported assuming a +supported version of SQL Server is in use. If a non-supported version such as +Azure Synapse is detected at first-connect time (based on the presence +of the ``fn_listextendedproperty`` SQL function), comment support including +rendering and table-comment reflection is disabled, as both features rely upon +SQL Server stored procedures and functions that are not available on all +backend types. + +To force comment support to be on or off, bypassing autodetection, set the +parameter ``supports_comments`` within :func:`_sa.create_engine`:: + + e = create_engine("mssql+pyodbc://u:p@dsn", supports_comments=False) + +.. versionadded:: 2.0 Added support for table and column comments for + the SQL Server dialect, including DDL generation and reflection. + +.. _mssql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All SQL Server dialects support setting of transaction isolation level +both via a dialect-specific parameter +:paramref:`_sa.create_engine.isolation_level` +accepted by :func:`_sa.create_engine`, +as well as the :paramref:`.Connection.execution_options.isolation_level` +argument as passed to +:meth:`_engine.Connection.execution_options`. +This feature works by issuing the +command ``SET TRANSACTION ISOLATION LEVEL `` for +each new connection. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "mssql+pyodbc://scott:tiger@ms_2008", + isolation_level="REPEATABLE READ" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) + +Valid values for ``isolation_level`` include: + +* ``AUTOCOMMIT`` - pyodbc / pymssql-specific +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``SNAPSHOT`` - specific to SQL Server + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +.. seealso:: + + :ref:`dbapi_autocommit` + +.. _mssql_reset_on_return: + +Temporary Table / Resource Reset for Connection Pooling +------------------------------------------------------- + +The :class:`.QueuePool` connection pool implementation used +by the SQLAlchemy :class:`.Engine` object includes +:ref:`reset on return ` behavior that will invoke +the DBAPI ``.rollback()`` method when connections are returned to the pool. +While this rollback will clear out the immediate state used by the previous +transaction, it does not cover a wider range of session-level state, including +temporary tables as well as other server state such as prepared statement +handles and statement caches. An undocumented SQL Server procedure known +as ``sp_reset_connection`` is known to be a workaround for this issue which +will reset most of the session state that builds up on a connection, including +temporary tables. + +To install ``sp_reset_connection`` as the means of performing reset-on-return, +the :meth:`.PoolEvents.reset` event hook may be used, as demonstrated in the +example below. The :paramref:`_sa.create_engine.pool_reset_on_return` parameter +is set to ``None`` so that the custom scheme can replace the default behavior +completely. The custom hook implementation calls ``.rollback()`` in any case, +as it's usually important that the DBAPI's own tracking of commit/rollback +will remain consistent with the state of the transaction:: + + from sqlalchemy import create_engine + from sqlalchemy import event + + mssql_engine = create_engine( + "mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server", + + # disable default reset-on-return scheme + pool_reset_on_return=None, + ) + + + @event.listens_for(mssql_engine, "reset") + def _reset_mssql(dbapi_connection, connection_record, reset_state): + if not reset_state.terminate_only: + dbapi_connection.execute("{call sys.sp_reset_connection}") + + # so that the DBAPI itself knows that the connection has been + # reset + dbapi_connection.rollback() + +.. versionchanged:: 2.0.0b3 Added additional state arguments to + the :meth:`.PoolEvents.reset` event and additionally ensured the event + is invoked for all "reset" occurrences, so that it's appropriate + as a place for custom "reset" handlers. Previous schemes which + use the :meth:`.PoolEvents.checkin` handler remain usable as well. + +.. seealso:: + + :ref:`pool_reset_on_return` - in the :ref:`pooling_toplevel` documentation + +Nullability +----------- +MSSQL has support for three levels of column nullability. The default +nullability allows nulls and is explicit in the CREATE TABLE +construct:: + + name VARCHAR(20) NULL + +If ``nullable=None`` is specified then no specification is made. In +other words the database's configured default is used. This will +render:: + + name VARCHAR(20) + +If ``nullable`` is ``True`` or ``False`` then the column will be +``NULL`` or ``NOT NULL`` respectively. + +Date / Time Handling +-------------------- +DATE and TIME are supported. Bind parameters are converted +to datetime.datetime() objects as required by most MSSQL drivers, +and results are processed from strings if needed. +The DATE and TIME types are not available for MSSQL 2005 and +previous - if a server version below 2008 is detected, DDL +for these types will be issued as DATETIME. + +.. _mssql_large_type_deprecation: + +Large Text/Binary Type Deprecation +---------------------------------- + +Per +`SQL Server 2012/2014 Documentation `_, +the ``NTEXT``, ``TEXT`` and ``IMAGE`` datatypes are to be removed from SQL +Server in a future release. SQLAlchemy normally relates these types to the +:class:`.UnicodeText`, :class:`_expression.TextClause` and +:class:`.LargeBinary` datatypes. + +In order to accommodate this change, a new flag ``deprecate_large_types`` +is added to the dialect, which will be automatically set based on detection +of the server version in use, if not otherwise set by the user. The +behavior of this flag is as follows: + +* When this flag is ``True``, the :class:`.UnicodeText`, + :class:`_expression.TextClause` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NVARCHAR(max)``, ``VARCHAR(max)``, and ``VARBINARY(max)``, + respectively. This is a new behavior as of the addition of this flag. + +* When this flag is ``False``, the :class:`.UnicodeText`, + :class:`_expression.TextClause` and + :class:`.LargeBinary` datatypes, when used to render DDL, will render the + types ``NTEXT``, ``TEXT``, and ``IMAGE``, + respectively. This is the long-standing behavior of these types. + +* The flag begins with the value ``None``, before a database connection is + established. If the dialect is used to render DDL without the flag being + set, it is interpreted the same as ``False``. + +* On first connection, the dialect detects if SQL Server version 2012 or + greater is in use; if the flag is still at ``None``, it sets it to ``True`` + or ``False`` based on whether 2012 or greater is detected. + +* The flag can be set to either ``True`` or ``False`` when the dialect + is created, typically via :func:`_sa.create_engine`:: + + eng = create_engine("mssql+pymssql://user:pass@host/db", + deprecate_large_types=True) + +* Complete control over whether the "old" or "new" types are rendered is + available in all SQLAlchemy versions by using the UPPERCASE type objects + instead: :class:`_types.NVARCHAR`, :class:`_types.VARCHAR`, + :class:`_types.VARBINARY`, :class:`_types.TEXT`, :class:`_mssql.NTEXT`, + :class:`_mssql.IMAGE` + will always remain fixed and always output exactly that + type. + +.. _multipart_schema_names: + +Multipart Schema Names +---------------------- + +SQL Server schemas sometimes require multiple parts to their "schema" +qualifier, that is, including the database name and owner name as separate +tokens, such as ``mydatabase.dbo.some_table``. These multipart names can be set +at once using the :paramref:`_schema.Table.schema` argument of +:class:`_schema.Table`:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="mydatabase.dbo" + ) + +When performing operations such as table or component reflection, a schema +argument that contains a dot will be split into separate +"database" and "owner" components in order to correctly query the SQL +Server information schema tables, as these two values are stored separately. +Additionally, when rendering the schema name for DDL or SQL, the two +components will be quoted separately for case sensitive names and other +special characters. Given an argument as below:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="MyDataBase.dbo" + ) + +The above schema would be rendered as ``[MyDataBase].dbo``, and also in +reflection, would be reflected using "dbo" as the owner and "MyDataBase" +as the database name. + +To control how the schema name is broken into database / owner, +specify brackets (which in SQL Server are quoting characters) in the name. +Below, the "owner" will be considered as ``MyDataBase.dbo`` and the +"database" will be None:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="[MyDataBase.dbo]" + ) + +To individually specify both database and owner name with special characters +or embedded dots, use two sets of brackets:: + + Table( + "some_table", metadata, + Column("q", String(50)), + schema="[MyDataBase.Period].[MyOwner.Dot]" + ) + + +.. versionchanged:: 1.2 the SQL Server dialect now treats brackets as + identifier delimiters splitting the schema into separate database + and owner tokens, to allow dots within either name itself. + +.. _legacy_schema_rendering: + +Legacy Schema Mode +------------------ + +Very old versions of the MSSQL dialect introduced the behavior such that a +schema-qualified table would be auto-aliased when used in a +SELECT statement; given a table:: + + account_table = Table( + 'account', metadata, + Column('id', Integer, primary_key=True), + Column('info', String(100)), + schema="customer_schema" + ) + +this legacy mode of rendering would assume that "customer_schema.account" +would not be accepted by all parts of the SQL statement, as illustrated +below: + +.. sourcecode:: pycon+sql + + >>> eng = create_engine("mssql+pymssql://mydsn", legacy_schema_aliasing=True) + >>> print(account_table.select().compile(eng)) + {printsql}SELECT account_1.id, account_1.info + FROM customer_schema.account AS account_1 + +This mode of behavior is now off by default, as it appears to have served +no purpose; however in the case that legacy applications rely upon it, +it is available using the ``legacy_schema_aliasing`` argument to +:func:`_sa.create_engine` as illustrated above. + +.. deprecated:: 1.4 + + The ``legacy_schema_aliasing`` flag is now + deprecated and will be removed in a future release. + +.. _mssql_indexes: + +Clustered Index Support +----------------------- + +The MSSQL dialect supports clustered indexes (and primary keys) via the +``mssql_clustered`` option. This option is available to :class:`.Index`, +:class:`.UniqueConstraint`. and :class:`.PrimaryKeyConstraint`. +For indexes this option can be combined with the ``mssql_columnstore`` one +to create a clustered columnstore index. + +To generate a clustered index:: + + Index("my_index", table.c.x, mssql_clustered=True) + +which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``. + +To generate a clustered primary key use:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=True)) + +which will render the table, for example, as:: + + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY CLUSTERED (x, y)) + +Similarly, we can generate a clustered unique constraint using:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x"), + UniqueConstraint("y", mssql_clustered=True), + ) + +To explicitly request a non-clustered primary key (for example, when +a separate clustered index is desired), use:: + + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=False)) + +which will render the table, for example, as:: + + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY NONCLUSTERED (x, y)) + +Columnstore Index Support +------------------------- + +The MSSQL dialect supports columnstore indexes via the ``mssql_columnstore`` +option. This option is available to :class:`.Index`. It be combined with +the ``mssql_clustered`` option to create a clustered columnstore index. + +To generate a columnstore index:: + + Index("my_index", table.c.x, mssql_columnstore=True) + +which renders the index as ``CREATE COLUMNSTORE INDEX my_index ON table (x)``. + +To generate a clustered columnstore index provide no columns:: + + idx = Index("my_index", mssql_clustered=True, mssql_columnstore=True) + # required to associate the index with the table + table.append_constraint(idx) + +the above renders the index as +``CREATE CLUSTERED COLUMNSTORE INDEX my_index ON table``. + +.. versionadded:: 2.0.18 + +MSSQL-Specific Index Options +----------------------------- + +In addition to clustering, the MSSQL dialect supports other special options +for :class:`.Index`. + +INCLUDE +^^^^^^^ + +The ``mssql_include`` option renders INCLUDE(colname) for the given string +names:: + + Index("my_index", table.c.x, mssql_include=['y']) + +would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` + +.. _mssql_index_where: + +Filtered Indexes +^^^^^^^^^^^^^^^^ + +The ``mssql_where`` option renders WHERE(condition) for the given string +names:: + + Index("my_index", table.c.x, mssql_where=table.c.x > 10) + +would render the index as ``CREATE INDEX my_index ON table (x) WHERE x > 10``. + +.. versionadded:: 1.3.4 + +Index ordering +^^^^^^^^^^^^^^ + +Index ordering is available via functional expressions, such as:: + + Index("my_index", table.c.x.desc()) + +would render the index as ``CREATE INDEX my_index ON table (x DESC)`` + +.. seealso:: + + :ref:`schema_indexes_functional` + +Compatibility Levels +-------------------- +MSSQL supports the notion of setting compatibility levels at the +database level. This allows, for instance, to run a database that +is compatible with SQL2000 while running on a SQL2005 database +server. ``server_version_info`` will always return the database +server version information (in this case SQL2005) and not the +compatibility level information. Because of this, if running under +a backwards compatibility mode SQLAlchemy may attempt to use T-SQL +statements that are unable to be parsed by the database server. + +.. _mssql_triggers: + +Triggers +-------- + +SQLAlchemy by default uses OUTPUT INSERTED to get at newly +generated primary key values via IDENTITY columns or other +server side defaults. MS-SQL does not +allow the usage of OUTPUT INSERTED on tables that have triggers. +To disable the usage of OUTPUT INSERTED on a per-table basis, +specify ``implicit_returning=False`` for each :class:`_schema.Table` +which has triggers:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + # ..., + implicit_returning=False + ) + +Declarative form:: + + class MyClass(Base): + # ... + __table_args__ = {'implicit_returning':False} + + +.. _mssql_rowcount_versioning: + +Rowcount Support / ORM Versioning +--------------------------------- + +The SQL Server drivers may have limited ability to return the number +of rows updated from an UPDATE or DELETE statement. + +As of this writing, the PyODBC driver is not able to return a rowcount when +OUTPUT INSERTED is used. Previous versions of SQLAlchemy therefore had +limitations for features such as the "ORM Versioning" feature that relies upon +accurate rowcounts in order to match version numbers with matched rows. + +SQLAlchemy 2.0 now retrieves the "rowcount" manually for these particular use +cases based on counting the rows that arrived back within RETURNING; so while +the driver still has this limitation, the ORM Versioning feature is no longer +impacted by it. As of SQLAlchemy 2.0.5, ORM versioning has been fully +re-enabled for the pyodbc driver. + +.. versionchanged:: 2.0.5 ORM versioning support is restored for the pyodbc + driver. Previously, a warning would be emitted during ORM flush that + versioning was not supported. + + +Enabling Snapshot Isolation +--------------------------- + +SQL Server has a default transaction +isolation mode that locks entire tables, and causes even mildly concurrent +applications to have long held locks and frequent deadlocks. +Enabling snapshot isolation for the database as a whole is recommended +for modern levels of concurrency support. This is accomplished via the +following ALTER DATABASE commands executed at the SQL prompt:: + + ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON + + ALTER DATABASE MyDatabase SET READ_COMMITTED_SNAPSHOT ON + +Background on SQL Server snapshot isolation is available at +https://msdn.microsoft.com/en-us/library/ms175095.aspx. + +""" # noqa + +from __future__ import annotations + +import codecs +import datetime +import operator +import re +from typing import overload +from typing import TYPE_CHECKING +from uuid import UUID as _python_UUID + +from . import information_schema as ischema +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from ... import exc +from ... import Identity +from ... import schema as sa_schema +from ... import Sequence +from ... import sql +from ... import text +from ... import util +from ...engine import cursor as _cursor +from ...engine import default +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import expression +from ...sql import func +from ...sql import quoted_name +from ...sql import roles +from ...sql import sqltypes +from ...sql import try_cast as try_cast # noqa: F401 +from ...sql import util as sql_util +from ...sql._typing import is_sql_compiler +from ...sql.compiler import InsertmanyvaluesSentinelOpts +from ...sql.elements import TryCast as TryCast # noqa: F401 +from ...types import BIGINT +from ...types import BINARY +from ...types import CHAR +from ...types import DATE +from ...types import DATETIME +from ...types import DECIMAL +from ...types import FLOAT +from ...types import INTEGER +from ...types import NCHAR +from ...types import NUMERIC +from ...types import NVARCHAR +from ...types import SMALLINT +from ...types import TEXT +from ...types import VARCHAR +from ...util import update_wrapper +from ...util.typing import Literal + +if TYPE_CHECKING: + from ...sql.dml import DMLState + from ...sql.selectable import TableClause + +# https://sqlserverbuilds.blogspot.com/ +MS_2017_VERSION = (14,) +MS_2016_VERSION = (13,) +MS_2014_VERSION = (12,) +MS_2012_VERSION = (11,) +MS_2008_VERSION = (10,) +MS_2005_VERSION = (9,) +MS_2000_VERSION = (8,) + +RESERVED_WORDS = { + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", +} + + +class REAL(sqltypes.REAL): + """the SQL Server REAL datatype.""" + + def __init__(self, **kw): + # REAL is a synonym for FLOAT(24) on SQL server. + # it is only accepted as the word "REAL" in DDL, the numeric + # precision value is not allowed to be present + kw.setdefault("precision", 24) + super().__init__(**kw) + + +class DOUBLE_PRECISION(sqltypes.DOUBLE_PRECISION): + """the SQL Server DOUBLE PRECISION datatype. + + .. versionadded:: 2.0.11 + + """ + + def __init__(self, **kw): + # DOUBLE PRECISION is a synonym for FLOAT(53) on SQL server. + # it is only accepted as the word "DOUBLE PRECISION" in DDL, + # the numeric precision value is not allowed to be present + kw.setdefault("precision", 53) + super().__init__(**kw) + + +class TINYINT(sqltypes.Integer): + __visit_name__ = "TINYINT" + + +# MSSQL DATE/TIME types have varied behavior, sometimes returning +# strings. MSDate/TIME check for everything, and always +# filter bind parameters into datetime objects (required by pyodbc, +# not sure about other dialects). + + +class _MSDate(sqltypes.Date): + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + _reg = re.compile(r"(\d+)-(\d+)-(\d+)") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.date() + elif isinstance(value, str): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a date value" % (value,) + ) + return datetime.date(*[int(x or 0) for x in m.groups()]) + else: + return value + + return process + + +class TIME(sqltypes.TIME): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super().__init__() + + __zero_date = datetime.date(1900, 1, 1) + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, datetime.datetime): + value = datetime.datetime.combine( + self.__zero_date, value.time() + ) + elif isinstance(value, datetime.time): + """issue #5339 + per: https://github.com/mkleehammer/pyodbc/wiki/Tips-and-Tricks-by-Database-Platform#time-columns + pass TIME value as string + """ # noqa + value = str(value) + return value + + return process + + _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") + + def result_processor(self, dialect, coltype): + def process(value): + if isinstance(value, datetime.datetime): + return value.time() + elif isinstance(value, str): + m = self._reg.match(value) + if not m: + raise ValueError( + "could not parse %r as a time value" % (value,) + ) + return datetime.time(*[int(x or 0) for x in m.groups()]) + else: + return value + + return process + + +_MSTime = TIME + + +class _BASETIMEIMPL(TIME): + __visit_name__ = "_BASETIMEIMPL" + + +class _DateTimeBase: + def bind_processor(self, dialect): + def process(value): + if type(value) == datetime.date: + return datetime.datetime(value.year, value.month, value.day) + else: + return value + + return process + + +class _MSDateTime(_DateTimeBase, sqltypes.DateTime): + pass + + +class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "SMALLDATETIME" + + +class DATETIME2(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "DATETIME2" + + def __init__(self, precision=None, **kw): + super().__init__(**kw) + self.precision = precision + + +class DATETIMEOFFSET(_DateTimeBase, sqltypes.DateTime): + __visit_name__ = "DATETIMEOFFSET" + + def __init__(self, precision=None, **kw): + super().__init__(**kw) + self.precision = precision + + +class _UnicodeLiteral: + def literal_processor(self, dialect): + def process(value): + value = value.replace("'", "''") + + if dialect.identifier_preparer._double_percents: + value = value.replace("%", "%%") + + return "N'%s'" % value + + return process + + +class _MSUnicode(_UnicodeLiteral, sqltypes.Unicode): + pass + + +class _MSUnicodeText(_UnicodeLiteral, sqltypes.UnicodeText): + pass + + +class TIMESTAMP(sqltypes._Binary): + """Implement the SQL Server TIMESTAMP type. + + Note this is **completely different** than the SQL Standard + TIMESTAMP type, which is not supported by SQL Server. It + is a read-only datatype that does not support INSERT of values. + + .. versionadded:: 1.2 + + .. seealso:: + + :class:`_mssql.ROWVERSION` + + """ + + __visit_name__ = "TIMESTAMP" + + # expected by _Binary to be present + length = None + + def __init__(self, convert_int=False): + """Construct a TIMESTAMP or ROWVERSION type. + + :param convert_int: if True, binary integer values will + be converted to integers on read. + + .. versionadded:: 1.2 + + """ + self.convert_int = convert_int + + def result_processor(self, dialect, coltype): + super_ = super().result_processor(dialect, coltype) + if self.convert_int: + + def process(value): + if super_: + value = super_(value) + if value is not None: + # https://stackoverflow.com/a/30403242/34549 + value = int(codecs.encode(value, "hex"), 16) + return value + + return process + else: + return super_ + + +class ROWVERSION(TIMESTAMP): + """Implement the SQL Server ROWVERSION type. + + The ROWVERSION datatype is a SQL Server synonym for the TIMESTAMP + datatype, however current SQL Server documentation suggests using + ROWVERSION for new datatypes going forward. + + The ROWVERSION datatype does **not** reflect (e.g. introspect) from the + database as itself; the returned datatype will be + :class:`_mssql.TIMESTAMP`. + + This is a read-only datatype that does not support INSERT of values. + + .. versionadded:: 1.2 + + .. seealso:: + + :class:`_mssql.TIMESTAMP` + + """ + + __visit_name__ = "ROWVERSION" + + +class NTEXT(sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 + characters.""" + + __visit_name__ = "NTEXT" + + +class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): + """The MSSQL VARBINARY type. + + This type adds additional features to the core :class:`_types.VARBINARY` + type, including "deprecate_large_types" mode where + either ``VARBINARY(max)`` or IMAGE is rendered, as well as the SQL + Server ``FILESTREAM`` option. + + .. seealso:: + + :ref:`mssql_large_type_deprecation` + + """ + + __visit_name__ = "VARBINARY" + + def __init__(self, length=None, filestream=False): + """ + Construct a VARBINARY type. + + :param length: optional, a length for the column for use in + DDL statements, for those binary types that accept a length, + such as the MySQL BLOB type. + + :param filestream=False: if True, renders the ``FILESTREAM`` keyword + in the table definition. In this case ``length`` must be ``None`` + or ``'max'``. + + .. versionadded:: 1.4.31 + + """ + + self.filestream = filestream + if self.filestream and length not in (None, "max"): + raise ValueError( + "length must be None or 'max' when setting filestream" + ) + super().__init__(length=length) + + +class IMAGE(sqltypes.LargeBinary): + __visit_name__ = "IMAGE" + + +class XML(sqltypes.Text): + """MSSQL XML type. + + This is a placeholder type for reflection purposes that does not include + any Python-side datatype support. It also does not currently support + additional arguments, such as "CONTENT", "DOCUMENT", + "xml_schema_collection". + + """ + + __visit_name__ = "XML" + + +class BIT(sqltypes.Boolean): + """MSSQL BIT type. + + Both pyodbc and pymssql return values from BIT columns as + Python so just subclass Boolean. + + """ + + __visit_name__ = "BIT" + + +class MONEY(sqltypes.TypeEngine): + __visit_name__ = "MONEY" + + +class SMALLMONEY(sqltypes.TypeEngine): + __visit_name__ = "SMALLMONEY" + + +class MSUUid(sqltypes.Uuid): + def bind_processor(self, dialect): + if self.native_uuid: + # this is currently assuming pyodbc; might not work for + # some other mssql driver + return None + else: + if self.as_uuid: + + def process(value): + if value is not None: + value = value.hex + return value + + return process + else: + + def process(value): + if value is not None: + value = value.replace("-", "").replace("''", "'") + return value + + return process + + def literal_processor(self, dialect): + if self.native_uuid: + + def process(value): + return f"""'{str(value).replace("''", "'")}'""" + + return process + else: + if self.as_uuid: + + def process(value): + return f"""'{value.hex}'""" + + return process + else: + + def process(value): + return f"""'{ + value.replace("-", "").replace("'", "''") + }'""" + + return process + + +class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): + __visit_name__ = "UNIQUEIDENTIFIER" + + @overload + def __init__( + self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ... + ): ... + + @overload + def __init__( + self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ... + ): ... + + def __init__(self, as_uuid: bool = True): + """Construct a :class:`_mssql.UNIQUEIDENTIFIER` type. + + + :param as_uuid=True: if True, values will be interpreted + as Python uuid objects, converting to/from string via the + DBAPI. + + .. versionchanged: 2.0 Added direct "uuid" support to the + :class:`_mssql.UNIQUEIDENTIFIER` datatype; uuid interpretation + defaults to ``True``. + + """ + self.as_uuid = as_uuid + self.native_uuid = True + + +class SQL_VARIANT(sqltypes.TypeEngine): + __visit_name__ = "SQL_VARIANT" + + +# old names. +MSDateTime = _MSDateTime +MSDate = _MSDate +MSReal = REAL +MSTinyInteger = TINYINT +MSTime = TIME +MSSmallDateTime = SMALLDATETIME +MSDateTime2 = DATETIME2 +MSDateTimeOffset = DATETIMEOFFSET +MSText = TEXT +MSNText = NTEXT +MSString = VARCHAR +MSNVarchar = NVARCHAR +MSChar = CHAR +MSNChar = NCHAR +MSBinary = BINARY +MSVarBinary = VARBINARY +MSImage = IMAGE +MSBit = BIT +MSMoney = MONEY +MSSmallMoney = SMALLMONEY +MSUniqueIdentifier = UNIQUEIDENTIFIER +MSVariant = SQL_VARIANT + +ischema_names = { + "int": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "tinyint": TINYINT, + "varchar": VARCHAR, + "nvarchar": NVARCHAR, + "char": CHAR, + "nchar": NCHAR, + "text": TEXT, + "ntext": NTEXT, + "decimal": DECIMAL, + "numeric": NUMERIC, + "float": FLOAT, + "datetime": DATETIME, + "datetime2": DATETIME2, + "datetimeoffset": DATETIMEOFFSET, + "date": DATE, + "time": TIME, + "smalldatetime": SMALLDATETIME, + "binary": BINARY, + "varbinary": VARBINARY, + "bit": BIT, + "real": REAL, + "double precision": DOUBLE_PRECISION, + "image": IMAGE, + "xml": XML, + "timestamp": TIMESTAMP, + "money": MONEY, + "smallmoney": SMALLMONEY, + "uniqueidentifier": UNIQUEIDENTIFIER, + "sql_variant": SQL_VARIANT, +} + + +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_, length=None): + """Extend a string-type declaration with standard SQL + COLLATE annotations. + + """ + + if getattr(type_, "collation", None): + collation = "COLLATE %s" % type_.collation + else: + collation = None + + if not length: + length = type_.length + + if length: + spec = spec + "(%s)" % length + + return " ".join([c for c in (spec, collation) if c is not None]) + + def visit_double(self, type_, **kw): + return self.visit_DOUBLE_PRECISION(type_, **kw) + + def visit_FLOAT(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {"precision": precision} + + def visit_TINYINT(self, type_, **kw): + return "TINYINT" + + def visit_TIME(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_ROWVERSION(self, type_, **kw): + return "ROWVERSION" + + def visit_datetime(self, type_, **kw): + if type_.timezone: + return self.visit_DATETIMEOFFSET(type_, **kw) + else: + return self.visit_DATETIME(type_, **kw) + + def visit_DATETIMEOFFSET(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_DATETIME2(self, type_, **kw): + precision = getattr(type_, "precision", None) + if precision is not None: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_, **kw): + return "SMALLDATETIME" + + def visit_unicode(self, type_, **kw): + return self.visit_NVARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARCHAR(type_, **kw) + else: + return self.visit_TEXT(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_NVARCHAR(type_, **kw) + else: + return self.visit_NTEXT(type_, **kw) + + def visit_NTEXT(self, type_, **kw): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_, **kw): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_, **kw): + return self._extend("VARCHAR", type_, length=type_.length or "max") + + def visit_CHAR(self, type_, **kw): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_, **kw): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_, **kw): + return self._extend("NVARCHAR", type_, length=type_.length or "max") + + def visit_date(self, type_, **kw): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_, **kw) + else: + return self.visit_DATE(type_, **kw) + + def visit__BASETIMEIMPL(self, type_, **kw): + return self.visit_time(type_, **kw) + + def visit_time(self, type_, **kw): + if self.dialect.server_version_info < MS_2008_VERSION: + return self.visit_DATETIME(type_, **kw) + else: + return self.visit_TIME(type_, **kw) + + def visit_large_binary(self, type_, **kw): + if self.dialect.deprecate_large_types: + return self.visit_VARBINARY(type_, **kw) + else: + return self.visit_IMAGE(type_, **kw) + + def visit_IMAGE(self, type_, **kw): + return "IMAGE" + + def visit_XML(self, type_, **kw): + return "XML" + + def visit_VARBINARY(self, type_, **kw): + text = self._extend("VARBINARY", type_, length=type_.length or "max") + if getattr(type_, "filestream", False): + text += " FILESTREAM" + return text + + def visit_boolean(self, type_, **kw): + return self.visit_BIT(type_) + + def visit_BIT(self, type_, **kw): + return "BIT" + + def visit_JSON(self, type_, **kw): + # this is a bit of a break with SQLAlchemy's convention of + # "UPPERCASE name goes to UPPERCASE type name with no modification" + return self._extend("NVARCHAR", type_, length="max") + + def visit_MONEY(self, type_, **kw): + return "MONEY" + + def visit_SMALLMONEY(self, type_, **kw): + return "SMALLMONEY" + + def visit_uuid(self, type_, **kw): + if type_.native_uuid: + return self.visit_UNIQUEIDENTIFIER(type_, **kw) + else: + return super().visit_uuid(type_, **kw) + + def visit_UNIQUEIDENTIFIER(self, type_, **kw): + return "UNIQUEIDENTIFIER" + + def visit_SQL_VARIANT(self, type_, **kw): + return "SQL_VARIANT" + + +class MSExecutionContext(default.DefaultExecutionContext): + _enable_identity_insert = False + _select_lastrowid = False + _lastrowid = None + + dialect: MSDialect + + def _opt_encode(self, statement): + if self.compiled and self.compiled.schema_translate_map: + rst = self.compiled.preparer._render_schema_translates + statement = rst(statement, self.compiled.schema_translate_map) + + return statement + + def pre_exec(self): + """Activate IDENTITY_INSERT if needed.""" + + if self.isinsert: + if TYPE_CHECKING: + assert is_sql_compiler(self.compiled) + assert isinstance(self.compiled.compile_state, DMLState) + assert isinstance( + self.compiled.compile_state.dml_table, TableClause + ) + + tbl = self.compiled.compile_state.dml_table + id_column = tbl._autoincrement_column + + if id_column is not None and ( + not isinstance(id_column.default, Sequence) + ): + insert_has_identity = True + compile_state = self.compiled.dml_compile_state + self._enable_identity_insert = ( + id_column.key in self.compiled_parameters[0] + ) or ( + compile_state._dict_parameters + and (id_column.key in compile_state._insert_col_keys) + ) + + else: + insert_has_identity = False + self._enable_identity_insert = False + + self._select_lastrowid = ( + not self.compiled.inline + and insert_has_identity + and not self.compiled.effective_returning + and not self._enable_identity_insert + and not self.executemany + ) + + if self._enable_identity_insert: + self.root_connection._cursor_execute( + self.cursor, + self._opt_encode( + "SET IDENTITY_INSERT %s ON" + % self.identifier_preparer.format_table(tbl) + ), + (), + self, + ) + + def post_exec(self): + """Disable IDENTITY_INSERT if enabled.""" + + conn = self.root_connection + + if self.isinsert or self.isupdate or self.isdelete: + self._rowcount = self.cursor.rowcount + + if self._select_lastrowid: + if self.dialect.use_scope_identity: + conn._cursor_execute( + self.cursor, + "SELECT scope_identity() AS lastrowid", + (), + self, + ) + else: + conn._cursor_execute( + self.cursor, "SELECT @@identity AS lastrowid", (), self + ) + # fetchall() ensures the cursor is consumed without closing it + row = self.cursor.fetchall()[0] + self._lastrowid = int(row[0]) + + self.cursor_fetch_strategy = _cursor._NO_CURSOR_DML + elif ( + self.compiled is not None + and is_sql_compiler(self.compiled) + and self.compiled.effective_returning + ): + self.cursor_fetch_strategy = ( + _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + self.cursor.description, + self.cursor.fetchall(), + ) + ) + + if self._enable_identity_insert: + if TYPE_CHECKING: + assert is_sql_compiler(self.compiled) + assert isinstance(self.compiled.compile_state, DMLState) + assert isinstance( + self.compiled.compile_state.dml_table, TableClause + ) + conn._cursor_execute( + self.cursor, + self._opt_encode( + "SET IDENTITY_INSERT %s OFF" + % self.identifier_preparer.format_table( + self.compiled.compile_state.dml_table + ) + ), + (), + self, + ) + + def get_lastrowid(self): + return self._lastrowid + + def handle_dbapi_exception(self, e): + if self._enable_identity_insert: + try: + self.cursor.execute( + self._opt_encode( + "SET IDENTITY_INSERT %s OFF" + % self.identifier_preparer.format_table( + self.compiled.compile_state.dml_table + ) + ) + ) + except Exception: + pass + + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "SELECT NEXT VALUE FOR %s" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + def get_insert_default(self, column): + if ( + isinstance(column, sa_schema.Column) + and column is column.table._autoincrement_column + and isinstance(column.default, sa_schema.Sequence) + and column.default.optional + ): + return None + return super().get_insert_default(column) + + +class MSSQLCompiler(compiler.SQLCompiler): + returning_precedes_values = True + + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "microseconds": "microsecond", + }, + ) + + def __init__(self, *args, **kwargs): + self.tablealiases = {} + super().__init__(*args, **kwargs) + + def _with_legacy_schema_aliasing(fn): + def decorate(self, *arg, **kw): + if self.dialect.legacy_schema_aliasing: + return fn(self, *arg, **kw) + else: + super_ = getattr(super(MSSQLCompiler, self), fn.__name__) + return super_(*arg, **kw) + + return decorate + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_current_date_func(self, fn, **kw): + return "GETDATE()" + + def visit_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_char_length_func(self, fn, **kw): + return "LEN%s" % self.function_argspec(fn, **kw) + + def visit_aggregate_strings_func(self, fn, **kw): + expr = fn.clauses.clauses[0]._compiler_dispatch(self, **kw) + kw["literal_execute"] = True + delimeter = fn.clauses.clauses[1]._compiler_dispatch(self, **kw) + return f"string_agg({expr}, {delimeter})" + + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return " + ".join(self.process(elem, **kw) for elem in clauselist) + + def visit_concat_op_binary(self, binary, operator, **kw): + return "%s + %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def get_select_precolumns(self, select, **kw): + """MS-SQL puts TOP, it's version of LIMIT here""" + + s = super().get_select_precolumns(select, **kw) + + if select._has_row_limiting_clause and self._use_top(select): + # ODBC drivers and possibly others + # don't support bind params in the SELECT clause on SQL Server. + # so have to use literal here. + kw["literal_execute"] = True + s += "TOP %s " % self.process( + self._get_limit_or_fetch(select), **kw + ) + if select._fetch_clause is not None: + if select._fetch_clause_options["percent"]: + s += "PERCENT " + if select._fetch_clause_options["with_ties"]: + s += "WITH TIES " + + return s + + def get_from_hint_text(self, table, text): + return text + + def get_crud_hint_text(self, table, text): + return text + + def _get_limit_or_fetch(self, select): + if select._fetch_clause is None: + return select._limit_clause + else: + return select._fetch_clause + + def _use_top(self, select): + return (select._offset_clause is None) and ( + select._simple_int_clause(select._limit_clause) + or ( + # limit can use TOP with is by itself. fetch only uses TOP + # when it needs to because of PERCENT and/or WITH TIES + # TODO: Why? shouldn't we use TOP always ? + select._simple_int_clause(select._fetch_clause) + and ( + select._fetch_clause_options["percent"] + or select._fetch_clause_options["with_ties"] + ) + ) + ) + + def limit_clause(self, cs, **kwargs): + return "" + + def _check_can_use_fetch_limit(self, select): + # to use ROW_NUMBER(), an ORDER BY is required. + # OFFSET are FETCH are options of the ORDER BY clause + if not select._order_by_clause.clauses: + raise exc.CompileError( + "MSSQL requires an order_by when " + "using an OFFSET or a non-simple " + "LIMIT clause" + ) + + if select._fetch_clause_options is not None and ( + select._fetch_clause_options["percent"] + or select._fetch_clause_options["with_ties"] + ): + raise exc.CompileError( + "MSSQL needs TOP to use PERCENT and/or WITH TIES. " + "Only simple fetch without offset can be used." + ) + + def _row_limit_clause(self, select, **kw): + """MSSQL 2012 supports OFFSET/FETCH operators + Use it instead subquery with row_number + + """ + + if self.dialect._supports_offset_fetch and not self._use_top(select): + self._check_can_use_fetch_limit(select) + + return self.fetch_clause( + select, + fetch_clause=self._get_limit_or_fetch(select), + require_offset=True, + **kw, + ) + + else: + return "" + + def visit_try_cast(self, element, **kw): + return "TRY_CAST (%s AS %s)" % ( + self.process(element.clause, **kw), + self.process(element.typeclause, **kw), + ) + + def translate_select_structure(self, select_stmt, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + MSSQL 2012 and above are excluded + + """ + select = select_stmt + + if ( + select._has_row_limiting_clause + and not self.dialect._supports_offset_fetch + and not self._use_top(select) + and not getattr(select, "_mssql_visit", None) + ): + self._check_can_use_fetch_limit(select) + + _order_by_clauses = [ + sql_util.unwrap_label_reference(elem) + for elem in select._order_by_clause.clauses + ] + + limit_clause = self._get_limit_or_fetch(select) + offset_clause = select._offset_clause + + select = select._generate() + select._mssql_visit = True + select = ( + select.add_columns( + sql.func.ROW_NUMBER() + .over(order_by=_order_by_clauses) + .label("mssql_rn") + ) + .order_by(None) + .alias() + ) + + mssql_rn = sql.column("mssql_rn") + limitselect = sql.select( + *[c for c in select.c if c.key != "mssql_rn"] + ) + if offset_clause is not None: + limitselect = limitselect.where(mssql_rn > offset_clause) + if limit_clause is not None: + limitselect = limitselect.where( + mssql_rn <= (limit_clause + offset_clause) + ) + else: + limitselect = limitselect.where(mssql_rn <= (limit_clause)) + return limitselect + else: + return select + + @_with_legacy_schema_aliasing + def visit_table(self, table, mssql_aliased=False, iscrud=False, **kwargs): + if mssql_aliased is table or iscrud: + return super().visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=table, **kwargs) + else: + return super().visit_table(table, **kwargs) + + @_with_legacy_schema_aliasing + def visit_alias(self, alias, **kw): + # translate for schema-qualified table aliases + kw["mssql_aliased"] = alias.element + return super().visit_alias(alias, **kw) + + @_with_legacy_schema_aliasing + def visit_column(self, column, add_to_result_map=None, **kw): + if ( + column.table is not None + and (not self.isupdate and not self.isdelete) + or self.is_subquery() + ): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = elements._corresponding_column_or_error(t, column) + if add_to_result_map is not None: + add_to_result_map( + column.name, + column.name, + (column, column.name, column.key), + column.type, + ) + + return super().visit_column(converted, **kw) + + return super().visit_column( + column, add_to_result_map=add_to_result_map, **kw + ) + + def _schema_aliased_table(self, table): + if getattr(table, "schema", None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_extract(self, extract, **kw): + field = self.extract_map.get(extract.field, extract.field) + return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) + + def visit_savepoint(self, savepoint_stmt, **kw): + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_rollback_to_savepoint(self, savepoint_stmt, **kw): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if ( + isinstance(binary.left, expression.BindParameter) + and binary.operator == operator.eq + and not isinstance(binary.right, expression.BindParameter) + ): + return self.process( + expression.BinaryExpression( + binary.right, binary.left, binary.operator + ), + **kwargs, + ) + return super().visit_binary(binary, **kwargs) + + def returning_clause( + self, stmt, returning_cols, *, populate_result_map, **kw + ): + # SQL server returning clause requires that the columns refer to + # the virtual table names "inserted" or "deleted". Here, we make + # a simple alias of our table with that name, and then adapt the + # columns we have from the list of RETURNING columns to that new name + # so that they render as "inserted." / "deleted.". + + if stmt.is_insert or stmt.is_update: + target = stmt.table.alias("inserted") + elif stmt.is_delete: + target = stmt.table.alias("deleted") + else: + assert False, "expected Insert, Update or Delete statement" + + adapter = sql_util.ClauseAdapter(target) + + # adapter.traverse() takes a column from our target table and returns + # the one that is linked to the "inserted" / "deleted" tables. So in + # order to retrieve these values back from the result (e.g. like + # row[column]), tell the compiler to also add the original unadapted + # column to the result map. Before #4877, these were (unknowingly) + # falling back using string name matching in the result set which + # necessarily used an expensive KeyError in order to match. + + columns = [ + self._label_returning_column( + stmt, + adapter.traverse(column), + populate_result_map, + {"result_map_targets": (column,)}, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + name=name, + proxy_name=proxy_name, + **kw, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in stmt._generate_columns_plus_names( + True, cols=expression._select_iterables(returning_cols) + ) + ] + + return "OUTPUT " + ", ".join(columns) + + def get_cte_preamble(self, recursive): + # SQL Server finds it too inconvenient to accept + # an entirely optional, SQL standard specified, + # "RECURSIVE" word with their "WITH", + # so here we go + return "WITH" + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super().label_select_column(select, column, asfrom) + + def for_update_clause(self, select, **kw): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which + # SQLAlchemy doesn't use + return "" + + def order_by_clause(self, select, **kw): + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT: + # "The ORDER BY clause is invalid in views, inline functions, + # derived tables, subqueries, and common table expressions, + # unless TOP, OFFSET or FOR XML is also specified." + if ( + self.is_subquery() + and not self._use_top(select) + and ( + select._offset is None + or not self.dialect._supports_offset_fetch + ) + ): + # avoid processing the order by clause if we won't end up + # using it, because we don't want all the bind params tacked + # onto the positional list if that is what the dbapi requires + return "" + + order_by = self.process(select._order_by_clause, **kw) + + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the UPDATE..FROM clause specific to MSSQL. + + In MSSQL, if the UPDATE statement involves an alias of the table to + be updated, then the table itself must be added to the FROM list as + well. Otherwise, it is optional. Here, we add it regardless. + + """ + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + """If we have extra froms make sure we render any alias as hint.""" + ashint = False + if extra_froms: + ashint = True + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, ashint=ashint, **kw + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. FROM clause specific to MSSQL. + + Yes, it has the FROM keyword twice. + + """ + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def visit_empty_set_expr(self, type_, **kw): + return "SELECT 1 WHERE 1!=1" + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "NOT EXISTS (SELECT %s INTERSECT SELECT %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "EXISTS (SELECT %s INTERSECT SELECT %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def _render_json_extract_from_binary(self, binary, operator, **kw): + # note we are intentionally calling upon the process() calls in the + # order in which they appear in the SQL String as this is used + # by positional parameter rendering + + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + # as with other dialects, start with an explicit test for NULL + case_expression = "CASE JSON_VALUE(%s, %s) WHEN NULL THEN NULL" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + if binary.type._type_affinity is sqltypes.Integer: + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS INTEGER)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + elif binary.type._type_affinity is sqltypes.Numeric: + type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ( + "FLOAT" + if isinstance(binary.type, sqltypes.Float) + else "NUMERIC(%s, %s)" + % (binary.type.precision, binary.type.scale) + ), + ) + elif binary.type._type_affinity is sqltypes.Boolean: + # the NULL handling is particularly weird with boolean, so + # explicitly return numeric (BIT) constants + type_expression = ( + "WHEN 'true' THEN 1 WHEN 'false' THEN 0 ELSE NULL" + ) + elif binary.type._type_affinity is sqltypes.String: + # TODO: does this comment (from mysql) apply to here, too? + # this fails with a JSON value that's a four byte unicode + # string. SQLite has the same problem at the moment + type_expression = "ELSE JSON_VALUE(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + else: + # other affinity....this is not expected right now + type_expression = "ELSE JSON_QUERY(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + return case_expression + " " + type_expression + " END" + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_sequence(self, seq, **kw): + return "NEXT VALUE FOR %s" % self.preparer.format_sequence(seq) + + +class MSSQLStrictCompiler(MSSQLCompiler): + """A subclass of MSSQLCompiler which disables the usage of bind + parameters where not allowed natively by MS-SQL. + + A dialect may use this compiler on a platform where native + binds are used. + + """ + + ansi_bind_rules = True + + def visit_in_op_binary(self, binary, operator, **kw): + kw["literal_execute"] = True + return "%s IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_not_in_op_binary(self, binary, operator, **kw): + kw["literal_execute"] = True + return "%s NOT IN %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def render_literal_value(self, value, type_): + """ + For date and datetime values, convert to a string + format acceptable to MSSQL. That seems to be the + so-called ODBC canonical date format which looks + like this: + + yyyy-mm-dd hh:mi:ss.mmm(24h) + + For other data types, call the base class implementation. + """ + # datetime and date are both subclasses of datetime.date + if issubclass(type(value), datetime.date): + # SQL Server wants single quotes around the date string. + return "'" + str(value) + "'" + else: + return super().render_literal_value(value, type_) + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + + # type is not accepted in a computed column + if column.computed is not None: + colspec += " " + self.process(column.computed) + else: + colspec += " " + self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ) + + if column.nullable is not None: + if ( + not column.nullable + or column.primary_key + or isinstance(column.default, sa_schema.Sequence) + or column.autoincrement is True + or column.identity + ): + colspec += " NOT NULL" + elif column.computed is None: + # don't specify "NULL" for computed columns + colspec += " NULL" + + if column.table is None: + raise exc.CompileError( + "mssql requires Table-bound columns " + "in order to generate DDL" + ) + + d_opt = column.dialect_options["mssql"] + start = d_opt["identity_start"] + increment = d_opt["identity_increment"] + if start is not None or increment is not None: + if column.identity: + raise exc.CompileError( + "Cannot specify options 'mssql_identity_start' and/or " + "'mssql_identity_increment' while also using the " + "'Identity' construct." + ) + util.warn_deprecated( + "The dialect options 'mssql_identity_start' and " + "'mssql_identity_increment' are deprecated. " + "Use the 'Identity' object instead.", + "1.4", + ) + + if column.identity: + colspec += self.process(column.identity, **kwargs) + elif ( + column is column.table._autoincrement_column + or column.autoincrement is True + ) and ( + not isinstance(column.default, Sequence) or column.default.optional + ): + colspec += self.process(Identity(start=start, increment=increment)) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_create_index(self, create, include_schema=False, **kw): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + + # handle clustering option + clustered = index.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + # handle columnstore option (has no negative value) + columnstore = index.dialect_options["mssql"]["columnstore"] + if columnstore: + text += "COLUMNSTORE " + + text += "INDEX %s ON %s" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table(index.table), + ) + + # in some case mssql allows indexes with no columns defined + if len(index.expressions) > 0: + text += " (%s)" % ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ) + + # handle other included columns + if index.dialect_options["mssql"]["include"]: + inclusions = [ + index.table.c[col] if isinstance(col, str) else col + for col in index.dialect_options["mssql"]["include"] + ] + + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) + + whereclause = index.dialect_options["mssql"]["where"] + + if whereclause is not None: + whereclause = coercions.expect( + roles.DDLExpressionRole, whereclause + ) + + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def visit_drop_index(self, drop, **kw): + return "\nDROP INDEX %s ON %s" % ( + self._prepared_index_name(drop.element, include_schema=False), + self.preparer.format_table(drop.element.table), + ) + + def visit_primary_key_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) + text += "PRIMARY KEY " + + clustered = constraint.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_unique_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "UNIQUE %s" % self.define_unique_constraint_distinct( + constraint, **kw + ) + clustered = constraint.dialect_options["mssql"]["clustered"] + if clustered is not None: + if clustered: + text += "CLUSTERED " + else: + text += "NONCLUSTERED " + + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_computed_column(self, generated, **kw): + text = "AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + # explicitly check for True|False since None means server default + if generated.persisted is True: + text += " PERSISTED" + return text + + def visit_set_table_comment(self, create, **kw): + schema = self.preparer.schema_for_object(create.element) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_addextendedproperty 'MS_Description', " + "{}, 'schema', {}, 'table', {}".format( + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.NVARCHAR() + ), + self.preparer.quote_schema(schema_name), + self.preparer.format_table(create.element, use_schema=False), + ) + ) + + def visit_drop_table_comment(self, drop, **kw): + schema = self.preparer.schema_for_object(drop.element) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_dropextendedproperty 'MS_Description', 'schema', " + "{}, 'table', {}".format( + self.preparer.quote_schema(schema_name), + self.preparer.format_table(drop.element, use_schema=False), + ) + ) + + def visit_set_column_comment(self, create, **kw): + schema = self.preparer.schema_for_object(create.element.table) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_addextendedproperty 'MS_Description', " + "{}, 'schema', {}, 'table', {}, 'column', {}".format( + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.NVARCHAR() + ), + self.preparer.quote_schema(schema_name), + self.preparer.format_table( + create.element.table, use_schema=False + ), + self.preparer.format_column(create.element), + ) + ) + + def visit_drop_column_comment(self, drop, **kw): + schema = self.preparer.schema_for_object(drop.element.table) + schema_name = schema if schema else self.dialect.default_schema_name + return ( + "execute sp_dropextendedproperty 'MS_Description', 'schema', " + "{}, 'table', {}, 'column', {}".format( + self.preparer.quote_schema(schema_name), + self.preparer.format_table( + drop.element.table, use_schema=False + ), + self.preparer.format_column(drop.element), + ) + ) + + def visit_create_sequence(self, create, **kw): + prefix = None + if create.element.data_type is not None: + data_type = create.element.data_type + prefix = " AS %s" % self.type_compiler.process(data_type) + return super().visit_create_sequence(create, prefix=prefix, **kw) + + def visit_identity_column(self, identity, **kw): + text = " IDENTITY" + if identity.start is not None or identity.increment is not None: + start = 1 if identity.start is None else identity.start + increment = 1 if identity.increment is None else identity.increment + text += "(%s,%s)" % (start, increment) + return text + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def __init__(self, dialect): + super().__init__( + dialect, + initial_quote="[", + final_quote="]", + quote_case_sensitive_collations=False, + ) + + def _escape_identifier(self, value): + return value.replace("]", "]]") + + def _unescape_identifier(self, value): + return value.replace("]]", "]") + + def quote_schema(self, schema, force=None): + """Prepare a quoted table and schema name.""" + + # need to re-implement the deprecation warning entirely + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote_schema.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + version="1.3", + ) + + dbname, owner = _schema_elements(schema) + if dbname: + result = "%s.%s" % (self.quote(dbname), self.quote(owner)) + elif owner: + result = self.quote(owner) + else: + result = "" + return result + + +def _db_plus_owner_listing(fn): + def wrap(dialect, connection, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw, + ) + + return update_wrapper(wrap, fn) + + +def _db_plus_owner(fn): + def wrap(dialect, connection, tablename, schema=None, **kw): + dbname, owner = _owner_plus_db(dialect, schema) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + tablename, + dbname, + owner, + schema, + **kw, + ) + + return update_wrapper(wrap, fn) + + +def _switch_db(dbname, connection, fn, *arg, **kw): + if dbname: + current_db = connection.exec_driver_sql("select db_name()").scalar() + if current_db != dbname: + connection.exec_driver_sql( + "use %s" % connection.dialect.identifier_preparer.quote(dbname) + ) + try: + return fn(*arg, **kw) + finally: + if dbname and current_db != dbname: + connection.exec_driver_sql( + "use %s" + % connection.dialect.identifier_preparer.quote(current_db) + ) + + +def _owner_plus_db(dialect, schema): + if not schema: + return None, dialect.default_schema_name + else: + return _schema_elements(schema) + + +_memoized_schema = util.LRUCache() + + +def _schema_elements(schema): + if isinstance(schema, quoted_name) and schema.quote: + return None, schema + + if schema in _memoized_schema: + return _memoized_schema[schema] + + # tests for this function are in: + # test/dialect/mssql/test_reflection.py -> + # OwnerPlusDBTest.test_owner_database_pairs + # test/dialect/mssql/test_compiler.py -> test_force_schema_* + # test/dialect/mssql/test_compiler.py -> test_schema_many_tokens_* + # + + if schema.startswith("__[SCHEMA_"): + return None, schema + + push = [] + symbol = "" + bracket = False + has_brackets = False + for token in re.split(r"(\[|\]|\.)", schema): + if not token: + continue + if token == "[": + bracket = True + has_brackets = True + elif token == "]": + bracket = False + elif not bracket and token == ".": + if has_brackets: + push.append("[%s]" % symbol) + else: + push.append(symbol) + symbol = "" + has_brackets = False + else: + symbol += token + if symbol: + push.append(symbol) + if len(push) > 1: + dbname, owner = ".".join(push[0:-1]), push[-1] + + # test for internal brackets + if re.match(r".*\].*\[.*", dbname[1:-1]): + dbname = quoted_name(dbname, quote=False) + else: + dbname = dbname.lstrip("[").rstrip("]") + + elif len(push): + dbname, owner = None, push[0] + else: + dbname, owner = None, None + + _memoized_schema[schema] = dbname, owner + return dbname, owner + + +class MSDialect(default.DefaultDialect): + # will assume it's at least mssql2005 + name = "mssql" + supports_statement_cache = True + supports_default_values = True + supports_empty_insert = False + favor_returning_over_lastrowid = True + + returns_native_bytes = True + + supports_comments = True + supports_default_metavalue = False + """dialect supports INSERT... VALUES (DEFAULT) syntax - + SQL Server **does** support this, but **not** for the IDENTITY column, + so we can't turn this on. + + """ + + # supports_native_uuid is partial here, so we implement our + # own impl type + + execution_ctx_cls = MSExecutionContext + use_scope_identity = True + max_identifier_length = 128 + schema_name = "dbo" + + insert_returning = True + update_returning = True + delete_returning = True + update_returning_multifrom = True + delete_returning_multifrom = True + + colspecs = { + sqltypes.DateTime: _MSDateTime, + sqltypes.Date: _MSDate, + sqltypes.JSON: JSON, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, + sqltypes.Time: _BASETIMEIMPL, + sqltypes.Unicode: _MSUnicode, + sqltypes.UnicodeText: _MSUnicodeText, + DATETIMEOFFSET: DATETIMEOFFSET, + DATETIME2: DATETIME2, + SMALLDATETIME: SMALLDATETIME, + DATETIME: DATETIME, + sqltypes.Uuid: MSUUid, + } + + engine_config_types = default.DefaultDialect.engine_config_types.union( + {"legacy_schema_aliasing": util.asbool} + ) + + ischema_names = ischema_names + + supports_sequences = True + sequences_optional = True + # This is actually used for autoincrement, where itentity is used that + # starts with 1. + # for sequences T-SQL's actual default is -9223372036854775808 + default_sequence_base = 1 + + supports_native_boolean = False + non_native_boolean_check_constraint = False + supports_unicode_binds = True + postfetch_lastrowid = True + + # may be changed at server inspection time for older SQL server versions + supports_multivalues_insert = True + + use_insertmanyvalues = True + + # note pyodbc will set this to False if fast_executemany is set, + # as of SQLAlchemy 2.0.9 + use_insertmanyvalues_wo_returning = True + + insertmanyvalues_implicit_sentinel = ( + InsertmanyvaluesSentinelOpts.AUTOINCREMENT + | InsertmanyvaluesSentinelOpts.IDENTITY + | InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT + ) + + # "The incoming request has too many parameters. The server supports a " + # "maximum of 2100 parameters." + # in fact you can have 2099 parameters. + insertmanyvalues_max_parameters = 2099 + + _supports_offset_fetch = False + _supports_nvarchar_max = False + + legacy_schema_aliasing = False + + server_version_info = () + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler_cls = MSTypeCompiler + preparer = MSIdentifierPreparer + + construct_arguments = [ + (sa_schema.PrimaryKeyConstraint, {"clustered": None}), + (sa_schema.UniqueConstraint, {"clustered": None}), + ( + sa_schema.Index, + { + "clustered": None, + "include": None, + "where": None, + "columnstore": None, + }, + ), + ( + sa_schema.Column, + {"identity_start": None, "identity_increment": None}, + ), + ] + + def __init__( + self, + query_timeout=None, + use_scope_identity=True, + schema_name="dbo", + deprecate_large_types=None, + supports_comments=None, + json_serializer=None, + json_deserializer=None, + legacy_schema_aliasing=None, + ignore_no_transaction_on_rollback=False, + **opts, + ): + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = use_scope_identity + self.deprecate_large_types = deprecate_large_types + self.ignore_no_transaction_on_rollback = ( + ignore_no_transaction_on_rollback + ) + self._user_defined_supports_comments = uds = supports_comments + if uds is not None: + self.supports_comments = uds + + if legacy_schema_aliasing is not None: + util.warn_deprecated( + "The legacy_schema_aliasing parameter is " + "deprecated and will be removed in a future release.", + "1.4", + ) + self.legacy_schema_aliasing = legacy_schema_aliasing + + super().__init__(**opts) + + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + + def do_savepoint(self, connection, name): + # give the DBAPI a push + connection.exec_driver_sql("IF @@TRANCOUNT = 0 BEGIN TRANSACTION") + super().do_savepoint(connection, name) + + def do_release_savepoint(self, connection, name): + # SQL Server does not support RELEASE SAVEPOINT + pass + + def do_rollback(self, dbapi_connection): + try: + super().do_rollback(dbapi_connection) + except self.dbapi.ProgrammingError as e: + if self.ignore_no_transaction_on_rollback and re.match( + r".*\b111214\b", str(e) + ): + util.warn( + "ProgrammingError 111214 " + "'No corresponding transaction found.' " + "has been suppressed via " + "ignore_no_transaction_on_rollback=True" + ) + else: + raise + + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + } + + def get_isolation_level_values(self, dbapi_connection): + return list(self._isolation_lookup) + + def set_isolation_level(self, dbapi_connection, level): + cursor = dbapi_connection.cursor() + cursor.execute(f"SET TRANSACTION ISOLATION LEVEL {level}") + cursor.close() + if level == "SNAPSHOT": + dbapi_connection.commit() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + view_name = "sys.system_views" + try: + cursor.execute( + ( + "SELECT name FROM {} WHERE name IN " + "('dm_exec_sessions', 'dm_pdw_nodes_exec_sessions')" + ).format(view_name) + ) + row = cursor.fetchone() + if not row: + raise NotImplementedError( + "Can't fetch isolation level on this particular " + "SQL Server version." + ) + + view_name = f"sys.{row[0]}" + + cursor.execute( + """ + SELECT CASE transaction_isolation_level + WHEN 0 THEN NULL + WHEN 1 THEN 'READ UNCOMMITTED' + WHEN 2 THEN 'READ COMMITTED' + WHEN 3 THEN 'REPEATABLE READ' + WHEN 4 THEN 'SERIALIZABLE' + WHEN 5 THEN 'SNAPSHOT' END + AS TRANSACTION_ISOLATION_LEVEL + FROM {} + where session_id = @@SPID + """.format( + view_name + ) + ) + except self.dbapi.Error as err: + raise NotImplementedError( + "Can't fetch isolation level; encountered error {} when " + 'attempting to query the "{}" view.'.format(err, view_name) + ) from err + else: + row = cursor.fetchone() + return row[0].upper() + finally: + cursor.close() + + def initialize(self, connection): + super().initialize(connection) + self._setup_version_attributes() + self._setup_supports_nvarchar_max(connection) + self._setup_supports_comments(connection) + + def _setup_version_attributes(self): + if self.server_version_info[0] not in list(range(8, 17)): + util.warn( + "Unrecognized server version info '%s'. Some SQL Server " + "features may not function properly." + % ".".join(str(x) for x in self.server_version_info) + ) + + if self.server_version_info >= MS_2008_VERSION: + self.supports_multivalues_insert = True + else: + self.supports_multivalues_insert = False + + if self.deprecate_large_types is None: + self.deprecate_large_types = ( + self.server_version_info >= MS_2012_VERSION + ) + + self._supports_offset_fetch = ( + self.server_version_info and self.server_version_info[0] >= 11 + ) + + def _setup_supports_nvarchar_max(self, connection): + try: + connection.scalar( + sql.text("SELECT CAST('test max support' AS NVARCHAR(max))") + ) + except exc.DBAPIError: + self._supports_nvarchar_max = False + else: + self._supports_nvarchar_max = True + + def _setup_supports_comments(self, connection): + if self._user_defined_supports_comments is not None: + return + + try: + connection.scalar( + sql.text( + "SELECT 1 FROM fn_listextendedproperty" + "(default, default, default, default, " + "default, default, default)" + ) + ) + except exc.DBAPIError: + self.supports_comments = False + else: + self.supports_comments = True + + def _get_default_schema_name(self, connection): + query = sql.text("SELECT schema_name()") + default_schema_name = connection.scalar(query) + if default_schema_name is not None: + # guard against the case where the default_schema_name is being + # fed back into a table reflection function. + return quoted_name(default_schema_name, quote=True) + else: + return self.schema_name + + @_db_plus_owner + def has_table(self, connection, tablename, dbname, owner, schema, **kw): + self._ensure_has_table_connection(connection) + + return self._internal_has_table(connection, tablename, owner, **kw) + + @reflection.cache + @_db_plus_owner + def has_sequence( + self, connection, sequencename, dbname, owner, schema, **kw + ): + sequences = ischema.sequences + + s = sql.select(sequences.c.sequence_name).where( + sequences.c.sequence_name == sequencename + ) + + if owner: + s = s.where(sequences.c.sequence_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + @reflection.cache + @_db_plus_owner_listing + def get_sequence_names(self, connection, dbname, owner, schema, **kw): + sequences = ischema.sequences + + s = sql.select(sequences.c.sequence_name) + if owner: + s = s.where(sequences.c.sequence_schema == owner) + + c = connection.execute(s) + + return [row[0] for row in c] + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = sql.select(ischema.schemata.c.schema_name).order_by( + ischema.schemata.c.schema_name + ) + schema_names = [r[0] for r in connection.execute(s)] + return schema_names + + @reflection.cache + @_db_plus_owner_listing + def get_table_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = ( + sql.select(tables.c.table_name) + .where( + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "BASE TABLE", + ) + ) + .order_by(tables.c.table_name) + ) + table_names = [r[0] for r in connection.execute(s)] + return table_names + + @reflection.cache + @_db_plus_owner_listing + def get_view_names(self, connection, dbname, owner, schema, **kw): + tables = ischema.tables + s = ( + sql.select(tables.c.table_name) + .where( + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "VIEW", + ) + ) + .order_by(tables.c.table_name) + ) + view_names = [r[0] for r in connection.execute(s)] + return view_names + + @reflection.cache + def _internal_has_table(self, connection, tablename, owner, **kw): + if tablename.startswith("#"): # temporary table + # mssql does not support temporary views + # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed + return bool( + connection.scalar( + # U filters on user tables only. + text("SELECT object_id(:table_name, 'U')"), + {"table_name": f"tempdb.dbo.[{tablename}]"}, + ) + ) + else: + tables = ischema.tables + + s = sql.select(tables.c.table_name).where( + sql.and_( + sql.or_( + tables.c.table_type == "BASE TABLE", + tables.c.table_type == "VIEW", + ), + tables.c.table_name == tablename, + ) + ) + + if owner: + s = s.where(tables.c.table_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + def _default_or_error(self, connection, tablename, owner, method, **kw): + # TODO: try to avoid having to run a separate query here + if self._internal_has_table(connection, tablename, owner, **kw): + return method() + else: + raise exc.NoSuchTableError(f"{owner}.{tablename}") + + @reflection.cache + @_db_plus_owner + def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): + filter_definition = ( + "ind.filter_definition" + if self.server_version_info >= MS_2008_VERSION + else "NULL as filter_definition" + ) + rp = connection.execution_options(future_result=True).execute( + sql.text( + f""" +select + ind.index_id, + ind.is_unique, + ind.name, + ind.type, + {filter_definition} +from + sys.indexes as ind +join sys.tables as tab on + ind.object_id = tab.object_id +join sys.schemas as sch on + sch.schema_id = tab.schema_id +where + tab.name = :tabname + and sch.name = :schname + and ind.is_primary_key = 0 + and ind.type != 0 +order by + ind.name + """ + ) + .bindparams( + sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + .columns(name=sqltypes.Unicode()) + ) + indexes = {} + for row in rp.mappings(): + indexes[row["index_id"]] = current = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], + "include_columns": [], + "dialect_options": {}, + } + + do = current["dialect_options"] + index_type = row["type"] + if index_type in {1, 2}: + do["mssql_clustered"] = index_type == 1 + if index_type in {5, 6}: + do["mssql_clustered"] = index_type == 5 + do["mssql_columnstore"] = True + if row["filter_definition"] is not None: + do["mssql_where"] = row["filter_definition"] + + rp = connection.execution_options(future_result=True).execute( + sql.text( + """ +select + ind_col.index_id, + col.name, + ind_col.is_included_column +from + sys.columns as col +join sys.tables as tab on + tab.object_id = col.object_id +join sys.index_columns as ind_col on + ind_col.column_id = col.column_id + and ind_col.object_id = tab.object_id +join sys.schemas as sch on + sch.schema_id = tab.schema_id +where + tab.name = :tabname + and sch.name = :schname + """ + ) + .bindparams( + sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + .columns(name=sqltypes.Unicode()) + ) + for row in rp.mappings(): + if row["index_id"] not in indexes: + continue + index_def = indexes[row["index_id"]] + is_colstore = index_def["dialect_options"].get("mssql_columnstore") + is_clustered = index_def["dialect_options"].get("mssql_clustered") + if not (is_colstore and is_clustered): + # a clustered columnstore index includes all columns but does + # not want them in the index definition + if row["is_included_column"] and not is_colstore: + # a noncludsted columnstore index reports that includes + # columns but requires that are listed as normal columns + index_def["include_columns"].append(row["name"]) + else: + index_def["column_names"].append(row["name"]) + for index_info in indexes.values(): + # NOTE: "root level" include_columns is legacy, now part of + # dialect_options (issue #7382) + index_info["dialect_options"]["mssql_include"] = index_info[ + "include_columns" + ] + + if indexes: + return list(indexes.values()) + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.indexes, **kw + ) + + @reflection.cache + @_db_plus_owner + def get_view_definition( + self, connection, viewname, dbname, owner, schema, **kw + ): + view_def = connection.execute( + sql.text( + "select mod.definition " + "from sys.sql_modules as mod " + "join sys.views as views on mod.object_id = views.object_id " + "join sys.schemas as sch on views.schema_id = sch.schema_id " + "where views.name=:viewname and sch.name=:schname" + ).bindparams( + sql.bindparam("viewname", viewname, ischema.CoerceUnicode()), + sql.bindparam("schname", owner, ischema.CoerceUnicode()), + ) + ).scalar() + if view_def: + return view_def + else: + raise exc.NoSuchTableError(f"{owner}.{viewname}") + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + if not self.supports_comments: + raise NotImplementedError( + "Can't get table comments on current SQL Server version in use" + ) + + schema_name = schema if schema else self.default_schema_name + COMMENT_SQL = """ + SELECT cast(com.value as nvarchar(max)) + FROM fn_listextendedproperty('MS_Description', + 'schema', :schema, 'table', :table, NULL, NULL + ) as com; + """ + + comment = connection.execute( + sql.text(COMMENT_SQL).bindparams( + sql.bindparam("schema", schema_name, ischema.CoerceUnicode()), + sql.bindparam("table", table_name, ischema.CoerceUnicode()), + ) + ).scalar() + if comment: + return {"text": comment} + else: + return self._default_or_error( + connection, + table_name, + None, + ReflectionDefaults.table_comment, + **kw, + ) + + def _temp_table_name_like_pattern(self, tablename): + # LIKE uses '%' to match zero or more characters and '_' to match any + # single character. We want to match literal underscores, so T-SQL + # requires that we enclose them in square brackets. + return tablename + ( + ("[_][_][_]%") if not tablename.startswith("##") else "" + ) + + def _get_internal_temp_table_name(self, connection, tablename): + # it's likely that schema is always "dbo", but since we can + # get it here, let's get it. + # see https://stackoverflow.com/questions/8311959/ + # specifying-schema-for-temporary-tables + + try: + return connection.execute( + sql.text( + "select table_schema, table_name " + "from tempdb.information_schema.tables " + "where table_name like :p1" + ), + {"p1": self._temp_table_name_like_pattern(tablename)}, + ).one() + except exc.MultipleResultsFound as me: + raise exc.UnreflectableTableError( + "Found more than one temporary table named '%s' in tempdb " + "at this time. Cannot reliably resolve that name to its " + "internal table name." % tablename + ) from me + except exc.NoResultFound as ne: + raise exc.NoSuchTableError( + "Unable to find a temporary table named '%s' in tempdb." + % tablename + ) from ne + + @reflection.cache + @_db_plus_owner + def get_columns(self, connection, tablename, dbname, owner, schema, **kw): + is_temp_table = tablename.startswith("#") + if is_temp_table: + owner, tablename = self._get_internal_temp_table_name( + connection, tablename + ) + + columns = ischema.mssql_temp_table_columns + else: + columns = ischema.columns + + computed_cols = ischema.computed_columns + identity_cols = ischema.identity_columns + if owner: + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) + full_name = columns.c.table_schema + "." + columns.c.table_name + else: + whereclause = columns.c.table_name == tablename + full_name = columns.c.table_name + + if self._supports_nvarchar_max: + computed_definition = computed_cols.c.definition + else: + # tds_version 4.2 does not support NVARCHAR(MAX) + computed_definition = sql.cast( + computed_cols.c.definition, NVARCHAR(4000) + ) + + object_id = func.object_id(full_name) + + s = ( + sql.select( + columns.c.column_name, + columns.c.data_type, + columns.c.is_nullable, + columns.c.character_maximum_length, + columns.c.numeric_precision, + columns.c.numeric_scale, + columns.c.column_default, + columns.c.collation_name, + computed_definition, + computed_cols.c.is_persisted, + identity_cols.c.is_identity, + identity_cols.c.seed_value, + identity_cols.c.increment_value, + ischema.extended_properties.c.value.label("comment"), + ) + .select_from(columns) + .outerjoin( + computed_cols, + onclause=sql.and_( + computed_cols.c.object_id == object_id, + computed_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + ) + .outerjoin( + identity_cols, + onclause=sql.and_( + identity_cols.c.object_id == object_id, + identity_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), + ), + ) + .outerjoin( + ischema.extended_properties, + onclause=sql.and_( + ischema.extended_properties.c["class"] == 1, + ischema.extended_properties.c.major_id == object_id, + ischema.extended_properties.c.minor_id + == columns.c.ordinal_position, + ischema.extended_properties.c.name == "MS_Description", + ), + ) + .where(whereclause) + .order_by(columns.c.ordinal_position) + ) + + c = connection.execution_options(future_result=True).execute(s) + + cols = [] + for row in c.mappings(): + name = row[columns.c.column_name] + type_ = row[columns.c.data_type] + nullable = row[columns.c.is_nullable] == "YES" + charlen = row[columns.c.character_maximum_length] + numericprec = row[columns.c.numeric_precision] + numericscale = row[columns.c.numeric_scale] + default = row[columns.c.column_default] + collation = row[columns.c.collation_name] + definition = row[computed_definition] + is_persisted = row[computed_cols.c.is_persisted] + is_identity = row[identity_cols.c.is_identity] + identity_start = row[identity_cols.c.seed_value] + identity_increment = row[identity_cols.c.increment_value] + comment = row[ischema.extended_properties.c.value] + + coltype = self.ischema_names.get(type_, None) + + kwargs = {} + if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, + MSBinary, + MSVarBinary, + sqltypes.LargeBinary, + ): + if charlen == -1: + charlen = None + kwargs["length"] = charlen + if collation: + kwargs["collation"] = collation + + if coltype is None: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (type_, name) + ) + coltype = sqltypes.NULLTYPE + else: + if issubclass(coltype, sqltypes.Numeric): + kwargs["precision"] = numericprec + + if not issubclass(coltype, sqltypes.Float): + kwargs["scale"] = numericscale + + coltype = coltype(**kwargs) + cdict = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": is_identity is not None, + "comment": comment, + } + + if definition is not None and is_persisted is not None: + cdict["computed"] = { + "sqltext": definition, + "persisted": is_persisted, + } + + if is_identity is not None: + # identity_start and identity_increment are Decimal or None + if identity_start is None or identity_increment is None: + cdict["identity"] = {} + else: + if isinstance(coltype, sqltypes.BigInteger): + start = int(identity_start) + increment = int(identity_increment) + elif isinstance(coltype, sqltypes.Integer): + start = int(identity_start) + increment = int(identity_increment) + else: + start = identity_start + increment = identity_increment + + cdict["identity"] = { + "start": start, + "increment": increment, + } + + cols.append(cdict) + + if cols: + return cols + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.columns, **kw + ) + + @reflection.cache + @_db_plus_owner + def get_pk_constraint( + self, connection, tablename, dbname, owner, schema, **kw + ): + pkeys = [] + TC = ischema.constraints + C = ischema.key_constraints.alias("C") + + # Primary key constraints + s = ( + sql.select( + C.c.column_name, + TC.c.constraint_type, + C.c.constraint_name, + func.objectproperty( + func.object_id( + C.c.table_schema + "." + C.c.constraint_name + ), + "CnstIsClustKey", + ).label("is_clustered"), + ) + .where( + sql.and_( + TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner, + ), + ) + .order_by(TC.c.constraint_name, C.c.ordinal_position) + ) + c = connection.execution_options(future_result=True).execute(s) + constraint_name = None + is_clustered = None + for row in c.mappings(): + if "PRIMARY" in row[TC.c.constraint_type.name]: + pkeys.append(row["COLUMN_NAME"]) + if constraint_name is None: + constraint_name = row[C.c.constraint_name.name] + if is_clustered is None: + is_clustered = row["is_clustered"] + if pkeys: + return { + "constrained_columns": pkeys, + "name": constraint_name, + "dialect_options": {"mssql_clustered": is_clustered}, + } + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.pk_constraint, + **kw, + ) + + @reflection.cache + @_db_plus_owner + def get_foreign_keys( + self, connection, tablename, dbname, owner, schema, **kw + ): + # Foreign key constraints + s = ( + text( + """\ +WITH fk_info AS ( + SELECT + ischema_ref_con.constraint_schema, + ischema_ref_con.constraint_name, + ischema_key_col.ordinal_position, + ischema_key_col.table_schema, + ischema_key_col.table_name, + ischema_ref_con.unique_constraint_schema, + ischema_ref_con.unique_constraint_name, + ischema_ref_con.match_option, + ischema_ref_con.update_rule, + ischema_ref_con.delete_rule, + ischema_key_col.column_name AS constrained_column + FROM + INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS ischema_ref_con + INNER JOIN + INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col ON + ischema_key_col.table_schema = ischema_ref_con.constraint_schema + AND ischema_key_col.constraint_name = + ischema_ref_con.constraint_name + WHERE ischema_key_col.table_name = :tablename + AND ischema_key_col.table_schema = :owner +), +constraint_info AS ( + SELECT + ischema_key_col.constraint_schema, + ischema_key_col.constraint_name, + ischema_key_col.ordinal_position, + ischema_key_col.table_schema, + ischema_key_col.table_name, + ischema_key_col.column_name + FROM + INFORMATION_SCHEMA.KEY_COLUMN_USAGE ischema_key_col +), +index_info AS ( + SELECT + sys.schemas.name AS index_schema, + sys.indexes.name AS index_name, + sys.index_columns.key_ordinal AS ordinal_position, + sys.schemas.name AS table_schema, + sys.objects.name AS table_name, + sys.columns.name AS column_name + FROM + sys.indexes + INNER JOIN + sys.objects ON + sys.objects.object_id = sys.indexes.object_id + INNER JOIN + sys.schemas ON + sys.schemas.schema_id = sys.objects.schema_id + INNER JOIN + sys.index_columns ON + sys.index_columns.object_id = sys.objects.object_id + AND sys.index_columns.index_id = sys.indexes.index_id + INNER JOIN + sys.columns ON + sys.columns.object_id = sys.indexes.object_id + AND sys.columns.column_id = sys.index_columns.column_id +) + SELECT + fk_info.constraint_schema, + fk_info.constraint_name, + fk_info.ordinal_position, + fk_info.constrained_column, + constraint_info.table_schema AS referred_table_schema, + constraint_info.table_name AS referred_table_name, + constraint_info.column_name AS referred_column, + fk_info.match_option, + fk_info.update_rule, + fk_info.delete_rule + FROM + fk_info INNER JOIN constraint_info ON + constraint_info.constraint_schema = + fk_info.unique_constraint_schema + AND constraint_info.constraint_name = + fk_info.unique_constraint_name + AND constraint_info.ordinal_position = fk_info.ordinal_position + UNION + SELECT + fk_info.constraint_schema, + fk_info.constraint_name, + fk_info.ordinal_position, + fk_info.constrained_column, + index_info.table_schema AS referred_table_schema, + index_info.table_name AS referred_table_name, + index_info.column_name AS referred_column, + fk_info.match_option, + fk_info.update_rule, + fk_info.delete_rule + FROM + fk_info INNER JOIN index_info ON + index_info.index_schema = fk_info.unique_constraint_schema + AND index_info.index_name = fk_info.unique_constraint_name + AND index_info.ordinal_position = fk_info.ordinal_position + + ORDER BY fk_info.constraint_schema, fk_info.constraint_name, + fk_info.ordinal_position +""" + ) + .bindparams( + sql.bindparam("tablename", tablename, ischema.CoerceUnicode()), + sql.bindparam("owner", owner, ischema.CoerceUnicode()), + ) + .columns( + constraint_schema=sqltypes.Unicode(), + constraint_name=sqltypes.Unicode(), + table_schema=sqltypes.Unicode(), + table_name=sqltypes.Unicode(), + constrained_column=sqltypes.Unicode(), + referred_table_schema=sqltypes.Unicode(), + referred_table_name=sqltypes.Unicode(), + referred_column=sqltypes.Unicode(), + ) + ) + + # group rows by constraint ID, to handle multi-column FKs + fkeys = [] + + def fkey_rec(): + return { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, + } + + fkeys = util.defaultdict(fkey_rec) + + for r in connection.execute(s).all(): + ( + _, # constraint schema + rfknm, + _, # ordinal position + scol, + rschema, + rtbl, + rcol, + # TODO: we support match= for foreign keys so + # we can support this also, PG has match=FULL for example + # but this seems to not be a valid value for SQL Server + _, # match rule + fkuprule, + fkdelrule, + ) = r + + rec = fkeys[rfknm] + rec["name"] = rfknm + + if fkuprule != "NO ACTION": + rec["options"]["onupdate"] = fkuprule + + if fkdelrule != "NO ACTION": + rec["options"]["ondelete"] = fkdelrule + + if not rec["referred_table"]: + rec["referred_table"] = rtbl + if schema is not None or owner != rschema: + if dbname: + rschema = dbname + "." + rschema + rec["referred_schema"] = rschema + + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) + + local_cols.append(scol) + remote_cols.append(rcol) + + if fkeys: + return list(fkeys.values()) + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.foreign_keys, + **kw, + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/information_schema.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/information_schema.py new file mode 100644 index 0000000..0c5f237 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/information_schema.py @@ -0,0 +1,254 @@ +# dialects/mssql/information_schema.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import cast +from ... import Column +from ... import MetaData +from ... import Table +from ...ext.compiler import compiles +from ...sql import expression +from ...types import Boolean +from ...types import Integer +from ...types import Numeric +from ...types import NVARCHAR +from ...types import String +from ...types import TypeDecorator +from ...types import Unicode + + +ischema = MetaData() + + +class CoerceUnicode(TypeDecorator): + impl = Unicode + cache_ok = True + + def bind_expression(self, bindvalue): + return _cast_on_2005(bindvalue) + + +class _cast_on_2005(expression.ColumnElement): + def __init__(self, bindvalue): + self.bindvalue = bindvalue + + +@compiles(_cast_on_2005) +def _compile(element, compiler, **kw): + from . import base + + if ( + compiler.dialect.server_version_info is None + or compiler.dialect.server_version_info < base.MS_2005_VERSION + ): + return compiler.process(element.bindvalue, **kw) + else: + return compiler.process(cast(element.bindvalue, Unicode), **kw) + + +schemata = Table( + "SCHEMATA", + ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA", +) + +tables = Table( + "TABLES", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", CoerceUnicode, key="table_type"), + schema="INFORMATION_SCHEMA", +) + +columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA", +) + +mssql_temp_table_columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="tempdb.INFORMATION_SCHEMA", +) + +constraints = Table( + "TABLE_CONSTRAINTS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"), + schema="INFORMATION_SCHEMA", +) + +column_constraints = Table( + "CONSTRAINT_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA", +) + +key_constraints = Table( + "KEY_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA", +) + +ref_constraints = Table( + "REFERENTIAL_CONSTRAINTS", + ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + # TODO: is CATLOG misspelled ? + Column( + "UNIQUE_CONSTRAINT_CATLOG", + CoerceUnicode, + key="unique_constraint_catalog", + ), + Column( + "UNIQUE_CONSTRAINT_SCHEMA", + CoerceUnicode, + key="unique_constraint_schema", + ), + Column( + "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name" + ), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA", +) + +views = Table( + "VIEWS", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA", +) + +computed_columns = Table( + "computed_columns", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("is_computed", Boolean), + Column("is_persisted", Boolean), + Column("definition", CoerceUnicode), + schema="sys", +) + +sequences = Table( + "SEQUENCES", + ischema, + Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"), + Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"), + Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"), + schema="INFORMATION_SCHEMA", +) + + +class NumericSqlVariant(TypeDecorator): + r"""This type casts sql_variant columns in the identity_columns view + to numeric. This is required because: + + * pyodbc does not support sql_variant + * pymssql under python 2 return the byte representation of the number, + int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the + correct value as string. + """ + + impl = Unicode + cache_ok = True + + def column_expression(self, colexpr): + return cast(colexpr, Numeric(38, 0)) + + +identity_columns = Table( + "identity_columns", + ischema, + Column("object_id", Integer), + Column("name", CoerceUnicode), + Column("is_identity", Boolean), + Column("seed_value", NumericSqlVariant), + Column("increment_value", NumericSqlVariant), + Column("last_value", NumericSqlVariant), + Column("is_not_for_replication", Boolean), + schema="sys", +) + + +class NVarcharSqlVariant(TypeDecorator): + """This type casts sql_variant columns in the extended_properties view + to nvarchar. This is required because pyodbc does not support sql_variant + """ + + impl = Unicode + cache_ok = True + + def column_expression(self, colexpr): + return cast(colexpr, NVARCHAR) + + +extended_properties = Table( + "extended_properties", + ischema, + Column("class", Integer), # TINYINT + Column("class_desc", CoerceUnicode), + Column("major_id", Integer), + Column("minor_id", Integer), + Column("name", CoerceUnicode), + Column("value", NVarcharSqlVariant), + schema="sys", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/json.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/json.py new file mode 100644 index 0000000..18bea09 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/json.py @@ -0,0 +1,133 @@ +# dialects/mssql/json.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import types as sqltypes + +# technically, all the dialect-specific datatypes that don't have any special +# behaviors would be private with names like _MSJson. However, we haven't been +# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType +# / JSONPathType in their json.py files, so keep consistent with that +# sub-convention for now. A future change can update them all to be +# package-private at once. + + +class JSON(sqltypes.JSON): + """MSSQL JSON type. + + MSSQL supports JSON-formatted data as of SQL Server 2016. + + The :class:`_mssql.JSON` datatype at the DDL level will represent the + datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison + functions as well as Python coercion behavior. + + :class:`_mssql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a SQL Server backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`_mssql.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_VALUE`` + or ``JSON_QUERY`` functions at the database level. + + The SQL Server :class:`_mssql.JSON` type necessarily makes use of the + ``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements + of a JSON object. These two functions have a major restriction in that + they are **mutually exclusive** based on the type of object to be returned. + The ``JSON_QUERY`` function **only** returns a JSON dictionary or list, + but not an individual string, numeric, or boolean element; the + ``JSON_VALUE`` function **only** returns an individual string, numeric, + or boolean element. **both functions either return NULL or raise + an error if they are not used against the correct expected value**. + + To handle this awkward requirement, indexed access rules are as follows: + + 1. When extracting a sub element from a JSON that is itself a JSON + dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor + should be used:: + + stmt = select( + data_table.c.data["some key"].as_json() + ).where( + data_table.c.data["some key"].as_json() == {"sub": "structure"} + ) + + 2. When extracting a sub element from a JSON that is a plain boolean, + string, integer, or float, use the appropriate method among + :meth:`_types.JSON.Comparator.as_boolean`, + :meth:`_types.JSON.Comparator.as_string`, + :meth:`_types.JSON.Comparator.as_integer`, + :meth:`_types.JSON.Comparator.as_float`:: + + stmt = select( + data_table.c.data["some key"].as_string() + ).where( + data_table.c.data["some key"].as_string() == "some string" + ) + + .. versionadded:: 1.4 + + + """ + + # note there was a result processor here that was looking for "number", + # but none of the tests seem to exercise it. + + +# Note: these objects currently match exactly those of MySQL, however since +# these are not generalizable to all JSON implementations, remain separately +# implemented for each dialect. +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/provision.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/provision.py new file mode 100644 index 0000000..143d386 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/provision.py @@ -0,0 +1,155 @@ +# dialects/mssql/provision.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from sqlalchemy import inspect +from sqlalchemy import Integer +from ... import create_engine +from ... import exc +from ...schema import Column +from ...schema import DropConstraint +from ...schema import ForeignKeyConstraint +from ...schema import MetaData +from ...schema import Table +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import generate_driver_url +from ...testing.provision import get_temp_table_name +from ...testing.provision import log +from ...testing.provision import normalize_sequence +from ...testing.provision import run_reap_dbs +from ...testing.provision import temp_table_keyword_args + + +@generate_driver_url.for_db("mssql") +def generate_driver_url(url, driver, query_str): + backend = url.get_backend_name() + + new_url = url.set(drivername="%s+%s" % (backend, driver)) + + if driver not in ("pyodbc", "aioodbc"): + new_url = new_url.set(query="") + + if driver == "aioodbc": + new_url = new_url.update_query_dict({"MARS_Connection": "Yes"}) + + if query_str: + new_url = new_url.update_query_string(query_str) + + try: + new_url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return new_url + + +@create_db.for_db("mssql") +def _mssql_create_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + conn.exec_driver_sql("create database %s" % ident) + conn.exec_driver_sql( + "ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident + ) + conn.exec_driver_sql( + "ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident + ) + conn.exec_driver_sql("use %s" % ident) + conn.exec_driver_sql("create schema test_schema") + conn.exec_driver_sql("create schema test_schema_2") + + +@drop_db.for_db("mssql") +def _mssql_drop_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + _mssql_drop_ignore(conn, ident) + + +def _mssql_drop_ignore(conn, ident): + try: + # typically when this happens, we can't KILL the session anyway, + # so let the cleanup process drop the DBs + # for row in conn.exec_driver_sql( + # "select session_id from sys.dm_exec_sessions " + # "where database_id=db_id('%s')" % ident): + # log.info("killing SQL server session %s", row['session_id']) + # conn.exec_driver_sql("kill %s" % row['session_id']) + conn.exec_driver_sql("drop database %s" % ident) + log.info("Reaped db: %s", ident) + return True + except exc.DatabaseError as err: + log.warning("couldn't drop db: %s", err) + return False + + +@run_reap_dbs.for_db("mssql") +def _reap_mssql_dbs(url, idents): + log.info("db reaper connecting to %r", url) + eng = create_engine(url) + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + log.info("identifiers in file: %s", ", ".join(idents)) + + to_reap = conn.exec_driver_sql( + "select d.name from sys.databases as d where name " + "like 'TEST_%' and not exists (select session_id " + "from sys.dm_exec_sessions " + "where database_id=d.database_id)" + ) + all_names = {dbname.lower() for (dbname,) in to_reap} + to_drop = set() + for name in all_names: + if name in idents: + to_drop.add(name) + + dropped = total = 0 + for total, dbname in enumerate(to_drop, 1): + if _mssql_drop_ignore(conn, dbname): + dropped += 1 + log.info( + "Dropped %d out of %d stale databases detected", dropped, total + ) + + +@temp_table_keyword_args.for_db("mssql") +def _mssql_temp_table_keyword_args(cfg, eng): + return {} + + +@get_temp_table_name.for_db("mssql") +def _mssql_get_temp_table_name(cfg, eng, base_name): + return "##" + base_name + + +@drop_all_schema_objects_pre_tables.for_db("mssql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + inspector = inspect(conn) + for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2): + for tname in inspector.get_table_names(schema=schema): + tb = Table( + tname, + MetaData(), + Column("x", Integer), + Column("y", Integer), + schema=schema, + ) + for fk in inspect(conn).get_foreign_keys(tname, schema=schema): + conn.execute( + DropConstraint( + ForeignKeyConstraint( + [tb.c.x], [tb.c.y], name=fk["name"] + ) + ) + ) + + +@normalize_sequence.for_db("mssql") +def normalize_sequence(cfg, sequence): + if sequence.start is None: + sequence.start = 1 + return sequence diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pymssql.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 0000000..ea1f9bd --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,125 @@ +# dialects/mssql/pymssql.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" +.. dialect:: mssql+pymssql + :name: pymssql + :dbapi: pymssql + :connectstring: mssql+pymssql://:@/?charset=utf8 + +pymssql is a Python module that provides a Python DBAPI interface around +`FreeTDS `_. + +.. versionchanged:: 2.0.5 + + pymssql was restored to SQLAlchemy's continuous integration testing + + +""" # noqa +import re + +from .base import MSDialect +from .base import MSIdentifierPreparer +from ... import types as sqltypes +from ... import util +from ...engine import processors + + +class _MSNumeric_pymssql(sqltypes.Numeric): + def result_processor(self, dialect, type_): + if not self.asdecimal: + return processors.to_float + else: + return sqltypes.Numeric.result_processor(self, dialect, type_) + + +class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): + def __init__(self, dialect): + super().__init__(dialect) + # pymssql has the very unusual behavior that it uses pyformat + # yet does not require that percent signs be doubled + self._double_percents = False + + +class MSDialect_pymssql(MSDialect): + supports_statement_cache = True + supports_native_decimal = True + supports_native_uuid = True + driver = "pymssql" + + preparer = MSIdentifierPreparer_pymssql + + colspecs = util.update_copy( + MSDialect.colspecs, + {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float}, + ) + + @classmethod + def import_dbapi(cls): + module = __import__("pymssql") + # pymmsql < 2.1.1 doesn't have a Binary method. we use string + client_ver = tuple(int(x) for x in module.__version__.split(".")) + if client_ver < (2, 1, 1): + # TODO: monkeypatching here is less than ideal + module.Binary = lambda x: x if hasattr(x, "decode") else str(x) + + if client_ver < (1,): + util.warn( + "The pymssql dialect expects at least " + "the 1.0 series of the pymssql DBAPI." + ) + return module + + def _get_server_version_info(self, connection): + vers = connection.exec_driver_sql("select @@version").scalar() + m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers) + if m: + return tuple(int(x) for x in m.group(1, 2, 3, 4)) + else: + return None + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + opts.update(url.query) + port = opts.pop("port", None) + if port and "host" in opts: + opts["host"] = "%s:%s" % (opts["host"], port) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + for msg in ( + "Adaptive Server connection timed out", + "Net-Lib error during Connection reset by peer", + "message 20003", # connection timeout + "Error 10054", + "Not connected to any MS SQL server", + "Connection is closed", + "message 20006", # Write to the server failed + "message 20017", # Unexpected EOF from the server + "message 20047", # DBPROCESS is dead or not enabled + ): + if msg in str(e): + return True + else: + return False + + def get_isolation_level_values(self, dbapi_connection): + return super().get_isolation_level_values(dbapi_connection) + [ + "AUTOCOMMIT" + ] + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + dbapi_connection.autocommit(True) + else: + dbapi_connection.autocommit(False) + super().set_isolation_level(dbapi_connection, level) + + +dialect = MSDialect_pymssql diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pyodbc.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 0000000..76ea046 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,745 @@ +# dialects/mssql/pyodbc.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: mssql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mssql+pyodbc://:@ + :url: https://pypi.org/project/pyodbc/ + +Connecting to PyODBC +-------------------- + +The URL here is to be translated to PyODBC connection strings, as +detailed in `ConnectionStrings `_. + +DSN Connections +^^^^^^^^^^^^^^^ + +A DSN connection in ODBC means that a pre-existing ODBC datasource is +configured on the client machine. The application then specifies the name +of this datasource, which encompasses details such as the specific ODBC driver +in use as well as the network address of the database. Assuming a datasource +is configured on the client, a basic DSN-based connection looks like:: + + engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn") + +Which above, will pass the following connection string to PyODBC:: + + DSN=some_dsn;UID=scott;PWD=tiger + +If the username and password are omitted, the DSN form will also add +the ``Trusted_Connection=yes`` directive to the ODBC string. + +Hostname Connections +^^^^^^^^^^^^^^^^^^^^ + +Hostname-based connections are also supported by pyodbc. These are often +easier to use than a DSN and have the additional advantage that the specific +database name to connect towards may be specified locally in the URL, rather +than it being fixed as part of a datasource configuration. + +When using a hostname connection, the driver name must also be specified in the +query parameters of the URL. As these names usually have spaces in them, the +name must be URL encoded which means using plus signs for spaces:: + + engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server") + +The ``driver`` keyword is significant to the pyodbc dialect and must be +specified in lowercase. + +Any other names passed in the query string are passed through in the pyodbc +connect string, such as ``authentication``, ``TrustServerCertificate``, etc. +Multiple keyword arguments must be separated by an ampersand (``&``); these +will be translated to semicolons when the pyodbc connect string is generated +internally:: + + e = create_engine( + "mssql+pyodbc://scott:tiger@mssql2017:1433/test?" + "driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" + "&authentication=ActiveDirectoryIntegrated" + ) + +The equivalent URL can be constructed using :class:`_sa.engine.URL`:: + + from sqlalchemy.engine import URL + connection_url = URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="mssql2017", + port=1433, + database="test", + query={ + "driver": "ODBC Driver 18 for SQL Server", + "TrustServerCertificate": "yes", + "authentication": "ActiveDirectoryIntegrated", + }, + ) + + +Pass through exact Pyodbc string +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A PyODBC connection string can also be sent in pyodbc's format directly, as +specified in `the PyODBC documentation +`_, +using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object +can help make this easier:: + + from sqlalchemy.engine import URL + connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password" + connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string}) + + engine = create_engine(connection_url) + +.. _mssql_pyodbc_access_tokens: + +Connecting to databases with access tokens +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Some database servers are set up to only accept access tokens for login. For +example, SQL Server allows the use of Azure Active Directory tokens to connect +to databases. This requires creating a credential object using the +``azure-identity`` library. More information about the authentication step can be +found in `Microsoft's documentation +`_. + +After getting an engine, the credentials need to be sent to ``pyodbc.connect`` +each time a connection is requested. One way to do this is to set up an event +listener on the engine that adds the credential token to the dialect's connect +call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For +SQL Server in particular, this is passed as an ODBC connection attribute with +a data structure `described by Microsoft +`_. + +The following code snippet will create an engine that connects to an Azure SQL +database using Azure credentials:: + + import struct + from sqlalchemy import create_engine, event + from sqlalchemy.engine.url import URL + from azure import identity + + SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h + TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database + + connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server" + + engine = create_engine(connection_string) + + azure_credentials = identity.DefaultAzureCredential() + + @event.listens_for(engine, "do_connect") + def provide_token(dialect, conn_rec, cargs, cparams): + # remove the "Trusted_Connection" parameter that SQLAlchemy adds + cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") + + # create token credential + raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le") + token_struct = struct.pack(f"`_, + stating that a connection string when using an access token must not contain + ``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters. + +.. _azure_synapse_ignore_no_transaction_on_rollback: + +Avoiding transaction-related exceptions on Azure Synapse Analytics +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Azure Synapse Analytics has a significant difference in its transaction +handling compared to plain SQL Server; in some cases an error within a Synapse +transaction can cause it to be arbitrarily terminated on the server side, which +then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to +fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()`` +to pass silently if no transaction is present as the driver does not expect +this condition. The symptom of this failure is an exception with a message +resembling 'No corresponding transaction found. (111214)' when attempting to +emit a ``.rollback()`` after an operation had a failure of some kind. + +This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to +the SQL Server dialect via the :func:`_sa.create_engine` function as follows:: + + engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True) + +Using the above parameter, the dialect will catch ``ProgrammingError`` +exceptions raised during ``connection.rollback()`` and emit a warning +if the error message contains code ``111214``, however will not raise +an exception. + +.. versionadded:: 1.4.40 Added the + ``ignore_no_transaction_on_rollback=True`` parameter. + +Enable autocommit for Azure SQL Data Warehouse (DW) connections +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Azure SQL Data Warehouse does not support transactions, +and that can cause problems with SQLAlchemy's "autobegin" (and implicit +commit/rollback) behavior. We can avoid these problems by enabling autocommit +at both the pyodbc and engine levels:: + + connection_url = sa.engine.URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="dw.azure.example.com", + database="mydb", + query={ + "driver": "ODBC Driver 17 for SQL Server", + "autocommit": "True", + }, + ) + + engine = create_engine(connection_url).execution_options( + isolation_level="AUTOCOMMIT" + ) + +Avoiding sending large string parameters as TEXT/NTEXT +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, for historical reasons, Microsoft's ODBC drivers for SQL Server +send long string parameters (greater than 4000 SBCS characters or 2000 Unicode +characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many +years and are starting to cause compatibility issues with newer versions of +SQL_Server/Azure. For example, see `this +issue `_. + +Starting with ODBC Driver 18 for SQL Server we can override the legacy +behavior and pass long strings as varchar(max)/nvarchar(max) using the +``LongAsMax=Yes`` connection string parameter:: + + connection_url = sa.engine.URL.create( + "mssql+pyodbc", + username="scott", + password="tiger", + host="mssqlserver.example.com", + database="mydb", + query={ + "driver": "ODBC Driver 18 for SQL Server", + "LongAsMax": "Yes", + }, + ) + + +Pyodbc Pooling / connection close behavior +------------------------------------------ + +PyODBC uses internal `pooling +`_ by +default, which means connections will be longer lived than they are within +SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often +preferable to disable this behavior. This behavior can only be disabled +globally at the PyODBC module level, **before** any connections are made:: + + import pyodbc + + pyodbc.pooling = False + + # don't use the engine before pooling is set to False + engine = create_engine("mssql+pyodbc://user:pass@dsn") + +If this variable is left at its default value of ``True``, **the application +will continue to maintain active database connections**, even when the +SQLAlchemy engine itself fully discards a connection or if the engine is +disposed. + +.. seealso:: + + `pooling `_ - + in the PyODBC documentation. + +Driver / Unicode Support +------------------------- + +PyODBC works best with Microsoft ODBC drivers, particularly in the area +of Unicode support on both Python 2 and Python 3. + +Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not** +recommended; there have been historically many Unicode-related issues +in this area, including before Microsoft offered ODBC drivers for Linux +and OSX. Now that Microsoft offers drivers for all platforms, for +PyODBC support these are recommended. FreeTDS remains relevant for +non-ODBC drivers such as pymssql where it works very well. + + +Rowcount Support +---------------- + +Previous limitations with the SQLAlchemy ORM's "versioned rows" feature with +Pyodbc have been resolved as of SQLAlchemy 2.0.5. See the notes at +:ref:`mssql_rowcount_versioning`. + +.. _mssql_pyodbc_fastexecutemany: + +Fast Executemany Mode +--------------------- + +The PyODBC driver includes support for a "fast executemany" mode of execution +which greatly reduces round trips for a DBAPI ``executemany()`` call when using +Microsoft ODBC drivers, for **limited size batches that fit in memory**. The +feature is enabled by setting the attribute ``.fast_executemany`` on the DBAPI +cursor when an executemany call is to be used. The SQLAlchemy PyODBC SQL +Server dialect supports this parameter by passing the +``fast_executemany`` parameter to +:func:`_sa.create_engine` , when using the **Microsoft ODBC driver only**:: + + engine = create_engine( + "mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server", + fast_executemany=True) + +.. versionchanged:: 2.0.9 - the ``fast_executemany`` parameter now has its + intended effect of this PyODBC feature taking effect for all INSERT + statements that are executed with multiple parameter sets, which don't + include RETURNING. Previously, SQLAlchemy 2.0's :term:`insertmanyvalues` + feature would cause ``fast_executemany`` to not be used in most cases + even if specified. + +.. versionadded:: 1.3 + +.. seealso:: + + `fast executemany `_ + - on github + +.. _mssql_pyodbc_setinputsizes: + +Setinputsizes Support +----------------------- + +As of version 2.0, the pyodbc ``cursor.setinputsizes()`` method is used for +all statement executions, except for ``cursor.executemany()`` calls when +fast_executemany=True where it is not supported (assuming +:ref:`insertmanyvalues ` is kept enabled, +"fastexecutemany" will not take place for INSERT statements in any case). + +The use of ``cursor.setinputsizes()`` can be disabled by passing +``use_setinputsizes=False`` to :func:`_sa.create_engine`. + +When ``use_setinputsizes`` is left at its default of ``True``, the +specific per-type symbols passed to ``cursor.setinputsizes()`` can be +programmatically customized using the :meth:`.DialectEvents.do_setinputsizes` +hook. See that method for usage examples. + +.. versionchanged:: 2.0 The mssql+pyodbc dialect now defaults to using + ``use_setinputsizes=True`` for all statement executions with the exception of + cursor.executemany() calls when fast_executemany=True. The behavior can + be turned off by passing ``use_setinputsizes=False`` to + :func:`_sa.create_engine`. + +""" # noqa + + +import datetime +import decimal +import re +import struct + +from .base import _MSDateTime +from .base import _MSUnicode +from .base import _MSUnicodeText +from .base import BINARY +from .base import DATETIMEOFFSET +from .base import MSDialect +from .base import MSExecutionContext +from .base import VARBINARY +from .json import JSON as _MSJson +from .json import JSONIndexType as _MSJsonIndexType +from .json import JSONPathType as _MSJsonPathType +from ... import exc +from ... import types as sqltypes +from ... import util +from ...connectors.pyodbc import PyODBCConnector +from ...engine import cursor as _cursor + + +class _ms_numeric_pyodbc: + """Turns Decimals with adjusted() < 0 or > 7 into strings. + + The routines here are needed for older pyodbc versions + as well as current mxODBC versions. + + """ + + def bind_processor(self, dialect): + super_process = super().bind_processor(dialect) + + if not dialect._need_decimal_fix: + return super_process + + def process(value): + if self.asdecimal and isinstance(value, decimal.Decimal): + adjusted = value.adjusted() + if adjusted < 0: + return self._small_dec_to_string(value) + elif adjusted > 7: + return self._large_dec_to_string(value) + + if super_process: + return super_process(value) + else: + return value + + return process + + # these routines needed for older versions of pyodbc. + # as of 2.1.8 this logic is integrated. + + def _small_dec_to_string(self, value): + return "%s0.%s%s" % ( + (value < 0 and "-" or ""), + "0" * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value.as_tuple()[1]]), + ) + + def _large_dec_to_string(self, value): + _int = value.as_tuple()[1] + if "E" in str(value): + result = "%s%s%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int]), + "0" * (value.adjusted() - (len(_int) - 1)), + ) + else: + if (len(_int) - 1) > value.adjusted(): + result = "%s%s.%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + "".join([str(s) for s in _int][value.adjusted() + 1 :]), + ) + else: + result = "%s%s" % ( + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + ) + return result + + +class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric): + pass + + +class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float): + pass + + +class _ms_binary_pyodbc: + """Wraps binary values in dialect-specific Binary wrapper. + If the value is null, return a pyodbc-specific BinaryNull + object to prevent pyODBC [and FreeTDS] from defaulting binary + NULL types to SQLWCHAR and causing implicit conversion errors. + """ + + def bind_processor(self, dialect): + if dialect.dbapi is None: + return None + + DBAPIBinary = dialect.dbapi.Binary + + def process(value): + if value is not None: + return DBAPIBinary(value) + else: + # pyodbc-specific + return dialect.dbapi.BinaryNull + + return process + + +class _ODBCDateTimeBindProcessor: + """Add bind processors to handle datetimeoffset behaviors""" + + has_tz = False + + def bind_processor(self, dialect): + def process(value): + if value is None: + return None + elif isinstance(value, str): + # if a string was passed directly, allow it through + return value + elif not value.tzinfo or (not self.timezone and not self.has_tz): + # for DateTime(timezone=False) + return value + else: + # for DATETIMEOFFSET or DateTime(timezone=True) + # + # Convert to string format required by T-SQL + dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z") + # offset needs a colon, e.g., -0700 -> -07:00 + # "UTC offset in the form (+-)HHMM[SS[.ffffff]]" + # backend currently rejects seconds / fractional seconds + dto_string = re.sub( + r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string + ) + return dto_string + + return process + + +class _ODBCDateTime(_ODBCDateTimeBindProcessor, _MSDateTime): + pass + + +class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET): + has_tz = True + + +class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY): + pass + + +class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY): + pass + + +class _String_pyodbc(sqltypes.String): + def get_dbapi_type(self, dbapi): + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_VARCHAR, 0, 0) + else: + return dbapi.SQL_VARCHAR + + +class _Unicode_pyodbc(_MSUnicode): + def get_dbapi_type(self, dbapi): + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_WVARCHAR, 0, 0) + else: + return dbapi.SQL_WVARCHAR + + +class _UnicodeText_pyodbc(_MSUnicodeText): + def get_dbapi_type(self, dbapi): + if self.length in (None, "max") or self.length >= 2000: + return (dbapi.SQL_WVARCHAR, 0, 0) + else: + return dbapi.SQL_WVARCHAR + + +class _JSON_pyodbc(_MSJson): + def get_dbapi_type(self, dbapi): + return (dbapi.SQL_WVARCHAR, 0, 0) + + +class _JSONIndexType_pyodbc(_MSJsonIndexType): + def get_dbapi_type(self, dbapi): + return dbapi.SQL_WVARCHAR + + +class _JSONPathType_pyodbc(_MSJsonPathType): + def get_dbapi_type(self, dbapi): + return dbapi.SQL_WVARCHAR + + +class MSExecutionContext_pyodbc(MSExecutionContext): + _embedded_scope_identity = False + + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same + statement. + + Background on why "scope_identity()" is preferable to "@@identity": + https://msdn.microsoft.com/en-us/library/ms190315.aspx + + Background on why we attempt to embed "scope_identity()" into the same + statement as the INSERT: + https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values? + + """ + + super().pre_exec() + + # don't embed the scope_identity select into an + # "INSERT .. DEFAULT VALUES" + if ( + self._select_lastrowid + and self.dialect.use_scope_identity + and len(self.parameters[0]) + ): + self._embedded_scope_identity = True + + self.statement += "; select scope_identity()" + + def post_exec(self): + if self._embedded_scope_identity: + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with + # no data (due to triggers, etc.) + while True: + try: + # fetchall() ensures the cursor is consumed + # without closing it (FreeTDS particularly) + rows = self.cursor.fetchall() + except self.dialect.dbapi.Error: + # no way around this - nextset() consumes the previous set + # so we need to just keep flipping + self.cursor.nextset() + else: + if not rows: + # async adapter drivers just return None here + self.cursor.nextset() + continue + row = rows[0] + break + + self._lastrowid = int(row[0]) + + self.cursor_fetch_strategy = _cursor._NO_CURSOR_DML + else: + super().post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + supports_statement_cache = True + + # note this parameter is no longer used by the ORM or default dialect + # see #9414 + supports_sane_rowcount_returning = False + + execution_ctx_cls = MSExecutionContext_pyodbc + + colspecs = util.update_copy( + MSDialect.colspecs, + { + sqltypes.Numeric: _MSNumeric_pyodbc, + sqltypes.Float: _MSFloat_pyodbc, + BINARY: _BINARY_pyodbc, + # support DateTime(timezone=True) + sqltypes.DateTime: _ODBCDateTime, + DATETIMEOFFSET: _ODBCDATETIMEOFFSET, + # SQL Server dialect has a VARBINARY that is just to support + # "deprecate_large_types" w/ VARBINARY(max), but also we must + # handle the usual SQL standard VARBINARY + VARBINARY: _VARBINARY_pyodbc, + sqltypes.VARBINARY: _VARBINARY_pyodbc, + sqltypes.LargeBinary: _VARBINARY_pyodbc, + sqltypes.String: _String_pyodbc, + sqltypes.Unicode: _Unicode_pyodbc, + sqltypes.UnicodeText: _UnicodeText_pyodbc, + sqltypes.JSON: _JSON_pyodbc, + sqltypes.JSON.JSONIndexType: _JSONIndexType_pyodbc, + sqltypes.JSON.JSONPathType: _JSONPathType_pyodbc, + # this excludes Enum from the string/VARCHAR thing for now + # it looks like Enum's adaptation doesn't really support the + # String type itself having a dialect-level impl + sqltypes.Enum: sqltypes.Enum, + }, + ) + + def __init__( + self, + fast_executemany=False, + use_setinputsizes=True, + **params, + ): + super().__init__(use_setinputsizes=use_setinputsizes, **params) + self.use_scope_identity = ( + self.use_scope_identity + and self.dbapi + and hasattr(self.dbapi.Cursor, "nextset") + ) + self._need_decimal_fix = self.dbapi and self._dbapi_version() < ( + 2, + 1, + 8, + ) + self.fast_executemany = fast_executemany + if fast_executemany: + self.use_insertmanyvalues_wo_returning = False + + def _get_server_version_info(self, connection): + try: + # "Version of the instance of SQL Server, in the form + # of 'major.minor.build.revision'" + raw = connection.exec_driver_sql( + "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)" + ).scalar() + except exc.DBAPIError: + # SQL Server docs indicate this function isn't present prior to + # 2008. Before we had the VARCHAR cast above, pyodbc would also + # fail on this query. + return super()._get_server_version_info(connection) + else: + version = [] + r = re.compile(r"[.\-]") + for n in r.split(raw): + try: + version.append(int(n)) + except ValueError: + pass + return tuple(version) + + def on_connect(self): + super_ = super().on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + self._setup_timestampoffset_type(conn) + + return on_connect + + def _setup_timestampoffset_type(self, connection): + # output converter function for datetimeoffset + def _handle_datetimeoffset(dto_value): + tup = struct.unpack("<6hI2h", dto_value) + return datetime.datetime( + tup[0], + tup[1], + tup[2], + tup[3], + tup[4], + tup[5], + tup[6] // 1000, + datetime.timezone( + datetime.timedelta(hours=tup[7], minutes=tup[8]) + ), + ) + + odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h + connection.add_output_converter( + odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset + ) + + def do_executemany(self, cursor, statement, parameters, context=None): + if self.fast_executemany: + cursor.fast_executemany = True + super().do_executemany(cursor, statement, parameters, context=context) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + code = e.args[0] + if code in { + "08S01", + "01000", + "01002", + "08003", + "08007", + "08S02", + "08001", + "HYT00", + "HY010", + "10054", + }: + return True + return super().is_disconnect(e, connection, cursor) + + +dialect = MSDialect_pyodbc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__init__.py new file mode 100644 index 0000000..60bac87 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__init__.py @@ -0,0 +1,101 @@ +# dialects/mysql/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from . import aiomysql # noqa +from . import asyncmy # noqa +from . import base # noqa +from . import cymysql # noqa +from . import mariadbconnector # noqa +from . import mysqlconnector # noqa +from . import mysqldb # noqa +from . import pymysql # noqa +from . import pyodbc # noqa +from .base import BIGINT +from .base import BINARY +from .base import BIT +from .base import BLOB +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DECIMAL +from .base import DOUBLE +from .base import ENUM +from .base import FLOAT +from .base import INTEGER +from .base import JSON +from .base import LONGBLOB +from .base import LONGTEXT +from .base import MEDIUMBLOB +from .base import MEDIUMINT +from .base import MEDIUMTEXT +from .base import NCHAR +from .base import NUMERIC +from .base import NVARCHAR +from .base import REAL +from .base import SET +from .base import SMALLINT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import TINYBLOB +from .base import TINYINT +from .base import TINYTEXT +from .base import VARBINARY +from .base import VARCHAR +from .base import YEAR +from .dml import Insert +from .dml import insert +from .expression import match +from ...util import compat + +# default dialect +base.dialect = dialect = mysqldb.dialect + +__all__ = ( + "BIGINT", + "BINARY", + "BIT", + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "DOUBLE", + "ENUM", + "FLOAT", + "INTEGER", + "INTEGER", + "JSON", + "LONGBLOB", + "LONGTEXT", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "NCHAR", + "NVARCHAR", + "NUMERIC", + "SET", + "SMALLINT", + "REAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "VARBINARY", + "VARCHAR", + "YEAR", + "dialect", + "insert", + "Insert", + "match", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..2a39bd6 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/aiomysql.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/aiomysql.cpython-311.pyc new file mode 100644 index 0000000..cd8e408 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/aiomysql.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/asyncmy.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/asyncmy.cpython-311.pyc new file mode 100644 index 0000000..1c1eb90 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/asyncmy.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..e26259b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/cymysql.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/cymysql.cpython-311.pyc new file mode 100644 index 0000000..e08de25 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/cymysql.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/dml.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/dml.cpython-311.pyc new file mode 100644 index 0000000..87e30ca Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/dml.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/enumerated.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/enumerated.cpython-311.pyc new file mode 100644 index 0000000..2df5629 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/enumerated.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/expression.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/expression.cpython-311.pyc new file mode 100644 index 0000000..c3fd7e9 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/expression.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/json.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/json.cpython-311.pyc new file mode 100644 index 0000000..417677b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/json.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadb.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadb.cpython-311.pyc new file mode 100644 index 0000000..79ef157 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadb.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadbconnector.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadbconnector.cpython-311.pyc new file mode 100644 index 0000000..a01ff5f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mariadbconnector.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqlconnector.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqlconnector.cpython-311.pyc new file mode 100644 index 0000000..d7a6a76 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqlconnector.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqldb.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqldb.cpython-311.pyc new file mode 100644 index 0000000..d572a66 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/mysqldb.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/provision.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/provision.cpython-311.pyc new file mode 100644 index 0000000..3865a06 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/provision.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pymysql.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pymysql.cpython-311.pyc new file mode 100644 index 0000000..bc40e52 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pymysql.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pyodbc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pyodbc.cpython-311.pyc new file mode 100644 index 0000000..07a2fcc Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/pyodbc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reflection.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reflection.cpython-311.pyc new file mode 100644 index 0000000..0b1bc4c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reflection.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reserved_words.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reserved_words.cpython-311.pyc new file mode 100644 index 0000000..1ef118c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/reserved_words.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000..6d20ed0 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/__pycache__/types.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/aiomysql.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/aiomysql.py new file mode 100644 index 0000000..405fa82 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/aiomysql.py @@ -0,0 +1,332 @@ +# dialects/mysql/aiomysql.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: mysql+aiomysql + :name: aiomysql + :dbapi: aiomysql + :connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/aio-libs/aiomysql + +The aiomysql dialect is SQLAlchemy's second Python asyncio dialect. + +Using a special asyncio mediation layer, the aiomysql dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4") + + +""" # noqa +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiomysql_cursor: + # TODO: base on connectors/asyncio.py + # see #10415 + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor(adapt_connection.dbapi.Cursor) + + # see https://github.com/aio-libs/aiomysql/issues/543 + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._execute_mutex: + result = await self._cursor.execute(operation, parameters) + + if not self.server_side: + # aiomysql has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._execute_mutex: + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): + # TODO: base on connectors/asyncio.py + # see #10415 + __slots__ = () + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiomysql_connection(AdaptedConnection): + # TODO: base on connectors/asyncio.py + # see #10415 + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiomysql_ss_cursor(self) + else: + return AsyncAdapt_aiomysql_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def terminate(self): + # it's not awaitable. + self._connection.close() + + def close(self) -> None: + self.await_(self._connection.ensure_closed()) + + +class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): + # TODO: base on connectors/asyncio.py + # see #10415 + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiomysql_dbapi: + def __init__(self, aiomysql, pymysql): + self.aiomysql = aiomysql + self.pymysql = pymysql + self.paramstyle = "format" + self._init_dbapi_attributes() + self.Cursor, self.SSCursor = self._init_cursors_subclasses() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.aiomysql, name)) + + for name in ( + "NUMBER", + "STRING", + "DATETIME", + "BINARY", + "TIMESTAMP", + "Binary", + ): + setattr(self, name, getattr(self.pymysql, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aiomysql_connection( + self, + await_fallback(creator_fn(*arg, **kw)), + ) + else: + return AsyncAdapt_aiomysql_connection( + self, + await_only(creator_fn(*arg, **kw)), + ) + + def _init_cursors_subclasses(self): + # suppress unconditional warning emitted by aiomysql + class Cursor(self.aiomysql.Cursor): + async def _show_warnings(self, conn): + pass + + class SSCursor(self.aiomysql.SSCursor): + async def _show_warnings(self, conn): + pass + + return Cursor, SSCursor + + +class MySQLDialect_aiomysql(MySQLDialect_pymysql): + driver = "aiomysql" + supports_statement_cache = True + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_aiomysql_ss_cursor + + is_async = True + has_terminate = True + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_aiomysql_dbapi( + __import__("aiomysql"), __import__("pymysql") + ) + + @classmethod + def get_pool_class(cls, url): + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def do_terminate(self, dbapi_connection) -> None: + dbapi_connection.terminate() + + def create_connect_args(self, url): + return super().create_connect_args( + url, _translate_args=dict(username="user", database="db") + ) + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + else: + str_e = str(e).lower() + return "not connected" in str_e + + def _found_rows_client_flag(self): + from pymysql.constants import CLIENT + + return CLIENT.FOUND_ROWS + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = MySQLDialect_aiomysql diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/asyncmy.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/asyncmy.py new file mode 100644 index 0000000..7360044 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/asyncmy.py @@ -0,0 +1,337 @@ +# dialects/mysql/asyncmy.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: mysql+asyncmy + :name: asyncmy + :dbapi: asyncmy + :connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...] + :url: https://github.com/long2ice/asyncmy + +Using a special asyncio mediation layer, the asyncmy dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4") + + +""" # noqa +from contextlib import asynccontextmanager + +from .pymysql import MySQLDialect_pymysql +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_asyncmy_cursor: + # TODO: base on connectors/asyncio.py + # see #10415 + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) + + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if not self.server_side: + # asyncmy has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): + # TODO: base on connectors/asyncio.py + # see #10415 + __slots__ = () + server_side = True + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor( + adapt_connection.dbapi.asyncmy.cursors.SSCursor + ) + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_asyncmy_connection(AdaptedConnection): + # TODO: base on connectors/asyncio.py + # see #10415 + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + @asynccontextmanager + async def _mutex_and_adapt_errors(self): + async with self._execute_mutex: + try: + yield + except AttributeError: + raise self.dbapi.InternalError( + "network operation failed due to asyncmy attribute error" + ) + + def ping(self, reconnect): + assert not reconnect + return self.await_(self._do_ping()) + + async def _do_ping(self): + async with self._mutex_and_adapt_errors(): + return await self._connection.ping(False) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): + self.await_(self._connection.autocommit(value)) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncmy_ss_cursor(self) + else: + return AsyncAdapt_asyncmy_cursor(self) + + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def terminate(self): + # it's not awaitable. + self._connection.close() + + def close(self) -> None: + self.await_(self._connection.ensure_closed()) + + +class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +def _Binary(x): + """Return x as a binary type.""" + return bytes(x) + + +class AsyncAdapt_asyncmy_dbapi: + def __init__(self, asyncmy): + self.asyncmy = asyncmy + self.paramstyle = "format" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "Warning", + "Error", + "InterfaceError", + "DataError", + "DatabaseError", + "OperationalError", + "InterfaceError", + "IntegrityError", + "ProgrammingError", + "InternalError", + "NotSupportedError", + ): + setattr(self, name, getattr(self.asyncmy.errors, name)) + + STRING = util.symbol("STRING") + NUMBER = util.symbol("NUMBER") + BINARY = util.symbol("BINARY") + DATETIME = util.symbol("DATETIME") + TIMESTAMP = util.symbol("TIMESTAMP") + Binary = staticmethod(_Binary) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_asyncmy_connection( + self, + await_fallback(creator_fn(*arg, **kw)), + ) + else: + return AsyncAdapt_asyncmy_connection( + self, + await_only(creator_fn(*arg, **kw)), + ) + + +class MySQLDialect_asyncmy(MySQLDialect_pymysql): + driver = "asyncmy" + supports_statement_cache = True + + supports_server_side_cursors = True + _sscursor = AsyncAdapt_asyncmy_ss_cursor + + is_async = True + has_terminate = True + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) + + @classmethod + def get_pool_class(cls, url): + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def do_terminate(self, dbapi_connection) -> None: + dbapi_connection.terminate() + + def create_connect_args(self, url): + return super().create_connect_args( + url, _translate_args=dict(username="user", database="db") + ) + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + else: + str_e = str(e).lower() + return ( + "not connected" in str_e or "network operation failed" in str_e + ) + + def _found_rows_client_flag(self): + from asyncmy.constants import CLIENT + + return CLIENT.FOUND_ROWS + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = MySQLDialect_asyncmy diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/base.py new file mode 100644 index 0000000..dacbb7a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/base.py @@ -0,0 +1,3447 @@ +# dialects/mysql/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + +.. dialect:: mysql + :name: MySQL / MariaDB + :full_support: 5.6, 5.7, 8.0 / 10.8, 10.9 + :normal_support: 5.6+ / 10+ + :best_effort: 5.0.2+ / 5.0.2+ + +Supported Versions and Features +------------------------------- + +SQLAlchemy supports MySQL starting with version 5.0.2 through modern releases, +as well as all modern versions of MariaDB. See the official MySQL +documentation for detailed information about features supported in any given +server release. + +.. versionchanged:: 1.4 minimum MySQL version supported is now 5.0.2. + +MariaDB Support +~~~~~~~~~~~~~~~ + +The MariaDB variant of MySQL retains fundamental compatibility with MySQL's +protocols however the development of these two products continues to diverge. +Within the realm of SQLAlchemy, the two databases have a small number of +syntactical and behavioral differences that SQLAlchemy accommodates automatically. +To connect to a MariaDB database, no changes to the database URL are required:: + + + engine = create_engine("mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + +Upon first connect, the SQLAlchemy dialect employs a +server version detection scheme that determines if the +backing database reports as MariaDB. Based on this flag, the dialect +can make different choices in those of areas where its behavior +must be different. + +.. _mysql_mariadb_only_mode: + +MariaDB-Only Mode +~~~~~~~~~~~~~~~~~ + +The dialect also supports an **optional** "MariaDB-only" mode of connection, which may be +useful for the case where an application makes use of MariaDB-specific features +and is not compatible with a MySQL database. To use this mode of operation, +replace the "mysql" token in the above URL with "mariadb":: + + engine = create_engine("mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") + +The above engine, upon first connect, will raise an error if the server version +detection detects that the backing database is not MariaDB. + +When using an engine with ``"mariadb"`` as the dialect name, **all mysql-specific options +that include the name "mysql" in them are now named with "mariadb"**. This means +options like ``mysql_engine`` should be named ``mariadb_engine``, etc. Both +"mysql" and "mariadb" options can be used simultaneously for applications that +use URLs with both "mysql" and "mariadb" dialects:: + + my_table = Table( + "mytable", + metadata, + Column("id", Integer, primary_key=True), + Column("textdata", String(50)), + mariadb_engine="InnoDB", + mysql_engine="InnoDB", + ) + + Index( + "textdata_ix", + my_table.c.textdata, + mysql_prefix="FULLTEXT", + mariadb_prefix="FULLTEXT", + ) + +Similar behavior will occur when the above structures are reflected, i.e. the +"mariadb" prefix will be present in the option names when the database URL +is based on the "mariadb" name. + +.. versionadded:: 1.4 Added "mariadb" dialect name supporting "MariaDB-only mode" + for the MySQL dialect. + +.. _mysql_connection_timeouts: + +Connection Timeouts and Disconnects +----------------------------------- + +MySQL / MariaDB feature an automatic connection close behavior, for connections that +have been idle for a fixed period of time, defaulting to eight hours. +To circumvent having this issue, use +the :paramref:`_sa.create_engine.pool_recycle` option which ensures that +a connection will be discarded and replaced with a new one if it has been +present in the pool for a fixed number of seconds:: + + engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) + +For more comprehensive disconnect detection of pooled connections, including +accommodation of server restarts and network issues, a pre-ping approach may +be employed. See :ref:`pool_disconnects` for current approaches. + +.. seealso:: + + :ref:`pool_disconnects` - Background on several techniques for dealing + with timed out connections as well as database restarts. + +.. _mysql_storage_engines: + +CREATE TABLE arguments including Storage Engines +------------------------------------------------ + +Both MySQL's and MariaDB's CREATE TABLE syntax includes a wide array of special options, +including ``ENGINE``, ``CHARSET``, ``MAX_ROWS``, ``ROW_FORMAT``, +``INSERT_METHOD``, and many more. +To accommodate the rendering of these arguments, specify the form +``mysql_argument_name="value"``. For example, to specify a table with +``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8mb4``, and ``KEY_BLOCK_SIZE`` +of ``1024``:: + + Table('mytable', metadata, + Column('data', String(32)), + mysql_engine='InnoDB', + mysql_charset='utf8mb4', + mysql_key_block_size="1024" + ) + +When supporting :ref:`mysql_mariadb_only_mode` mode, similar keys against +the "mariadb" prefix must be included as well. The values can of course +vary independently so that different settings on MySQL vs. MariaDB may +be maintained:: + + # support both "mysql" and "mariadb-only" engine URLs + + Table('mytable', metadata, + Column('data', String(32)), + + mysql_engine='InnoDB', + mariadb_engine='InnoDB', + + mysql_charset='utf8mb4', + mariadb_charset='utf8', + + mysql_key_block_size="1024" + mariadb_key_block_size="1024" + + ) + +The MySQL / MariaDB dialects will normally transfer any keyword specified as +``mysql_keyword_name`` to be rendered as ``KEYWORD_NAME`` in the +``CREATE TABLE`` statement. A handful of these names will render with a space +instead of an underscore; to support this, the MySQL dialect has awareness of +these particular names, which include ``DATA DIRECTORY`` +(e.g. ``mysql_data_directory``), ``CHARACTER SET`` (e.g. +``mysql_character_set``) and ``INDEX DIRECTORY`` (e.g. +``mysql_index_directory``). + +The most common argument is ``mysql_engine``, which refers to the storage +engine for the table. Historically, MySQL server installations would default +to ``MyISAM`` for this value, although newer versions may be defaulting +to ``InnoDB``. The ``InnoDB`` engine is typically preferred for its support +of transactions and foreign keys. + +A :class:`_schema.Table` +that is created in a MySQL / MariaDB database with a storage engine +of ``MyISAM`` will be essentially non-transactional, meaning any +INSERT/UPDATE/DELETE statement referring to this table will be invoked as +autocommit. It also will have no support for foreign key constraints; while +the ``CREATE TABLE`` statement accepts foreign key options, when using the +``MyISAM`` storage engine these arguments are discarded. Reflecting such a +table will also produce no foreign key constraint information. + +For fully atomic transactions as well as support for foreign key +constraints, all participating ``CREATE TABLE`` statements must specify a +transactional engine, which in the vast majority of cases is ``InnoDB``. + + +Case Sensitivity and Table Reflection +------------------------------------- + +Both MySQL and MariaDB have inconsistent support for case-sensitive identifier +names, basing support on specific details of the underlying +operating system. However, it has been observed that no matter +what case sensitivity behavior is present, the names of tables in +foreign key declarations are *always* received from the database +as all-lower case, making it impossible to accurately reflect a +schema where inter-related tables use mixed-case identifier names. + +Therefore it is strongly advised that table names be declared as +all lower case both within SQLAlchemy as well as on the MySQL / MariaDB +database itself, especially if database reflection features are +to be used. + +.. _mysql_isolation_level: + +Transaction Isolation Level +--------------------------- + +All MySQL / MariaDB dialects support setting of transaction isolation level both via a +dialect-specific parameter :paramref:`_sa.create_engine.isolation_level` +accepted +by :func:`_sa.create_engine`, as well as the +:paramref:`.Connection.execution_options.isolation_level` argument as passed to +:meth:`_engine.Connection.execution_options`. +This feature works by issuing the +command ``SET SESSION TRANSACTION ISOLATION LEVEL `` for each new +connection. For the special AUTOCOMMIT isolation level, DBAPI-specific +techniques are used. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "mysql+mysqldb://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +The special ``AUTOCOMMIT`` value makes use of the various "autocommit" +attributes provided by specific DBAPIs, and is currently supported by +MySQLdb, MySQL-Client, MySQL-Connector Python, and PyMySQL. Using it, +the database connection will return true for the value of +``SELECT @@autocommit;``. + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +.. seealso:: + + :ref:`dbapi_autocommit` + +AUTO_INCREMENT Behavior +----------------------- + +When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on +the first :class:`.Integer` primary key column which is not marked as a +foreign key:: + + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True) + ... ) + >>> t.create() + CREATE TABLE mytable ( + id INTEGER NOT NULL AUTO_INCREMENT, + PRIMARY KEY (id) + ) + +You can disable this behavior by passing ``False`` to the +:paramref:`_schema.Column.autoincrement` argument of :class:`_schema.Column`. +This flag +can also be used to enable auto-increment on a secondary column in a +multi-column key for some storage engines:: + + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True) + ) + +.. _mysql_ss_cursors: + +Server Side Cursors +------------------- + +Server-side cursor support is available for the mysqlclient, PyMySQL, +mariadbconnector dialects and may also be available in others. This makes use +of either the "buffered=True/False" flag if available or by using a class such +as ``MySQLdb.cursors.SSCursor`` or ``pymysql.cursors.SSCursor`` internally. + + +Server side cursors are enabled on a per-statement basis by using the +:paramref:`.Connection.execution_options.stream_results` connection execution +option:: + + with engine.connect() as conn: + result = conn.execution_options(stream_results=True).execute(text("select * from table")) + +Note that some kinds of SQL statements may not be supported with +server side cursors; generally, only SQL statements that return rows should be +used with this option. + +.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated + and will be removed in a future release. Please use the + :paramref:`_engine.Connection.stream_results` execution option for + unbuffered cursor support. + +.. seealso:: + + :ref:`engine_stream_results` + +.. _mysql_unicode: + +Unicode +------- + +Charset Selection +~~~~~~~~~~~~~~~~~ + +Most MySQL / MariaDB DBAPIs offer the option to set the client character set for +a connection. This is typically delivered using the ``charset`` parameter +in the URL, such as:: + + e = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + +This charset is the **client character set** for the connection. Some +MySQL DBAPIs will default this to a value such as ``latin1``, and some +will make use of the ``default-character-set`` setting in the ``my.cnf`` +file as well. Documentation for the DBAPI in use should be consulted +for specific behavior. + +The encoding used for Unicode has traditionally been ``'utf8'``. However, for +MySQL versions 5.5.3 and MariaDB 5.5 on forward, a new MySQL-specific encoding +``'utf8mb4'`` has been introduced, and as of MySQL 8.0 a warning is emitted by +the server if plain ``utf8`` is specified within any server-side directives, +replaced with ``utf8mb3``. The rationale for this new encoding is due to the +fact that MySQL's legacy utf-8 encoding only supports codepoints up to three +bytes instead of four. Therefore, when communicating with a MySQL or MariaDB +database that includes codepoints more than three bytes in size, this new +charset is preferred, if supported by both the database as well as the client +DBAPI, as in:: + + e = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") + +All modern DBAPIs should support the ``utf8mb4`` charset. + +In order to use ``utf8mb4`` encoding for a schema that was created with legacy +``utf8``, changes to the MySQL/MariaDB schema and/or server configuration may be +required. + +.. seealso:: + + `The utf8mb4 Character Set \ + `_ - \ + in the MySQL documentation + +.. _mysql_binary_introducer: + +Dealing with Binary Data Warnings and Unicode +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL versions 5.6, 5.7 and later (not MariaDB at the time of this writing) now +emit a warning when attempting to pass binary data to the database, while a +character set encoding is also in place, when the binary data itself is not +valid for that encoding:: + + default.py:509: Warning: (1300, "Invalid utf8mb4 character string: + 'F9876A'") + cursor.execute(statement, parameters) + +This warning is due to the fact that the MySQL client library is attempting to +interpret the binary string as a unicode object even if a datatype such +as :class:`.LargeBinary` is in use. To resolve this, the SQL statement requires +a binary "character set introducer" be present before any non-NULL value +that renders like this:: + + INSERT INTO table (data) VALUES (_binary %s) + +These character set introducers are provided by the DBAPI driver, assuming the +use of mysqlclient or PyMySQL (both of which are recommended). Add the query +string parameter ``binary_prefix=true`` to the URL to repair this warning:: + + # mysqlclient + engine = create_engine( + "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + + # PyMySQL + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + + +The ``binary_prefix`` flag may or may not be supported by other MySQL drivers. + +SQLAlchemy itself cannot render this ``_binary`` prefix reliably, as it does +not work with the NULL value, which is valid to be sent as a bound parameter. +As the MySQL driver renders parameters directly into the SQL string, it's the +most efficient place for this additional keyword to be passed. + +.. seealso:: + + `Character set introducers `_ - on the MySQL website + + +ANSI Quoting Style +------------------ + +MySQL / MariaDB feature two varieties of identifier "quoting style", one using +backticks and the other using quotes, e.g. ```some_identifier``` vs. +``"some_identifier"``. All MySQL dialects detect which version +is in use by checking the value of :ref:`sql_mode` when a connection is first +established with a particular :class:`_engine.Engine`. +This quoting style comes +into play when rendering table and column names as well as when reflecting +existing database structures. The detection is entirely automatic and +no special configuration is needed to use either quoting style. + + +.. _mysql_sql_mode: + +Changing the sql_mode +--------------------- + +MySQL supports operating in multiple +`Server SQL Modes `_ for +both Servers and Clients. To change the ``sql_mode`` for a given application, a +developer can leverage SQLAlchemy's Events system. + +In the following example, the event system is used to set the ``sql_mode`` on +the ``first_connect`` and ``connect`` events:: + + from sqlalchemy import create_engine, event + + eng = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo='debug') + + # `insert=True` will ensure this is the very first listener to run + @event.listens_for(eng, "connect", insert=True) + def connect(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("SET sql_mode = 'STRICT_ALL_TABLES'") + + conn = eng.connect() + +In the example illustrated above, the "connect" event will invoke the "SET" +statement on the connection at the moment a particular DBAPI connection is +first created for a given Pool, before the connection is made available to the +connection pool. Additionally, because the function was registered with +``insert=True``, it will be prepended to the internal list of registered +functions. + + +MySQL / MariaDB SQL Extensions +------------------------------ + +Many of the MySQL / MariaDB SQL extensions are handled through SQLAlchemy's generic +function and operator support:: + + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) + +And of course any valid SQL statement can be executed as a string as well. + +Some limited direct support for MySQL / MariaDB extensions to SQL is currently +available. + +* INSERT..ON DUPLICATE KEY UPDATE: See + :ref:`mysql_insert_on_duplicate_key_update` + +* SELECT pragma, use :meth:`_expression.Select.prefix_with` and + :meth:`_query.Query.prefix_with`:: + + select(...).prefix_with(['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) + +* UPDATE with LIMIT:: + + update(..., mysql_limit=10, mariadb_limit=10) + +* optimizer hints, use :meth:`_expression.Select.prefix_with` and + :meth:`_query.Query.prefix_with`:: + + select(...).prefix_with("/*+ NO_RANGE_OPTIMIZATION(t4 PRIMARY) */") + +* index hints, use :meth:`_expression.Select.with_hint` and + :meth:`_query.Query.with_hint`:: + + select(...).with_hint(some_table, "USE INDEX xyz") + +* MATCH operator support:: + + from sqlalchemy.dialects.mysql import match + select(...).where(match(col1, col2, against="some expr").in_boolean_mode()) + + .. seealso:: + + :class:`_mysql.match` + +INSERT/DELETE...RETURNING +------------------------- + +The MariaDB dialect supports 10.5+'s ``INSERT..RETURNING`` and +``DELETE..RETURNING`` (10.0+) syntaxes. ``INSERT..RETURNING`` may be used +automatically in some cases in order to fetch newly generated identifiers in +place of the traditional approach of using ``cursor.lastrowid``, however +``cursor.lastrowid`` is currently still preferred for simple single-statement +cases for its better performance. + +To specify an explicit ``RETURNING`` clause, use the +:meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = connection.execute( + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # DELETE..RETURNING + result = connection.execute( + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + +.. versionadded:: 2.0 Added support for MariaDB RETURNING + +.. _mysql_insert_on_duplicate_key_update: + +INSERT...ON DUPLICATE KEY UPDATE (Upsert) +------------------------------------------ + +MySQL / MariaDB allow "upserts" (update or insert) +of rows into a table via the ``ON DUPLICATE KEY UPDATE`` clause of the +``INSERT`` statement. A candidate row will only be inserted if that row does +not match an existing primary or unique key in the table; otherwise, an UPDATE +will be performed. The statement allows for separate specification of the +values to INSERT versus the values for UPDATE. + +SQLAlchemy provides ``ON DUPLICATE KEY UPDATE`` support via the MySQL-specific +:func:`.mysql.insert()` function, which provides +the generative method :meth:`~.mysql.Insert.on_duplicate_key_update`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.mysql import insert + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... data=insert_stmt.inserted.data, + ... status='U' + ... ) + >>> print(on_duplicate_key_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = VALUES(data), status = %s + + +Unlike PostgreSQL's "ON CONFLICT" phrase, the "ON DUPLICATE KEY UPDATE" +phrase will always match on any primary key or unique key, and will always +perform an UPDATE if there's a match; there are no options for it to raise +an error or to skip performing an UPDATE. + +``ON DUPLICATE KEY UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are normally specified using +keyword arguments passed to the +:meth:`_mysql.Insert.on_duplicate_key_update` +given column key values (usually the name of the column, unless it +specifies :paramref:`_schema.Column.key` +) as keys and literal or SQL expressions +as values: + +.. sourcecode:: pycon+sql + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... data="some data", + ... updated_at=func.current_timestamp(), + ... ) + + >>> print(on_duplicate_key_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP + +In a manner similar to that of :meth:`.UpdateBase.values`, other parameter +forms are accepted, including a single dictionary: + +.. sourcecode:: pycon+sql + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... {"data": "some data", "updated_at": func.current_timestamp()}, + ... ) + +as well as a list of 2-tuples, which will automatically provide +a parameter-ordered UPDATE statement in a manner similar to that described +at :ref:`tutorial_parameter_ordered_updates`. Unlike the :class:`_expression.Update` +object, +no special flag is needed to specify the intent since the argument form is +this context is unambiguous: + +.. sourcecode:: pycon+sql + + >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( + ... [ + ... ("data", "some data"), + ... ("updated_at", func.current_timestamp()), + ... ] + ... ) + + >>> print(on_duplicate_key_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) + ON DUPLICATE KEY UPDATE data = %s, updated_at = CURRENT_TIMESTAMP + +.. versionchanged:: 1.3 support for parameter-ordered UPDATE clause within + MySQL ON DUPLICATE KEY UPDATE + +.. warning:: + + The :meth:`_mysql.Insert.on_duplicate_key_update` + method does **not** take into + account Python-side default UPDATE values or generation functions, e.g. + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON DUPLICATE KEY style of UPDATE, + unless they are manually specified explicitly in the parameters. + + + +In order to refer to the proposed insertion row, the special alias +:attr:`_mysql.Insert.inserted` is available as an attribute on +the :class:`_mysql.Insert` object; this object is a +:class:`_expression.ColumnCollection` which contains all columns of the target +table: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh') + + >>> do_update_stmt = stmt.on_duplicate_key_update( + ... data="updated value", + ... author=stmt.inserted.author + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) VALUES (%s, %s, %s) + ON DUPLICATE KEY UPDATE data = %s, author = VALUES(author) + +When rendered, the "inserted" namespace will produce the expression +``VALUES()``. + +.. versionadded:: 1.2 Added support for MySQL ON DUPLICATE KEY UPDATE clause + + + +rowcount Support +---------------- + +SQLAlchemy standardizes the DBAPI ``cursor.rowcount`` attribute to be the +usual definition of "number of rows matched by an UPDATE or DELETE" statement. +This is in contradiction to the default setting on most MySQL DBAPI drivers, +which is "number of rows actually modified/deleted". For this reason, the +SQLAlchemy MySQL dialects always add the ``constants.CLIENT.FOUND_ROWS`` +flag, or whatever is equivalent for the target dialect, upon connection. +This setting is currently hardcoded. + +.. seealso:: + + :attr:`_engine.CursorResult.rowcount` + + +.. _mysql_indexes: + +MySQL / MariaDB- Specific Index Options +----------------------------------------- + +MySQL and MariaDB-specific extensions to the :class:`.Index` construct are available. + +Index Length +~~~~~~~~~~~~~ + +MySQL and MariaDB both provide an option to create index entries with a certain length, where +"length" refers to the number of characters or bytes in each value which will +become part of the index. SQLAlchemy provides this feature via the +``mysql_length`` and/or ``mariadb_length`` parameters:: + + Index('my_index', my_table.c.data, mysql_length=10, mariadb_length=10) + + Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4, + 'b': 9}) + + Index('a_b_idx', my_table.c.a, my_table.c.b, mariadb_length={'a': 4, + 'b': 9}) + +Prefix lengths are given in characters for nonbinary string types and in bytes +for binary string types. The value passed to the keyword argument *must* be +either an integer (and, thus, specify the same prefix length value for all +columns of the index) or a dict in which keys are column names and values are +prefix length values for corresponding columns. MySQL and MariaDB only allow a +length for a column of an index if it is for a CHAR, VARCHAR, TEXT, BINARY, +VARBINARY and BLOB. + +Index Prefixes +~~~~~~~~~~~~~~ + +MySQL storage engines permit you to specify an index prefix when creating +an index. SQLAlchemy provides this feature via the +``mysql_prefix`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_prefix='FULLTEXT') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX, so it *must* be a valid index prefix for your MySQL +storage engine. + +.. seealso:: + + `CREATE INDEX `_ - MySQL documentation + +Index Types +~~~~~~~~~~~~~ + +Some MySQL storage engines permit you to specify an index type when creating +an index or primary key constraint. SQLAlchemy provides this feature via the +``mysql_using`` parameter on :class:`.Index`:: + + Index('my_index', my_table.c.data, mysql_using='hash', mariadb_using='hash') + +As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: + + PrimaryKeyConstraint("data", mysql_using='hash', mariadb_using='hash') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index +type for your MySQL storage engine. + +More information can be found at: + +https://dev.mysql.com/doc/refman/5.0/en/create-index.html + +https://dev.mysql.com/doc/refman/5.0/en/create-table.html + +Index Parsers +~~~~~~~~~~~~~ + +CREATE FULLTEXT INDEX in MySQL also supports a "WITH PARSER" option. This +is available using the keyword argument ``mysql_with_parser``:: + + Index( + 'my_index', my_table.c.data, + mysql_prefix='FULLTEXT', mysql_with_parser="ngram", + mariadb_prefix='FULLTEXT', mariadb_with_parser="ngram", + ) + +.. versionadded:: 1.3 + + +.. _mysql_foreign_keys: + +MySQL / MariaDB Foreign Keys +----------------------------- + +MySQL and MariaDB's behavior regarding foreign keys has some important caveats. + +Foreign Key Arguments to Avoid +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Neither MySQL nor MariaDB support the foreign key arguments "DEFERRABLE", "INITIALLY", +or "MATCH". Using the ``deferrable`` or ``initially`` keyword argument with +:class:`_schema.ForeignKeyConstraint` or :class:`_schema.ForeignKey` +will have the effect of +these keywords being rendered in a DDL expression, which will then raise an +error on MySQL or MariaDB. In order to use these keywords on a foreign key while having +them ignored on a MySQL / MariaDB backend, use a custom compile rule:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.schema import ForeignKeyConstraint + + @compiles(ForeignKeyConstraint, "mysql", "mariadb") + def process(element, compiler, **kw): + element.deferrable = element.initially = None + return compiler.visit_foreign_key_constraint(element, **kw) + +The "MATCH" keyword is in fact more insidious, and is explicitly disallowed +by SQLAlchemy in conjunction with the MySQL or MariaDB backends. This argument is +silently ignored by MySQL / MariaDB, but in addition has the effect of ON UPDATE and ON +DELETE options also being ignored by the backend. Therefore MATCH should +never be used with the MySQL / MariaDB backends; as is the case with DEFERRABLE and +INITIALLY, custom compilation rules can be used to correct a +ForeignKeyConstraint at DDL definition time. + +Reflection of Foreign Key Constraints +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Not all MySQL / MariaDB storage engines support foreign keys. When using the +very common ``MyISAM`` MySQL storage engine, the information loaded by table +reflection will not include foreign keys. For these tables, you may supply a +:class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: + + Table('mytable', metadata, + ForeignKeyConstraint(['other_id'], ['othertable.other_id']), + autoload_with=engine + ) + +.. seealso:: + + :ref:`mysql_storage_engines` + +.. _mysql_unique_constraints: + +MySQL / MariaDB Unique Constraints and Reflection +---------------------------------------------------- + +SQLAlchemy supports both the :class:`.Index` construct with the +flag ``unique=True``, indicating a UNIQUE index, as well as the +:class:`.UniqueConstraint` construct, representing a UNIQUE constraint. +Both objects/syntaxes are supported by MySQL / MariaDB when emitting DDL to create +these constraints. However, MySQL / MariaDB does not have a unique constraint +construct that is separate from a unique index; that is, the "UNIQUE" +constraint on MySQL / MariaDB is equivalent to creating a "UNIQUE INDEX". + +When reflecting these constructs, the +:meth:`_reflection.Inspector.get_indexes` +and the :meth:`_reflection.Inspector.get_unique_constraints` +methods will **both** +return an entry for a UNIQUE index in MySQL / MariaDB. However, when performing +full table reflection using ``Table(..., autoload_with=engine)``, +the :class:`.UniqueConstraint` construct is +**not** part of the fully reflected :class:`_schema.Table` construct under any +circumstances; this construct is always represented by a :class:`.Index` +with the ``unique=True`` setting present in the :attr:`_schema.Table.indexes` +collection. + + +TIMESTAMP / DATETIME issues +--------------------------- + +.. _mysql_timestamp_onupdate: + +Rendering ON UPDATE CURRENT TIMESTAMP for MySQL / MariaDB's explicit_defaults_for_timestamp +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL / MariaDB have historically expanded the DDL for the :class:`_types.TIMESTAMP` +datatype into the phrase "TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE +CURRENT_TIMESTAMP", which includes non-standard SQL that automatically updates +the column with the current timestamp when an UPDATE occurs, eliminating the +usual need to use a trigger in such a case where server-side update changes are +desired. + +MySQL 5.6 introduced a new flag `explicit_defaults_for_timestamp +`_ which disables the above behavior, +and in MySQL 8 this flag defaults to true, meaning in order to get a MySQL +"on update timestamp" without changing this flag, the above DDL must be +rendered explicitly. Additionally, the same DDL is valid for use of the +``DATETIME`` datatype as well. + +SQLAlchemy's MySQL dialect does not yet have an option to generate +MySQL's "ON UPDATE CURRENT_TIMESTAMP" clause, noting that this is not a general +purpose "ON UPDATE" as there is no such syntax in standard SQL. SQLAlchemy's +:paramref:`_schema.Column.server_onupdate` parameter is currently not related +to this special MySQL behavior. + +To generate this DDL, make use of the :paramref:`_schema.Column.server_default` +parameter and pass a textual clause that also includes the ON UPDATE clause:: + + from sqlalchemy import Table, MetaData, Column, Integer, String, TIMESTAMP + from sqlalchemy import text + + metadata = MetaData() + + mytable = Table( + "mytable", + metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column( + 'last_updated', + TIMESTAMP, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) + ) + +The same instructions apply to use of the :class:`_types.DateTime` and +:class:`_types.DATETIME` datatypes:: + + from sqlalchemy import DateTime + + mytable = Table( + "mytable", + metadata, + Column('id', Integer, primary_key=True), + Column('data', String(50)), + Column( + 'last_updated', + DateTime, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) + ) + + +Even though the :paramref:`_schema.Column.server_onupdate` feature does not +generate this DDL, it still may be desirable to signal to the ORM that this +updated value should be fetched. This syntax looks like the following:: + + from sqlalchemy.schema import FetchedValue + + class MyClass(Base): + __tablename__ = 'mytable' + + id = Column(Integer, primary_key=True) + data = Column(String(50)) + last_updated = Column( + TIMESTAMP, + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), + server_onupdate=FetchedValue() + ) + + +.. _mysql_timestamp_null: + +TIMESTAMP Columns and NULL +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +MySQL historically enforces that a column which specifies the +TIMESTAMP datatype implicitly includes a default value of +CURRENT_TIMESTAMP, even though this is not stated, and additionally +sets the column as NOT NULL, the opposite behavior vs. that of all +other datatypes:: + + mysql> CREATE TABLE ts_test ( + -> a INTEGER, + -> b INTEGER NOT NULL, + -> c TIMESTAMP, + -> d TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + -> e TIMESTAMP NULL); + Query OK, 0 rows affected (0.03 sec) + + mysql> SHOW CREATE TABLE ts_test; + +---------+----------------------------------------------------- + | Table | Create Table + +---------+----------------------------------------------------- + | ts_test | CREATE TABLE `ts_test` ( + `a` int(11) DEFAULT NULL, + `b` int(11) NOT NULL, + `c` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + `d` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + `e` timestamp NULL DEFAULT NULL + ) ENGINE=MyISAM DEFAULT CHARSET=latin1 + +Above, we see that an INTEGER column defaults to NULL, unless it is specified +with NOT NULL. But when the column is of type TIMESTAMP, an implicit +default of CURRENT_TIMESTAMP is generated which also coerces the column +to be a NOT NULL, even though we did not specify it as such. + +This behavior of MySQL can be changed on the MySQL side using the +`explicit_defaults_for_timestamp +`_ configuration flag introduced in +MySQL 5.6. With this server setting enabled, TIMESTAMP columns behave like +any other datatype on the MySQL side with regards to defaults and nullability. + +However, to accommodate the vast majority of MySQL databases that do not +specify this new flag, SQLAlchemy emits the "NULL" specifier explicitly with +any TIMESTAMP column that does not specify ``nullable=False``. In order to +accommodate newer databases that specify ``explicit_defaults_for_timestamp``, +SQLAlchemy also emits NOT NULL for TIMESTAMP columns that do specify +``nullable=False``. The following example illustrates:: + + from sqlalchemy import MetaData, Integer, Table, Column, text + from sqlalchemy.dialects.mysql import TIMESTAMP + + m = MetaData() + t = Table('ts_test', m, + Column('a', Integer), + Column('b', Integer, nullable=False), + Column('c', TIMESTAMP), + Column('d', TIMESTAMP, nullable=False) + ) + + + from sqlalchemy import create_engine + e = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo=True) + m.create_all(e) + +output:: + + CREATE TABLE ts_test ( + a INTEGER, + b INTEGER NOT NULL, + c TIMESTAMP NULL, + d TIMESTAMP NOT NULL + ) + +""" # noqa +from __future__ import annotations + +from array import array as _array +from collections import defaultdict +from itertools import compress +import re +from typing import cast + +from . import reflection as _reflection +from .enumerated import ENUM +from .enumerated import SET +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from .reserved_words import RESERVED_WORDS_MARIADB +from .reserved_words import RESERVED_WORDS_MYSQL +from .types import _FloatType +from .types import _IntegerType +from .types import _MatchType +from .types import _NumericType +from .types import _StringType +from .types import BIGINT +from .types import BIT +from .types import CHAR +from .types import DATETIME +from .types import DECIMAL +from .types import DOUBLE +from .types import FLOAT +from .types import INTEGER +from .types import LONGBLOB +from .types import LONGTEXT +from .types import MEDIUMBLOB +from .types import MEDIUMINT +from .types import MEDIUMTEXT +from .types import NCHAR +from .types import NUMERIC +from .types import NVARCHAR +from .types import REAL +from .types import SMALLINT +from .types import TEXT +from .types import TIME +from .types import TIMESTAMP +from .types import TINYBLOB +from .types import TINYINT +from .types import TINYTEXT +from .types import VARCHAR +from .types import YEAR +from ... import exc +from ... import literal_column +from ... import log +from ... import schema as sa_schema +from ... import sql +from ... import util +from ...engine import cursor as _cursor +from ...engine import default +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import functions +from ...sql import operators +from ...sql import roles +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql import visitors +from ...sql.compiler import InsertmanyvaluesSentinelOpts +from ...sql.compiler import SQLCompiler +from ...sql.schema import SchemaConst +from ...types import BINARY +from ...types import BLOB +from ...types import BOOLEAN +from ...types import DATE +from ...types import UUID +from ...types import VARBINARY +from ...util import topological + + +SET_RE = re.compile( + r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE +) + +# old names +MSTime = TIME +MSSet = SET +MSEnum = ENUM +MSLongBlob = LONGBLOB +MSMediumBlob = MEDIUMBLOB +MSTinyBlob = TINYBLOB +MSBlob = BLOB +MSBinary = BINARY +MSVarBinary = VARBINARY +MSNChar = NCHAR +MSNVarChar = NVARCHAR +MSChar = CHAR +MSString = VARCHAR +MSLongText = LONGTEXT +MSMediumText = MEDIUMTEXT +MSTinyText = TINYTEXT +MSText = TEXT +MSYear = YEAR +MSTimeStamp = TIMESTAMP +MSBit = BIT +MSSmallInteger = SMALLINT +MSTinyInteger = TINYINT +MSMediumInteger = MEDIUMINT +MSBigInteger = BIGINT +MSNumeric = NUMERIC +MSDecimal = DECIMAL +MSDouble = DOUBLE +MSReal = REAL +MSFloat = FLOAT +MSInteger = INTEGER + +colspecs = { + _IntegerType: _IntegerType, + _NumericType: _NumericType, + _FloatType: _FloatType, + sqltypes.Numeric: NUMERIC, + sqltypes.Float: FLOAT, + sqltypes.Double: DOUBLE, + sqltypes.Time: TIME, + sqltypes.Enum: ENUM, + sqltypes.MatchType: _MatchType, + sqltypes.JSON: JSON, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, +} + +# Everything 3.23 through 5.1 excepting OpenGIS types. +ischema_names = { + "bigint": BIGINT, + "binary": BINARY, + "bit": BIT, + "blob": BLOB, + "boolean": BOOLEAN, + "char": CHAR, + "date": DATE, + "datetime": DATETIME, + "decimal": DECIMAL, + "double": DOUBLE, + "enum": ENUM, + "fixed": DECIMAL, + "float": FLOAT, + "int": INTEGER, + "integer": INTEGER, + "json": JSON, + "longblob": LONGBLOB, + "longtext": LONGTEXT, + "mediumblob": MEDIUMBLOB, + "mediumint": MEDIUMINT, + "mediumtext": MEDIUMTEXT, + "nchar": NCHAR, + "nvarchar": NVARCHAR, + "numeric": NUMERIC, + "set": SET, + "smallint": SMALLINT, + "text": TEXT, + "time": TIME, + "timestamp": TIMESTAMP, + "tinyblob": TINYBLOB, + "tinyint": TINYINT, + "tinytext": TINYTEXT, + "uuid": UUID, + "varbinary": VARBINARY, + "varchar": VARCHAR, + "year": YEAR, +} + + +class MySQLExecutionContext(default.DefaultExecutionContext): + def post_exec(self): + if ( + self.isdelete + and cast(SQLCompiler, self.compiled).effective_returning + and not self.cursor.description + ): + # All MySQL/mariadb drivers appear to not include + # cursor.description for DELETE..RETURNING with no rows if the + # WHERE criteria is a straight "false" condition such as our EMPTY + # IN condition. manufacture an empty result in this case (issue + # #10505) + # + # taken from cx_Oracle implementation + self.cursor_fetch_strategy = ( + _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + [ + (entry.keyname, None) + for entry in cast( + SQLCompiler, self.compiled + )._result_columns + ], + [], + ) + ) + + def create_server_side_cursor(self): + if self.dialect.supports_server_side_cursors: + return self._dbapi_connection.cursor(self.dialect._sscursor) + else: + raise NotImplementedError() + + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "select nextval(%s)" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + +class MySQLCompiler(compiler.SQLCompiler): + render_table_with_column_in_update_from = True + """Overridden from base SQLCompiler value""" + + extract_map = compiler.SQLCompiler.extract_map.copy() + extract_map.update({"milliseconds": "millisecond"}) + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. + + """ + if self.stack: + stmt = self.stack[-1]["selectable"] + if stmt._where_criteria: + return " FROM DUAL" + + return "" + + def visit_random_func(self, fn, **kw): + return "rand%s" % self.function_argspec(fn) + + def visit_rollup_func(self, fn, **kw): + clause = ", ".join( + elem._compiler_dispatch(self, **kw) for elem in fn.clauses + ) + return f"{clause} WITH ROLLUP" + + def visit_aggregate_strings_func(self, fn, **kw): + expr, delimeter = ( + elem._compiler_dispatch(self, **kw) for elem in fn.clauses + ) + return f"group_concat({expr} SEPARATOR {delimeter})" + + def visit_sequence(self, seq, **kw): + return "nextval(%s)" % self.preparer.format_sequence(seq) + + def visit_sysdate_func(self, fn, **kw): + return "SYSDATE()" + + def _render_json_extract_from_binary(self, binary, operator, **kw): + # note we are intentionally calling upon the process() calls in the + # order in which they appear in the SQL String as this is used + # by positional parameter rendering + + if binary.type._type_affinity is sqltypes.JSON: + return "JSON_EXTRACT(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + # for non-JSON, MySQL doesn't handle JSON null at all so it has to + # be explicit + case_expression = "CASE JSON_EXTRACT(%s, %s) WHEN 'null' THEN NULL" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + if binary.type._type_affinity is sqltypes.Integer: + type_expression = ( + "ELSE CAST(JSON_EXTRACT(%s, %s) AS SIGNED INTEGER)" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ) + elif binary.type._type_affinity is sqltypes.Numeric: + if ( + binary.type.scale is not None + and binary.type.precision is not None + ): + # using DECIMAL here because MySQL does not recognize NUMERIC + type_expression = ( + "ELSE CAST(JSON_EXTRACT(%s, %s) AS DECIMAL(%s, %s))" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + binary.type.precision, + binary.type.scale, + ) + ) + else: + # FLOAT / REAL not added in MySQL til 8.0.17 + type_expression = ( + "ELSE JSON_EXTRACT(%s, %s)+0.0000000000000000000000" + % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ) + elif binary.type._type_affinity is sqltypes.Boolean: + # the NULL handling is particularly weird with boolean, so + # explicitly return true/false constants + type_expression = "WHEN true THEN true ELSE false" + elif binary.type._type_affinity is sqltypes.String: + # (gord): this fails with a JSON value that's a four byte unicode + # string. SQLite has the same problem at the moment + # (zzzeek): I'm not really sure. let's take a look at a test case + # that hits each backend and maybe make a requires rule for it? + type_expression = "ELSE JSON_UNQUOTE(JSON_EXTRACT(%s, %s))" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + else: + # other affinity....this is not expected right now + type_expression = "ELSE JSON_EXTRACT(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + return case_expression + " " + type_expression + " END" + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self._render_json_extract_from_binary(binary, operator, **kw) + + def visit_on_duplicate_key_update(self, on_duplicate, **kw): + statement = self.current_executable + + if on_duplicate._parameter_ordering: + parameter_ordering = [ + coercions.expect(roles.DMLColumnRole, key) + for key in on_duplicate._parameter_ordering + ] + ordered_keys = set(parameter_ordering) + cols = [ + statement.table.c[key] + for key in parameter_ordering + if key in statement.table.c + ] + [c for c in statement.table.c if c.key not in ordered_keys] + else: + cols = statement.table.c + + clauses = [] + + requires_mysql8_alias = ( + self.dialect._requires_alias_for_on_duplicate_key + ) + + if requires_mysql8_alias: + if statement.table.name.lower() == "new": + _on_dup_alias_name = "new_1" + else: + _on_dup_alias_name = "new" + + # traverses through all table columns to preserve table column order + for column in (col for col in cols if col.key in on_duplicate.update): + val = on_duplicate.update[column.key] + + if coercions._is_literal(val): + val = elements.BindParameter(None, val, type_=column.type) + value_text = self.process(val.self_group(), use_schema=False) + else: + + def replace(obj): + if ( + isinstance(obj, elements.BindParameter) + and obj.type._isnull + ): + obj = obj._clone() + obj.type = column.type + return obj + elif ( + isinstance(obj, elements.ColumnClause) + and obj.table is on_duplicate.inserted_alias + ): + if requires_mysql8_alias: + column_literal_clause = ( + f"{_on_dup_alias_name}." + f"{self.preparer.quote(obj.name)}" + ) + else: + column_literal_clause = ( + f"VALUES({self.preparer.quote(obj.name)})" + ) + return literal_column(column_literal_clause) + else: + # element is not replaced + return None + + val = visitors.replacement_traverse(val, {}, replace) + value_text = self.process(val.self_group(), use_schema=False) + + name_text = self.preparer.quote(column.name) + clauses.append("%s = %s" % (name_text, value_text)) + + non_matching = set(on_duplicate.update) - {c.key for c in cols} + if non_matching: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.statement.table.name, + (", ".join("'%s'" % c for c in non_matching)), + ) + ) + + if requires_mysql8_alias: + return ( + f"AS {_on_dup_alias_name} " + f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" + ) + else: + return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" + + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return "concat(%s)" % ( + ", ".join(self.process(elem, **kw) for elem in clauselist.clauses) + ) + + def visit_concat_op_binary(self, binary, operator, **kw): + return "concat(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + _match_valid_flag_combinations = frozenset( + ( + # (boolean_mode, natural_language, query_expansion) + (False, False, False), + (True, False, False), + (False, True, False), + (False, False, True), + (False, True, True), + ) + ) + + _match_flag_expressions = ( + "IN BOOLEAN MODE", + "IN NATURAL LANGUAGE MODE", + "WITH QUERY EXPANSION", + ) + + def visit_mysql_match(self, element, **kw): + return self.visit_match_op_binary(element, element.operator, **kw) + + def visit_match_op_binary(self, binary, operator, **kw): + """ + Note that `mysql_boolean_mode` is enabled by default because of + backward compatibility + """ + + modifiers = binary.modifiers + + boolean_mode = modifiers.get("mysql_boolean_mode", True) + natural_language = modifiers.get("mysql_natural_language", False) + query_expansion = modifiers.get("mysql_query_expansion", False) + + flag_combination = (boolean_mode, natural_language, query_expansion) + + if flag_combination not in self._match_valid_flag_combinations: + flags = ( + "in_boolean_mode=%s" % boolean_mode, + "in_natural_language_mode=%s" % natural_language, + "with_query_expansion=%s" % query_expansion, + ) + + flags = ", ".join(flags) + + raise exc.CompileError("Invalid MySQL match flags: %s" % flags) + + match_clause = binary.left + match_clause = self.process(match_clause, **kw) + against_clause = self.process(binary.right, **kw) + + if any(flag_combination): + flag_expressions = compress( + self._match_flag_expressions, + flag_combination, + ) + + against_clause = [against_clause] + against_clause.extend(flag_expressions) + + against_clause = " ".join(against_clause) + + return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause) + + def get_from_hint_text(self, table, text): + return text + + def visit_typeclause(self, typeclause, type_=None, **kw): + if type_ is None: + type_ = typeclause.type.dialect_impl(self.dialect) + if isinstance(type_, sqltypes.TypeDecorator): + return self.visit_typeclause(typeclause, type_.impl, **kw) + elif isinstance(type_, sqltypes.Integer): + if getattr(type_, "unsigned", False): + return "UNSIGNED INTEGER" + else: + return "SIGNED INTEGER" + elif isinstance(type_, sqltypes.TIMESTAMP): + return "DATETIME" + elif isinstance( + type_, + ( + sqltypes.DECIMAL, + sqltypes.DateTime, + sqltypes.Date, + sqltypes.Time, + ), + ): + return self.dialect.type_compiler_instance.process(type_) + elif isinstance(type_, sqltypes.String) and not isinstance( + type_, (ENUM, SET) + ): + adapted = CHAR._adapt_string_for_cast(type_) + return self.dialect.type_compiler_instance.process(adapted) + elif isinstance(type_, sqltypes._Binary): + return "BINARY" + elif isinstance(type_, sqltypes.JSON): + return "JSON" + elif isinstance(type_, sqltypes.NUMERIC): + return self.dialect.type_compiler_instance.process(type_).replace( + "NUMERIC", "DECIMAL" + ) + elif ( + isinstance(type_, sqltypes.Float) + and self.dialect._support_float_cast + ): + return self.dialect.type_compiler_instance.process(type_) + else: + return None + + def visit_cast(self, cast, **kw): + type_ = self.process(cast.typeclause) + if type_ is None: + util.warn( + "Datatype %s does not support CAST on MySQL/MariaDb; " + "the CAST will be skipped." + % self.dialect.type_compiler_instance.process( + cast.typeclause.type + ) + ) + return self.process(cast.clause.self_group(), **kw) + + return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) + + def render_literal_value(self, value, type_): + value = super().render_literal_value(value, type_) + if self.dialect._backslash_escapes: + value = value.replace("\\", "\\\\") + return value + + # override native_boolean=False behavior here, as + # MySQL still supports native boolean + def visit_true(self, element, **kw): + return "true" + + def visit_false(self, element, **kw): + return "false" + + def get_select_precolumns(self, select, **kw): + """Add special MySQL keywords in place of DISTINCT. + + .. deprecated:: 1.4 This usage is deprecated. + :meth:`_expression.Select.prefix_with` should be used for special + keywords at the start of a SELECT. + + """ + if isinstance(select._distinct, str): + util.warn_deprecated( + "Sending string values for 'distinct' is deprecated in the " + "MySQL dialect and will be removed in a future release. " + "Please use :meth:`.Select.prefix_with` for special keywords " + "at the start of a SELECT statement", + version="1.4", + ) + return select._distinct.upper() + " " + + return super().get_select_precolumns(select, **kw) + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " INNER JOIN " + + return "".join( + ( + self.process( + join.left, asfrom=True, from_linter=from_linter, **kwargs + ), + join_type, + self.process( + join.right, asfrom=True, from_linter=from_linter, **kwargs + ), + " ON ", + self.process(join.onclause, from_linter=from_linter, **kwargs), + ) + ) + + def for_update_clause(self, select, **kw): + if select._for_update_arg.read: + tmp = " LOCK IN SHARE MODE" + else: + tmp = " FOR UPDATE" + + if select._for_update_arg.of and self.dialect.supports_for_update_of: + tables = util.OrderedSet() + for c in select._for_update_arg.of: + tables.update(sql_util.surface_selectables_only(c)) + + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def limit_clause(self, select, **kw): + # MySQL supports: + # LIMIT + # LIMIT , + # and in server versions > 3.3: + # LIMIT OFFSET + # The latter is more readable for offsets but we're stuck with the + # former until we can refine dialects by server revision. + + limit_clause, offset_clause = ( + select._limit_clause, + select._offset_clause, + ) + + if limit_clause is None and offset_clause is None: + return "" + elif offset_clause is not None: + # As suggested by the MySQL docs, need to apply an + # artificial limit if one wasn't provided + # https://dev.mysql.com/doc/refman/5.0/en/select.html + if limit_clause is None: + # TODO: remove ?? + # hardwire the upper limit. Currently + # needed consistent with the usage of the upper + # bound as part of MySQL's "syntax" for OFFSET with + # no LIMIT. + return " \n LIMIT %s, %s" % ( + self.process(offset_clause, **kw), + "18446744073709551615", + ) + else: + return " \n LIMIT %s, %s" % ( + self.process(offset_clause, **kw), + self.process(limit_clause, **kw), + ) + else: + # No offset provided, so just use the limit + return " \n LIMIT %s" % (self.process(limit_clause, **kw),) + + def update_limit_clause(self, update_stmt): + limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) + if limit: + return "LIMIT %s" % limit + else: + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + kw["asfrom"] = True + return ", ".join( + t._compiler_dispatch(self, **kw) + for t in [from_table] + list(extra_froms) + ) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return None + + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + """If we have extra froms make sure we render any alias as hint.""" + ashint = False + if extra_froms: + ashint = True + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, ashint=ashint, **kw + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. USING clause specific to MySQL.""" + kw["asfrom"] = True + return "USING " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) + + def visit_empty_set_expr(self, element_types, **kw): + return ( + "SELECT %(outer)s FROM (SELECT %(inner)s) " + "as _empty_set WHERE 1!=1" + % { + "inner": ", ".join( + "1 AS _in_%s" % idx + for idx, type_ in enumerate(element_types) + ), + "outer": ", ".join( + "_in_%s" % idx for idx, type_ in enumerate(element_types) + ), + } + ) + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "NOT (%s <=> %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "%s <=> %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def _mariadb_regexp_flags(self, flags, pattern, **kw): + return "CONCAT('(?', %s, ')', %s)" % ( + self.render_literal_value(flags, sqltypes.STRINGTYPE), + self.process(pattern, **kw), + ) + + def _regexp_match(self, op_string, binary, operator, **kw): + flags = binary.modifiers["flags"] + if flags is None: + return self._generate_generic_binary(binary, op_string, **kw) + elif self.dialect.is_mariadb: + return "%s%s%s" % ( + self.process(binary.left, **kw), + op_string, + self._mariadb_regexp_flags(flags, binary.right), + ) + else: + text = "REGEXP_LIKE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.render_literal_value(flags, sqltypes.STRINGTYPE), + ) + if op_string == " NOT REGEXP ": + return "NOT %s" % text + else: + return text + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match(" REGEXP ", binary, operator, **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match(" NOT REGEXP ", binary, operator, **kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + flags = binary.modifiers["flags"] + if flags is None: + return "REGEXP_REPLACE(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + elif self.dialect.is_mariadb: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self._mariadb_regexp_flags(flags, binary.right.clauses[0]), + self.process(binary.right.clauses[1], **kw), + ) + else: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + self.render_literal_value(flags, sqltypes.STRINGTYPE), + ) + + +class MySQLDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kw): + """Builds column DDL.""" + if ( + self.dialect.is_mariadb is True + and column.computed is not None + and column._user_defined_nullable is SchemaConst.NULL_UNSPECIFIED + ): + column.nullable = True + colspec = [ + self.preparer.format_column(column), + self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ), + ] + + if column.computed is not None: + colspec.append(self.process(column.computed)) + + is_timestamp = isinstance( + column.type._unwrapped_dialect_impl(self.dialect), + sqltypes.TIMESTAMP, + ) + + if not column.nullable: + colspec.append("NOT NULL") + + # see: https://docs.sqlalchemy.org/en/latest/dialects/mysql.html#mysql_timestamp_null # noqa + elif column.nullable and is_timestamp: + colspec.append("NULL") + + comment = column.comment + if comment is not None: + literal = self.sql_compiler.render_literal_value( + comment, sqltypes.String() + ) + colspec.append("COMMENT " + literal) + + if ( + column.table is not None + and column is column.table._autoincrement_column + and ( + column.server_default is None + or isinstance(column.server_default, sa_schema.Identity) + ) + and not ( + self.dialect.supports_sequences + and isinstance(column.default, sa_schema.Sequence) + and not column.default.optional + ) + ): + colspec.append("AUTO_INCREMENT") + else: + default = self.get_column_default_string(column) + if default is not None: + colspec.append("DEFAULT " + default) + return " ".join(colspec) + + def post_create_table(self, table): + """Build table-level CREATE options like ENGINE and COLLATE.""" + + table_opts = [] + + opts = { + k[len(self.dialect.name) + 1 :].upper(): v + for k, v in table.kwargs.items() + if k.startswith("%s_" % self.dialect.name) + } + + if table.comment is not None: + opts["COMMENT"] = table.comment + + partition_options = [ + "PARTITION_BY", + "PARTITIONS", + "SUBPARTITIONS", + "SUBPARTITION_BY", + ] + + nonpart_options = set(opts).difference(partition_options) + part_options = set(opts).intersection(partition_options) + + for opt in topological.sort( + [ + ("DEFAULT_CHARSET", "COLLATE"), + ("DEFAULT_CHARACTER_SET", "COLLATE"), + ("CHARSET", "COLLATE"), + ("CHARACTER_SET", "COLLATE"), + ], + nonpart_options, + ): + arg = opts[opt] + if opt in _reflection._options_of_type_string: + arg = self.sql_compiler.render_literal_value( + arg, sqltypes.String() + ) + + if opt in ( + "DATA_DIRECTORY", + "INDEX_DIRECTORY", + "DEFAULT_CHARACTER_SET", + "CHARACTER_SET", + "DEFAULT_CHARSET", + "DEFAULT_COLLATE", + ): + opt = opt.replace("_", " ") + + joiner = "=" + if opt in ( + "TABLESPACE", + "DEFAULT CHARACTER SET", + "CHARACTER SET", + "COLLATE", + ): + joiner = " " + + table_opts.append(joiner.join((opt, arg))) + + for opt in topological.sort( + [ + ("PARTITION_BY", "PARTITIONS"), + ("PARTITION_BY", "SUBPARTITION_BY"), + ("PARTITION_BY", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITION_BY"), + ("SUBPARTITION_BY", "SUBPARTITIONS"), + ], + part_options, + ): + arg = opts[opt] + if opt in _reflection._options_of_type_string: + arg = self.sql_compiler.render_literal_value( + arg, sqltypes.String() + ) + + opt = opt.replace("_", " ") + joiner = " " + + table_opts.append(joiner.join((opt, arg))) + + return " ".join(table_opts) + + def visit_create_index(self, create, **kw): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + table = preparer.format_table(index.table) + + columns = [ + self.sql_compiler.process( + ( + elements.Grouping(expr) + if ( + isinstance(expr, elements.BinaryExpression) + or ( + isinstance(expr, elements.UnaryExpression) + and expr.modifier + not in (operators.desc_op, operators.asc_op) + ) + or isinstance(expr, functions.FunctionElement) + ) + else expr + ), + include_table=False, + literal_binds=True, + ) + for expr in index.expressions + ] + + name = self._prepared_index_name(index) + + text = "CREATE " + if index.unique: + text += "UNIQUE " + + index_prefix = index.kwargs.get("%s_prefix" % self.dialect.name, None) + if index_prefix: + text += index_prefix + " " + + text += "INDEX " + if create.if_not_exists: + text += "IF NOT EXISTS " + text += "%s ON %s " % (name, table) + + length = index.dialect_options[self.dialect.name]["length"] + if length is not None: + if isinstance(length, dict): + # length value can be a (column_name --> integer value) + # mapping specifying the prefix length for each column of the + # index + columns = ", ".join( + ( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr + ) + ) + for col, expr in zip(index.expressions, columns) + ) + else: + # or can be an integer value specifying the same + # prefix length for all columns of the index + columns = ", ".join( + "%s(%d)" % (col, length) for col in columns + ) + else: + columns = ", ".join(columns) + text += "(%s)" % columns + + parser = index.dialect_options["mysql"]["with_parser"] + if parser is not None: + text += " WITH PARSER %s" % (parser,) + + using = index.dialect_options["mysql"]["using"] + if using is not None: + text += " USING %s" % (preparer.quote(using)) + + return text + + def visit_primary_key_constraint(self, constraint, **kw): + text = super().visit_primary_key_constraint(constraint) + using = constraint.dialect_options["mysql"]["using"] + if using: + text += " USING %s" % (self.preparer.quote(using)) + return text + + def visit_drop_index(self, drop, **kw): + index = drop.element + text = "\nDROP INDEX " + if drop.if_exists: + text += "IF EXISTS " + + return text + "%s ON %s" % ( + self._prepared_index_name(index, include_schema=False), + self.preparer.format_table(index.table), + ) + + def visit_drop_constraint(self, drop, **kw): + constraint = drop.element + if isinstance(constraint, sa_schema.ForeignKeyConstraint): + qual = "FOREIGN KEY " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.PrimaryKeyConstraint): + qual = "PRIMARY KEY " + const = "" + elif isinstance(constraint, sa_schema.UniqueConstraint): + qual = "INDEX " + const = self.preparer.format_constraint(constraint) + elif isinstance(constraint, sa_schema.CheckConstraint): + if self.dialect.is_mariadb: + qual = "CONSTRAINT " + else: + qual = "CHECK " + const = self.preparer.format_constraint(constraint) + else: + qual = "" + const = self.preparer.format_constraint(constraint) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) + + def define_constraint_match(self, constraint): + if constraint.match is not None: + raise exc.CompileError( + "MySQL ignores the 'MATCH' keyword while at the same time " + "causes ON UPDATE/ON DELETE clauses to be ignored." + ) + return "" + + def visit_set_table_comment(self, create, **kw): + return "ALTER TABLE %s COMMENT %s" % ( + self.preparer.format_table(create.element), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_table_comment(self, create, **kw): + return "ALTER TABLE %s COMMENT ''" % ( + self.preparer.format_table(create.element) + ) + + def visit_set_column_comment(self, create, **kw): + return "ALTER TABLE %s CHANGE %s %s" % ( + self.preparer.format_table(create.element.table), + self.preparer.format_column(create.element), + self.get_column_specification(create.element), + ) + + +class MySQLTypeCompiler(compiler.GenericTypeCompiler): + def _extend_numeric(self, type_, spec): + "Extend a numeric-type declaration with MySQL specific extensions." + + if not self._mysql_type(type_): + return spec + + if type_.unsigned: + spec += " UNSIGNED" + if type_.zerofill: + spec += " ZEROFILL" + return spec + + def _extend_string(self, type_, defaults, spec): + """Extend a string-type declaration with standard SQL CHARACTER SET / + COLLATE annotations and MySQL specific extensions. + + """ + + def attr(name): + return getattr(type_, name, defaults.get(name)) + + if attr("charset"): + charset = "CHARACTER SET %s" % attr("charset") + elif attr("ascii"): + charset = "ASCII" + elif attr("unicode"): + charset = "UNICODE" + else: + charset = None + + if attr("collation"): + collation = "COLLATE %s" % type_.collation + elif attr("binary"): + collation = "BINARY" + else: + collation = None + + if attr("national"): + # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. + return " ".join( + [c for c in ("NATIONAL", spec, collation) if c is not None] + ) + return " ".join( + [c for c in (spec, charset, collation) if c is not None] + ) + + def _mysql_type(self, type_): + return isinstance(type_, (_StringType, _NumericType)) + + def visit_NUMERIC(self, type_, **kw): + if type_.precision is None: + return self._extend_numeric(type_, "NUMERIC") + elif type_.scale is None: + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s)" % {"precision": type_.precision}, + ) + else: + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + + def visit_DECIMAL(self, type_, **kw): + if type_.precision is None: + return self._extend_numeric(type_, "DECIMAL") + elif type_.scale is None: + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s)" % {"precision": type_.precision}, + ) + else: + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + + def visit_DOUBLE(self, type_, **kw): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric( + type_, + "DOUBLE(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + else: + return self._extend_numeric(type_, "DOUBLE") + + def visit_REAL(self, type_, **kw): + if type_.precision is not None and type_.scale is not None: + return self._extend_numeric( + type_, + "REAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) + else: + return self._extend_numeric(type_, "REAL") + + def visit_FLOAT(self, type_, **kw): + if ( + self._mysql_type(type_) + and type_.scale is not None + and type_.precision is not None + ): + return self._extend_numeric( + type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale) + ) + elif type_.precision is not None: + return self._extend_numeric( + type_, "FLOAT(%s)" % (type_.precision,) + ) + else: + return self._extend_numeric(type_, "FLOAT") + + def visit_INTEGER(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "INTEGER(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "INTEGER") + + def visit_BIGINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "BIGINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "BIGINT") + + def visit_MEDIUMINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "MEDIUMINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "MEDIUMINT") + + def visit_TINYINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, "TINYINT(%s)" % type_.display_width + ) + else: + return self._extend_numeric(type_, "TINYINT") + + def visit_SMALLINT(self, type_, **kw): + if self._mysql_type(type_) and type_.display_width is not None: + return self._extend_numeric( + type_, + "SMALLINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) + else: + return self._extend_numeric(type_, "SMALLINT") + + def visit_BIT(self, type_, **kw): + if type_.length is not None: + return "BIT(%s)" % type_.length + else: + return "BIT" + + def visit_DATETIME(self, type_, **kw): + if getattr(type_, "fsp", None): + return "DATETIME(%d)" % type_.fsp + else: + return "DATETIME" + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TIME(self, type_, **kw): + if getattr(type_, "fsp", None): + return "TIME(%d)" % type_.fsp + else: + return "TIME" + + def visit_TIMESTAMP(self, type_, **kw): + if getattr(type_, "fsp", None): + return "TIMESTAMP(%d)" % type_.fsp + else: + return "TIMESTAMP" + + def visit_YEAR(self, type_, **kw): + if type_.display_width is None: + return "YEAR" + else: + return "YEAR(%s)" % type_.display_width + + def visit_TEXT(self, type_, **kw): + if type_.length is not None: + return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) + else: + return self._extend_string(type_, {}, "TEXT") + + def visit_TINYTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "TINYTEXT") + + def visit_MEDIUMTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "MEDIUMTEXT") + + def visit_LONGTEXT(self, type_, **kw): + return self._extend_string(type_, {}, "LONGTEXT") + + def visit_VARCHAR(self, type_, **kw): + if type_.length is not None: + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) + else: + raise exc.CompileError( + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) + + def visit_CHAR(self, type_, **kw): + if type_.length is not None: + return self._extend_string( + type_, {}, "CHAR(%(length)s)" % {"length": type_.length} + ) + else: + return self._extend_string(type_, {}, "CHAR") + + def visit_NVARCHAR(self, type_, **kw): + # We'll actually generate the equiv. "NATIONAL VARCHAR" instead + # of "NVARCHAR". + if type_.length is not None: + return self._extend_string( + type_, + {"national": True}, + "VARCHAR(%(length)s)" % {"length": type_.length}, + ) + else: + raise exc.CompileError( + "NVARCHAR requires a length on dialect %s" % self.dialect.name + ) + + def visit_NCHAR(self, type_, **kw): + # We'll actually generate the equiv. + # "NATIONAL CHAR" instead of "NCHAR". + if type_.length is not None: + return self._extend_string( + type_, + {"national": True}, + "CHAR(%(length)s)" % {"length": type_.length}, + ) + else: + return self._extend_string(type_, {"national": True}, "CHAR") + + def visit_UUID(self, type_, **kw): + return "UUID" + + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY(%d)" % type_.length + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_) + + def visit_enum(self, type_, **kw): + if not type_.native_enum: + return super().visit_enum(type_) + else: + return self._visit_enumerated_values("ENUM", type_, type_.enums) + + def visit_BLOB(self, type_, **kw): + if type_.length is not None: + return "BLOB(%d)" % type_.length + else: + return "BLOB" + + def visit_TINYBLOB(self, type_, **kw): + return "TINYBLOB" + + def visit_MEDIUMBLOB(self, type_, **kw): + return "MEDIUMBLOB" + + def visit_LONGBLOB(self, type_, **kw): + return "LONGBLOB" + + def _visit_enumerated_values(self, name, type_, enumerated_values): + quoted_enums = [] + for e in enumerated_values: + quoted_enums.append("'%s'" % e.replace("'", "''")) + return self._extend_string( + type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) + ) + + def visit_ENUM(self, type_, **kw): + return self._visit_enumerated_values("ENUM", type_, type_.enums) + + def visit_SET(self, type_, **kw): + return self._visit_enumerated_values("SET", type_, type_.values) + + def visit_BOOLEAN(self, type_, **kw): + return "BOOL" + + +class MySQLIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS_MYSQL + + def __init__(self, dialect, server_ansiquotes=False, **kw): + if not server_ansiquotes: + quote = "`" + else: + quote = '"' + + super().__init__(dialect, initial_quote=quote, escape_quote=quote) + + def _quote_free_identifiers(self, *ids): + """Unilaterally identifier-quote any number of strings.""" + + return tuple([self.quote_identifier(i) for i in ids if i is not None]) + + +class MariaDBIdentifierPreparer(MySQLIdentifierPreparer): + reserved_words = RESERVED_WORDS_MARIADB + + +@log.class_logger +class MySQLDialect(default.DefaultDialect): + """Details of the MySQL dialect. + Not used directly in application code. + """ + + name = "mysql" + supports_statement_cache = True + + supports_alter = True + + # MySQL has no true "boolean" type; we + # allow for the "true" and "false" keywords, however + supports_native_boolean = False + + # identifiers are 64, however aliases can be 255... + max_identifier_length = 255 + max_index_name_length = 64 + max_constraint_name_length = 64 + + div_is_floordiv = False + + supports_native_enum = True + + returns_native_bytes = True + + supports_sequences = False # default for MySQL ... + # ... may be updated to True for MariaDB 10.3+ in initialize() + + sequences_optional = False + + supports_for_update_of = False # default for MySQL ... + # ... may be updated to True for MySQL 8+ in initialize() + + _requires_alias_for_on_duplicate_key = False # Only available ... + # ... in MySQL 8+ + + # MySQL doesn't support "DEFAULT VALUES" but *does* support + # "VALUES (DEFAULT)" + supports_default_values = False + supports_default_metavalue = True + + use_insertmanyvalues: bool = True + insertmanyvalues_implicit_sentinel = ( + InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ) + + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_multivalues_insert = True + insert_null_pk_still_autoincrements = True + + supports_comments = True + inline_comments = True + default_paramstyle = "format" + colspecs = colspecs + + cte_follows_insert = True + + statement_compiler = MySQLCompiler + ddl_compiler = MySQLDDLCompiler + type_compiler_cls = MySQLTypeCompiler + ischema_names = ischema_names + preparer = MySQLIdentifierPreparer + + is_mariadb = False + _mariadb_normalized_version_info = None + + # default SQL compilation settings - + # these are modified upon initialize(), + # i.e. first connect + _backslash_escapes = True + _server_ansiquotes = False + + construct_arguments = [ + (sa_schema.Table, {"*": None}), + (sql.Update, {"limit": None}), + (sa_schema.PrimaryKeyConstraint, {"using": None}), + ( + sa_schema.Index, + { + "using": None, + "length": None, + "prefix": None, + "with_parser": None, + }, + ), + ] + + def __init__( + self, + json_serializer=None, + json_deserializer=None, + is_mariadb=None, + **kwargs, + ): + kwargs.pop("use_ansiquotes", None) # legacy + default.DefaultDialect.__init__(self, **kwargs) + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + self._set_mariadb(is_mariadb, None) + + def get_isolation_level_values(self, dbapi_conn): + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ) + + def set_isolation_level(self, dbapi_connection, level): + cursor = dbapi_connection.cursor() + cursor.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}") + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + if self._is_mysql and self.server_version_info >= (5, 7, 20): + cursor.execute("SELECT @@transaction_isolation") + else: + cursor.execute("SELECT @@tx_isolation") + row = cursor.fetchone() + if row is None: + util.warn( + "Could not retrieve transaction isolation level for MySQL " + "connection." + ) + raise NotImplementedError() + val = row[0] + cursor.close() + if isinstance(val, bytes): + val = val.decode() + return val.upper().replace("-", " ") + + @classmethod + def _is_mariadb_from_url(cls, url): + dbapi = cls.import_dbapi() + dialect = cls(dbapi=dbapi) + + cargs, cparams = dialect.create_connect_args(url) + conn = dialect.connect(*cargs, **cparams) + try: + cursor = conn.cursor() + cursor.execute("SELECT VERSION() LIKE '%MariaDB%'") + val = cursor.fetchone()[0] + except: + raise + else: + return bool(val) + finally: + conn.close() + + def _get_server_version_info(self, connection): + # get database server version info explicitly over the wire + # to avoid proxy servers like MaxScale getting in the + # way with their own values, see #4205 + dbapi_con = connection.connection + cursor = dbapi_con.cursor() + cursor.execute("SELECT VERSION()") + val = cursor.fetchone()[0] + cursor.close() + if isinstance(val, bytes): + val = val.decode() + + return self._parse_server_version(val) + + def _parse_server_version(self, val): + version = [] + is_mariadb = False + + r = re.compile(r"[.\-+]") + tokens = r.split(val) + for token in tokens: + parsed_token = re.match( + r"^(?:(\d+)(?:a|b|c)?|(MariaDB\w*))$", token + ) + if not parsed_token: + continue + elif parsed_token.group(2): + self._mariadb_normalized_version_info = tuple(version[-3:]) + is_mariadb = True + else: + digit = int(parsed_token.group(1)) + version.append(digit) + + server_version_info = tuple(version) + + self._set_mariadb( + server_version_info and is_mariadb, server_version_info + ) + + if not is_mariadb: + self._mariadb_normalized_version_info = server_version_info + + if server_version_info < (5, 0, 2): + raise NotImplementedError( + "the MySQL/MariaDB dialect supports server " + "version info 5.0.2 and above." + ) + + # setting it here to help w the test suite + self.server_version_info = server_version_info + return server_version_info + + def _set_mariadb(self, is_mariadb, server_version_info): + if is_mariadb is None: + return + + if not is_mariadb and self.is_mariadb: + raise exc.InvalidRequestError( + "MySQL version %s is not a MariaDB variant." + % (".".join(map(str, server_version_info)),) + ) + if is_mariadb: + self.preparer = MariaDBIdentifierPreparer + # this would have been set by the default dialect already, + # so set it again + self.identifier_preparer = self.preparer(self) + + # this will be updated on first connect in initialize() + # if using older mariadb version + self.delete_returning = True + self.insert_returning = True + + self.is_mariadb = is_mariadb + + def do_begin_twophase(self, connection, xid): + connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid)) + + def do_prepare_twophase(self, connection, xid): + connection.execute(sql.text("XA END :xid"), dict(xid=xid)) + connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid)) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + connection.execute(sql.text("XA END :xid"), dict(xid=xid)) + connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid)) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid)) + + def do_recover_twophase(self, connection): + resultset = connection.exec_driver_sql("XA RECOVER") + return [row["data"][0 : row["gtrid_length"]] for row in resultset] + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, + ( + self.dbapi.OperationalError, + self.dbapi.ProgrammingError, + self.dbapi.InterfaceError, + ), + ) and self._extract_error_code(e) in ( + 1927, + 2006, + 2013, + 2014, + 2045, + 2055, + 4031, + ): + return True + elif isinstance( + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + ): + # if underlying connection is closed, + # this is the error you get + return "(0, '')" in str(e) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + """Proxy result rows to smooth over MySQL-Python driver + inconsistencies.""" + + return [_DecodingRow(row, charset) for row in rp.fetchall()] + + def _compat_fetchone(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + row = rp.fetchone() + if row: + return _DecodingRow(row, charset) + else: + return None + + def _compat_first(self, rp, charset=None): + """Proxy a result row to smooth over MySQL-Python driver + inconsistencies.""" + + row = rp.first() + if row: + return _DecodingRow(row, charset) + else: + return None + + def _extract_error_code(self, exception): + raise NotImplementedError() + + def _get_default_schema_name(self, connection): + return connection.exec_driver_sql("SELECT DATABASE()").scalar() + + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + + if schema is None: + schema = self.default_schema_name + + assert schema is not None + + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + + # DESCRIBE *must* be used because there is no information schema + # table that returns information on temp tables that is consistently + # available on MariaDB / MySQL / engine-agnostic etc. + # therefore we have no choice but to use DESCRIBE and an error catch + # to detect "False". See issue #9058 + + try: + with connection.exec_driver_sql( + f"DESCRIBE {full_name}", + execution_options={"skip_user_error_events": True}, + ) as rs: + return rs.fetchone() is not None + except exc.DBAPIError as e: + # https://dev.mysql.com/doc/mysql-errors/8.0/en/server-error-reference.html # noqa: E501 + # there are a lot of codes that *may* pop up here at some point + # but we continue to be fairly conservative. We include: + # 1146: Table '%s.%s' doesn't exist - what every MySQL has emitted + # for decades + # + # mysql 8 suddenly started emitting: + # 1049: Unknown database '%s' - for nonexistent schema + # + # also added: + # 1051: Unknown table '%s' - not known to emit + # + # there's more "doesn't exist" kinds of messages but they are + # less clear if mysql 8 would suddenly start using one of those + if self._extract_error_code(e.orig) in (1146, 1049, 1051): + return False + raise + + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + if not self.supports_sequences: + self._sequences_not_supported() + if not schema: + schema = self.default_schema_name + # MariaDB implements sequences as a special type of table + # + cursor = connection.execute( + sql.text( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_NAME=:name AND " + "TABLE_SCHEMA=:schema_name" + ), + dict( + name=str(sequence_name), + schema_name=str(schema), + ), + ) + return cursor.first() is not None + + def _sequences_not_supported(self): + raise NotImplementedError( + "Sequences are supported only by the " + "MariaDB series 10.3 or greater" + ) + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + if not self.supports_sequences: + self._sequences_not_supported() + if not schema: + schema = self.default_schema_name + # MariaDB implements sequences as a special type of table + cursor = connection.execute( + sql.text( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + "WHERE TABLE_TYPE='SEQUENCE' and TABLE_SCHEMA=:schema_name" + ), + dict(schema_name=schema), + ) + return [ + row[0] + for row in self._compat_fetchall( + cursor, charset=self._connection_charset + ) + ] + + def initialize(self, connection): + # this is driver-based, does not need server version info + # and is fairly critical for even basic SQL operations + self._connection_charset = self._detect_charset(connection) + + # call super().initialize() because we need to have + # server_version_info set up. in 1.4 under python 2 only this does the + # "check unicode returns" thing, which is the one area that some + # SQL gets compiled within initialize() currently + default.DefaultDialect.initialize(self, connection) + + self._detect_sql_mode(connection) + self._detect_ansiquotes(connection) # depends on sql mode + self._detect_casing(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer( + self, server_ansiquotes=self._server_ansiquotes + ) + + self.supports_sequences = ( + self.is_mariadb and self.server_version_info >= (10, 3) + ) + + self.supports_for_update_of = ( + self._is_mysql and self.server_version_info >= (8,) + ) + + self._needs_correct_for_88718_96365 = ( + not self.is_mariadb and self.server_version_info >= (8,) + ) + + self.delete_returning = ( + self.is_mariadb and self.server_version_info >= (10, 0, 5) + ) + + self.insert_returning = ( + self.is_mariadb and self.server_version_info >= (10, 5) + ) + + self._requires_alias_for_on_duplicate_key = ( + self._is_mysql and self.server_version_info >= (8, 0, 20) + ) + + self._warn_for_known_db_issues() + + def _warn_for_known_db_issues(self): + if self.is_mariadb: + mdb_version = self._mariadb_normalized_version_info + if mdb_version > (10, 2) and mdb_version < (10, 2, 9): + util.warn( + "MariaDB %r before 10.2.9 has known issues regarding " + "CHECK constraints, which impact handling of NULL values " + "with SQLAlchemy's boolean datatype (MDEV-13596). An " + "additional issue prevents proper migrations of columns " + "with CHECK constraints (MDEV-11114). Please upgrade to " + "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " + "series, to avoid these issues." % (mdb_version,) + ) + + @property + def _support_float_cast(self): + if not self.server_version_info: + return False + elif self.is_mariadb: + # ref https://mariadb.com/kb/en/mariadb-1045-release-notes/ + return self.server_version_info >= (10, 4, 5) + else: + # ref https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-17.html#mysqld-8-0-17-feature # noqa + return self.server_version_info >= (8, 0, 17) + + @property + def _is_mariadb(self): + return self.is_mariadb + + @property + def _is_mysql(self): + return not self.is_mariadb + + @property + def _is_mariadb_102(self): + return self.is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, + ) + + @reflection.cache + def get_schema_names(self, connection, **kw): + rp = connection.exec_driver_sql("SHOW schemas") + return [r[0] for r in rp] + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + charset = self._connection_charset + + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + charset = self._connection_charset + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + if parsed_state.table_options: + return parsed_state.table_options + else: + return ReflectionDefaults.table_options() + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + if parsed_state.columns: + return parsed_state.columns + else: + return ReflectionDefaults.columns() + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + for key in parsed_state.keys: + if key["type"] == "PRIMARY": + # There can be only one. + cols = [s[0] for s in key["columns"]] + return {"constrained_columns": cols, "name": None} + return ReflectionDefaults.pk_constraint() + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + default_schema = None + + fkeys = [] + + for spec in parsed_state.fk_constraints: + ref_name = spec["table"][-1] + ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema + + if not ref_schema: + if default_schema is None: + default_schema = connection.dialect.default_schema_name + if schema == default_schema: + ref_schema = schema + + loc_names = spec["local"] + ref_names = spec["foreign"] + + con_kw = {} + for opt in ("onupdate", "ondelete"): + if spec.get(opt, False) not in ("NO ACTION", None): + con_kw[opt] = spec[opt] + + fkey_d = { + "name": spec["name"], + "constrained_columns": loc_names, + "referred_schema": ref_schema, + "referred_table": ref_name, + "referred_columns": ref_names, + "options": con_kw, + } + fkeys.append(fkey_d) + + if self._needs_correct_for_88718_96365: + self._correct_for_mysql_bugs_88718_96365(fkeys, connection) + + return fkeys if fkeys else ReflectionDefaults.foreign_keys() + + def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): + # Foreign key is always in lower case (MySQL 8.0) + # https://bugs.mysql.com/bug.php?id=88718 + # issue #4344 for SQLAlchemy + + # table name also for MySQL 8.0 + # https://bugs.mysql.com/bug.php?id=96365 + # issue #4751 for SQLAlchemy + + # for lower_case_table_names=2, information_schema.columns + # preserves the original table/schema casing, but SHOW CREATE + # TABLE does not. this problem is not in lower_case_table_names=1, + # but use case-insensitive matching for these two modes in any case. + + if self._casing in (1, 2): + + def lower(s): + return s.lower() + + else: + # if on case sensitive, there can be two tables referenced + # with the same name different casing, so we need to use + # case-sensitive matching. + def lower(s): + return s + + default_schema_name = connection.dialect.default_schema_name + col_tuples = [ + ( + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, + ) + for rec in fkeys + for col_name in rec["referred_columns"] + ] + + if col_tuples: + correct_for_wrong_fk_case = connection.execute( + sql.text( + """ + select table_schema, table_name, column_name + from information_schema.columns + where (table_schema, table_name, lower(column_name)) in + :table_data; + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + dict(table_data=col_tuples), + ) + + # in casing=0, table name and schema name come back in their + # exact case. + # in casing=1, table name and schema name come back in lower + # case. + # in casing=2, table name and schema name come back from the + # information_schema.columns view in the case + # that was used in CREATE DATABASE and CREATE TABLE, but + # SHOW CREATE TABLE converts them to *lower case*, therefore + # not matching. So for this case, case-insensitive lookup + # is necessary + d = defaultdict(dict) + for schema, tname, cname in correct_for_wrong_fk_case: + d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema + d[(lower(schema), lower(tname))]["TABLENAME"] = tname + d[(lower(schema), lower(tname))][cname.lower()] = cname + + for fkey in fkeys: + rec = d[ + ( + lower(fkey["referred_schema"] or default_schema_name), + lower(fkey["referred_table"]), + ) + ] + + fkey["referred_table"] = rec["TABLENAME"] + if fkey["referred_schema"] is not None: + fkey["referred_schema"] = rec["SCHEMANAME"] + + fkey["referred_columns"] = [ + rec[col.lower()] for col in fkey["referred_columns"] + ] + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + cks = [ + {"name": spec["name"], "sqltext": spec["sqltext"]} + for spec in parsed_state.ck_constraints + ] + cks.sort(key=lambda d: d["name"] or "~") # sort None as last + return cks if cks else ReflectionDefaults.check_constraints() + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + comment = parsed_state.table_options.get(f"{self.name}_comment", None) + if comment is not None: + return {"text": comment} + else: + return ReflectionDefaults.table_comment() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + indexes = [] + + for spec in parsed_state.keys: + dialect_options = {} + unique = False + flavor = spec["type"] + if flavor == "PRIMARY": + continue + if flavor == "UNIQUE": + unique = True + elif flavor in ("FULLTEXT", "SPATIAL"): + dialect_options["%s_prefix" % self.name] = flavor + elif flavor is None: + pass + else: + self.logger.info( + "Converting unknown KEY type %s to a plain KEY", flavor + ) + pass + + if spec["parser"]: + dialect_options["%s_with_parser" % (self.name)] = spec[ + "parser" + ] + + index_d = {} + + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] + mysql_length = { + s[0]: s[1] for s in spec["columns"] if s[1] is not None + } + if mysql_length: + dialect_options["%s_length" % self.name] = mysql_length + + index_d["unique"] = unique + if flavor: + index_d["type"] = flavor + + if dialect_options: + index_d["dialect_options"] = dialect_options + + indexes.append(index_d) + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + return indexes if indexes else ReflectionDefaults.indexes() + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + parsed_state = self._parsed_state_or_create( + connection, table_name, schema, **kw + ) + + ucs = [ + { + "name": key["name"], + "column_names": [col[0] for col in key["columns"]], + "duplicates_index": key["name"], + } + for key in parsed_state.keys + if key["type"] == "UNIQUE" + ] + ucs.sort(key=lambda d: d["name"] or "~") # sort None as last + if ucs: + return ucs + else: + return ReflectionDefaults.unique_constraints() + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + charset = self._connection_charset + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, view_name) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if sql.upper().startswith("CREATE TABLE"): + # it's a table, not a view + raise exc.NoSuchTableError(full_name) + return sql + + def _parsed_state_or_create( + self, connection, table_name, schema=None, **kw + ): + return self._setup_parser( + connection, + table_name, + schema, + info_cache=kw.get("info_cache", None), + ) + + @util.memoized_property + def _tabledef_parser(self): + """return the MySQLTableDefinitionParser, generate if needed. + + The deferred creation ensures that the dialect has + retrieved server version information first. + + """ + preparer = self.identifier_preparer + return _reflection.MySQLTableDefinitionParser(self, preparer) + + @reflection.cache + def _setup_parser(self, connection, table_name, schema=None, **kw): + charset = self._connection_charset + parser = self._tabledef_parser + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if parser._check_view(sql): + # Adapt views to something table-like. + columns = self._describe_table( + connection, None, charset, full_name=full_name + ) + sql = parser._describe_to_create(table_name, columns) + return parser.parse(sql, charset) + + def _fetch_setting(self, connection, setting_name): + charset = self._connection_charset + + if self.server_version_info and self.server_version_info < (5, 6): + sql = "SHOW VARIABLES LIKE '%s'" % setting_name + fetch_col = 1 + else: + sql = "SELECT @@%s" % setting_name + fetch_col = 0 + + show_var = connection.exec_driver_sql(sql) + row = self._compat_first(show_var, charset=charset) + if not row: + return None + else: + return row[fetch_col] + + def _detect_charset(self, connection): + raise NotImplementedError() + + def _detect_casing(self, connection): + """Sniff out identifier case sensitivity. + + Cached per-connection. This value can not change without a server + restart. + + """ + # https://dev.mysql.com/doc/refman/en/identifier-case-sensitivity.html + + setting = self._fetch_setting(connection, "lower_case_table_names") + if setting is None: + cs = 0 + else: + # 4.0.15 returns OFF or ON according to [ticket:489] + # 3.23 doesn't, 4.0.27 doesn't.. + if setting == "OFF": + cs = 0 + elif setting == "ON": + cs = 1 + else: + cs = int(setting) + self._casing = cs + return cs + + def _detect_collations(self, connection): + """Pull the active COLLATIONS list from the server. + + Cached per-connection. + """ + + collations = {} + charset = self._connection_charset + rs = connection.exec_driver_sql("SHOW COLLATION") + for row in self._compat_fetchall(rs, charset): + collations[row[0]] = row[1] + return collations + + def _detect_sql_mode(self, connection): + setting = self._fetch_setting(connection, "sql_mode") + + if setting is None: + util.warn( + "Could not retrieve SQL_MODE; please ensure the " + "MySQL user has permissions to SHOW VARIABLES" + ) + self._sql_mode = "" + else: + self._sql_mode = setting or "" + + def _detect_ansiquotes(self, connection): + """Detect and adjust for the ANSI_QUOTES sql mode.""" + + mode = self._sql_mode + if not mode: + mode = "" + elif mode.isdigit(): + mode_no = int(mode) + mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or "" + + self._server_ansiquotes = "ANSI_QUOTES" in mode + + # as of MySQL 5.0.1 + self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode + + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): + """Run SHOW CREATE TABLE for a ``Table``.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "SHOW CREATE TABLE %s" % full_name + + rp = None + try: + rp = connection.execution_options( + skip_user_error_events=True + ).exec_driver_sql(st) + except exc.DBAPIError as e: + if self._extract_error_code(e.orig) == 1146: + raise exc.NoSuchTableError(full_name) from e + else: + raise + row = self._compat_first(rp, charset=charset) + if not row: + raise exc.NoSuchTableError(full_name) + return row[1].strip() + + def _describe_table(self, connection, table, charset=None, full_name=None): + """Run DESCRIBE for a ``Table`` and return processed rows.""" + + if full_name is None: + full_name = self.identifier_preparer.format_table(table) + st = "DESCRIBE %s" % full_name + + rp, rows = None, None + try: + try: + rp = connection.execution_options( + skip_user_error_events=True + ).exec_driver_sql(st) + except exc.DBAPIError as e: + code = self._extract_error_code(e.orig) + if code == 1146: + raise exc.NoSuchTableError(full_name) from e + + elif code == 1356: + raise exc.UnreflectableTableError( + "Table or view named %s could not be " + "reflected: %s" % (full_name, e) + ) from e + + else: + raise + rows = self._compat_fetchall(rp, charset=charset) + finally: + if rp: + rp.close() + return rows + + +class _DecodingRow: + """Return unicode-decoded values based on type inspection. + + Smooth over data type issues (esp. with alpha driver versions) and + normalize strings as Unicode regardless of user-configured driver + encoding settings. + + """ + + # Some MySQL-python versions can return some columns as + # sets.Set(['value']) (seriously) but thankfully that doesn't + # seem to come up in DDL queries. + + _encoding_compat = { + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "utf16": "utf-16-be", # MySQL's uft16 is always bigendian + "utf8mb4": "utf8", # real utf8 + "utf8mb3": "utf8", # real utf8; saw this happen on CI but I cannot + # reproduce, possibly mariadb10.6 related + "eucjpms": "ujis", + } + + def __init__(self, rowproxy, charset): + self.rowproxy = rowproxy + self.charset = self._encoding_compat.get(charset, charset) + + def __getitem__(self, index): + item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + + if self.charset and isinstance(item, bytes): + return item.decode(self.charset) + else: + return item + + def __getattr__(self, attr): + item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, bytes): + return item.decode(self.charset) + else: + return item diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/cymysql.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/cymysql.py new file mode 100644 index 0000000..f199aa4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/cymysql.py @@ -0,0 +1,84 @@ +# dialects/mysql/cymysql.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" + +.. dialect:: mysql+cymysql + :name: CyMySQL + :dbapi: cymysql + :connectstring: mysql+cymysql://:@/[?] + :url: https://github.com/nakagami/CyMySQL + +.. note:: + + The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous + integration** and may have unresolved issues. The recommended MySQL + dialects are mysqlclient and PyMySQL. + +""" # noqa + +from .base import BIT +from .base import MySQLDialect +from .mysqldb import MySQLDialect_mysqldb +from ... import util + + +class _cymysqlBIT(BIT): + def result_processor(self, dialect, coltype): + """Convert MySQL's 64 bit, variable length binary string to a long.""" + + def process(value): + if value is not None: + v = 0 + for i in iter(value): + v = v << 8 | i + return v + return value + + return process + + +class MySQLDialect_cymysql(MySQLDialect_mysqldb): + driver = "cymysql" + supports_statement_cache = True + + description_encoding = None + supports_sane_rowcount = True + supports_sane_multi_rowcount = False + supports_unicode_statements = True + + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) + + @classmethod + def import_dbapi(cls): + return __import__("cymysql") + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.OperationalError): + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) + elif isinstance(e, self.dbapi.InterfaceError): + # if underlying connection is closed, + # this is the error you get + return True + else: + return False + + +dialect = MySQLDialect_cymysql diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/dml.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/dml.py new file mode 100644 index 0000000..e4005c2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/dml.py @@ -0,0 +1,219 @@ +# dialects/mysql/dml.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +from typing import Any +from typing import List +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Union + +from ... import exc +from ... import util +from ...sql._typing import _DMLTableArgument +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.base import ReadOnlyColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.elements import KeyedColumnElement +from ...sql.expression import alias +from ...sql.selectable import NamedFromClause +from ...util.typing import Self + + +__all__ = ("Insert", "insert") + + +def insert(table: _DMLTableArgument) -> Insert: + """Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert` + construct. + + .. container:: inherited_member + + The :func:`sqlalchemy.dialects.mysql.insert` function creates + a :class:`sqlalchemy.dialects.mysql.Insert`. This class is based + on the dialect-agnostic :class:`_sql.Insert` construct which may + be constructed using the :func:`_sql.insert` function in + SQLAlchemy Core. + + The :class:`_mysql.Insert` construct includes additional methods + :meth:`_mysql.Insert.on_duplicate_key_update`. + + """ + return Insert(table) + + +class Insert(StandardInsert): + """MySQL-specific implementation of INSERT. + + Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE. + + The :class:`~.mysql.Insert` object is created using the + :func:`sqlalchemy.dialects.mysql.insert` function. + + .. versionadded:: 1.2 + + """ + + stringify_dialect = "mysql" + inherit_cache = False + + @property + def inserted( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + """Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE + statement + + MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row + that would be inserted, via a special function called ``VALUES()``. + This attribute provides all columns in this row to be referenceable + such that they will render within a ``VALUES()`` function inside the + ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted`` + so as not to conflict with the existing + :meth:`_expression.Insert.values` method. + + .. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance + of :class:`_expression.ColumnCollection`, which provides an + interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.inserted.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.inserted["column name"]`` or + ``stmt.inserted["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + .. seealso:: + + :ref:`mysql_insert_on_duplicate_key_update` - example of how + to use :attr:`_expression.Insert.inserted` + + """ + return self.inserted_alias.columns + + @util.memoized_property + def inserted_alias(self) -> NamedFromClause: + return alias(self.table, name="inserted") + + @_generative + @_exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already " + "has an ON DUPLICATE KEY clause present" + }, + ) + def on_duplicate_key_update(self, *args: _UpdateArg, **kw: Any) -> Self: + r""" + Specifies the ON DUPLICATE KEY UPDATE clause. + + :param \**kw: Column keys linked to UPDATE values. The + values may be any SQL expression or supported literal Python + values. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON DUPLICATE KEY UPDATE + style of UPDATE, unless values are manually specified here. + + :param \*args: As an alternative to passing key/value parameters, + a dictionary or list of 2-tuples can be passed as a single positional + argument. + + Passing a single dictionary is equivalent to the keyword argument + form:: + + insert().on_duplicate_key_update({"name": "some name"}) + + Passing a list of 2-tuples indicates that the parameter assignments + in the UPDATE clause should be ordered as sent, in a manner similar + to that described for the :class:`_expression.Update` + construct overall + in :ref:`tutorial_parameter_ordered_updates`:: + + insert().on_duplicate_key_update( + [("name", "some name"), ("value", "some value")]) + + .. versionchanged:: 1.3 parameters can be specified as a dictionary + or list of 2-tuples; the latter form provides for parameter + ordering. + + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`mysql_insert_on_duplicate_key_update` + + """ + if args and kw: + raise exc.ArgumentError( + "Can't pass kwargs and positional arguments simultaneously" + ) + + if args: + if len(args) > 1: + raise exc.ArgumentError( + "Only a single dictionary or list of tuples " + "is accepted positionally." + ) + values = args[0] + else: + values = kw + + self._post_values_clause = OnDuplicateClause( + self.inserted_alias, values + ) + return self + + +class OnDuplicateClause(ClauseElement): + __visit_name__ = "on_duplicate_key_update" + + _parameter_ordering: Optional[List[str]] = None + + stringify_dialect = "mysql" + + def __init__( + self, inserted_alias: NamedFromClause, update: _UpdateArg + ) -> None: + self.inserted_alias = inserted_alias + + # auto-detect that parameters should be ordered. This is copied from + # Update._proces_colparams(), however we don't look for a special flag + # in this case since we are not disambiguating from other use cases as + # we are in Update.values(). + if isinstance(update, list) and ( + update and isinstance(update[0], tuple) + ): + self._parameter_ordering = [key for key, value in update] + update = dict(update) + + if isinstance(update, dict): + if not update: + raise ValueError( + "update parameter dictionary must not be empty" + ) + elif isinstance(update, ColumnCollection): + update = dict(update) + else: + raise ValueError( + "update parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update = update + + +_UpdateArg = Union[ + Mapping[Any, Any], List[Tuple[str, Any]], ColumnCollection[Any, Any] +] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/enumerated.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/enumerated.py new file mode 100644 index 0000000..96499d7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/enumerated.py @@ -0,0 +1,244 @@ +# dialects/mysql/enumerated.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import re + +from .types import _StringType +from ... import exc +from ... import sql +from ... import util +from ...sql import sqltypes + + +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): + """MySQL ENUM type.""" + + __visit_name__ = "ENUM" + + native_enum = True + + def __init__(self, *enums, **kw): + """Construct an ENUM. + + E.g.:: + + Column('myenum', ENUM("foo", "bar", "baz")) + + :param enums: The range of valid values for this ENUM. Values in + enums are not quoted, they will be escaped and surrounded by single + quotes when generating the schema. This object may also be a + PEP-435-compliant enumerated type. + + .. versionadded: 1.1 added support for PEP-435-compliant enumerated + types. + + :param strict: This flag has no effect. + + .. versionchanged:: The MySQL ENUM type as well as the base Enum + type now validates all Python data values. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + kw.pop("strict", None) + self._enum_init(enums, kw) + _StringType.__init__(self, length=self.length, **kw) + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a MySQL native :class:`.mysql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + return cls(**kw) + + def _object_value_for_elem(self, elem): + # mysql sends back a blank string for any value that + # was persisted that was not in the enums; that is, it does no + # validation on the incoming data, it "truncates" it to be + # the blank string. Return it straight. + if elem == "": + return elem + else: + return super()._object_value_for_elem(elem) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[ENUM, _StringType, sqltypes.Enum] + ) + + +class SET(_StringType): + """MySQL SET type.""" + + __visit_name__ = "SET" + + def __init__(self, *values, **kw): + """Construct a SET. + + E.g.:: + + Column('myset', SET("foo", "bar", "baz")) + + + The list of potential values is required in the case that this + set will be used to generate DDL for a table, or if the + :paramref:`.SET.retrieve_as_bitwise` flag is set to True. + + :param values: The range of valid values for this SET. The values + are not quoted, they will be escaped and surrounded by single + quotes when generating the schema. + + :param convert_unicode: Same flag as that of + :paramref:`.String.convert_unicode`. + + :param collation: same as that of :paramref:`.String.collation` + + :param charset: same as that of :paramref:`.VARCHAR.charset`. + + :param ascii: same as that of :paramref:`.VARCHAR.ascii`. + + :param unicode: same as that of :paramref:`.VARCHAR.unicode`. + + :param binary: same as that of :paramref:`.VARCHAR.binary`. + + :param retrieve_as_bitwise: if True, the data for the set type will be + persisted and selected using an integer value, where a set is coerced + into a bitwise mask for persistence. MySQL allows this mode which + has the advantage of being able to store values unambiguously, + such as the blank string ``''``. The datatype will appear + as the expression ``col + 0`` in a SELECT statement, so that the + value is coerced into an integer value in result sets. + This flag is required if one wishes + to persist a set that can store the blank string ``''`` as a value. + + .. warning:: + + When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is + essential that the list of set values is expressed in the + **exact same order** as exists on the MySQL database. + + """ + self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False) + self.values = tuple(values) + if not self.retrieve_as_bitwise and "" in values: + raise exc.ArgumentError( + "Can't use the blank value '' in a SET without " + "setting retrieve_as_bitwise=True" + ) + if self.retrieve_as_bitwise: + self._bitmap = { + value: 2**idx for idx, value in enumerate(self.values) + } + self._bitmap.update( + (2**idx, value) for idx, value in enumerate(self.values) + ) + length = max([len(v) for v in values] + [0]) + kw.setdefault("length", length) + super().__init__(**kw) + + def column_expression(self, colexpr): + if self.retrieve_as_bitwise: + return sql.type_coerce( + sql.type_coerce(colexpr, sqltypes.Integer) + 0, self + ) + else: + return colexpr + + def result_processor(self, dialect, coltype): + if self.retrieve_as_bitwise: + + def process(value): + if value is not None: + value = int(value) + + return set(util.map_bits(self._bitmap.__getitem__, value)) + else: + return None + + else: + super_convert = super().result_processor(dialect, coltype) + + def process(value): + if isinstance(value, str): + # MySQLdb returns a string, let's parse + if super_convert: + value = super_convert(value) + return set(re.findall(r"[^,]+", value)) + else: + # mysql-connector-python does a naive + # split(",") which throws in an empty string + if value is not None: + value.discard("") + return value + + return process + + def bind_processor(self, dialect): + super_convert = super().bind_processor(dialect) + if self.retrieve_as_bitwise: + + def process(value): + if value is None: + return None + elif isinstance(value, (int, str)): + if super_convert: + return super_convert(value) + else: + return value + else: + int_value = 0 + for v in value: + int_value |= self._bitmap[v] + return int_value + + else: + + def process(value): + # accept strings and int (actually bitflag) values directly + if value is not None and not isinstance(value, (int, str)): + value = ",".join(value) + + if super_convert: + return super_convert(value) + else: + return value + + return process + + def adapt(self, impltype, **kw): + kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise + return util.constructor_copy(self, impltype, *self.values, **kw) + + def __repr__(self): + return util.generic_repr( + self, + to_inspect=[SET, _StringType], + additional_kw=[ + ("retrieve_as_bitwise", False), + ], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/expression.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/expression.py new file mode 100644 index 0000000..b81b58a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/expression.py @@ -0,0 +1,141 @@ +# dialects/mysql/expression.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from ... import exc +from ... import util +from ...sql import coercions +from ...sql import elements +from ...sql import operators +from ...sql import roles +from ...sql.base import _generative +from ...sql.base import Generative +from ...util.typing import Self + + +class match(Generative, elements.BinaryExpression): + """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause. + + E.g.:: + + from sqlalchemy import desc + from sqlalchemy.dialects.mysql import match + + match_expr = match( + users_table.c.firstname, + users_table.c.lastname, + against="Firstname Lastname", + ) + + stmt = ( + select(users_table) + .where(match_expr.in_boolean_mode()) + .order_by(desc(match_expr)) + ) + + Would produce SQL resembling:: + + SELECT id, firstname, lastname + FROM user + WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE) + ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC + + The :func:`_mysql.match` function is a standalone version of the + :meth:`_sql.ColumnElement.match` method available on all + SQL expressions, as when :meth:`_expression.ColumnElement.match` is + used, but allows to pass multiple columns + + :param cols: column expressions to match against + + :param against: expression to be compared towards + + :param in_boolean_mode: boolean, set "boolean mode" to true + + :param in_natural_language_mode: boolean , set "natural language" to true + + :param with_query_expansion: boolean, set "query expansion" to true + + .. versionadded:: 1.4.19 + + .. seealso:: + + :meth:`_expression.ColumnElement.match` + + """ + + __visit_name__ = "mysql_match" + + inherit_cache = True + + def __init__(self, *cols, **kw): + if not cols: + raise exc.ArgumentError("columns are required") + + against = kw.pop("against", None) + + if against is None: + raise exc.ArgumentError("against is required") + against = coercions.expect( + roles.ExpressionElementRole, + against, + ) + + left = elements.BooleanClauseList._construct_raw( + operators.comma_op, + clauses=cols, + ) + left.group = False + + flags = util.immutabledict( + { + "mysql_boolean_mode": kw.pop("in_boolean_mode", False), + "mysql_natural_language": kw.pop( + "in_natural_language_mode", False + ), + "mysql_query_expansion": kw.pop("with_query_expansion", False), + } + ) + + if kw: + raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw))) + + super().__init__(left, against, operators.match_op, modifiers=flags) + + @_generative + def in_boolean_mode(self) -> Self: + """Apply the "IN BOOLEAN MODE" modifier to the MATCH expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_boolean_mode": True}) + return self + + @_generative + def in_natural_language_mode(self) -> Self: + """Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH + expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_natural_language": True}) + return self + + @_generative + def with_query_expansion(self) -> Self: + """Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression. + + :return: a new :class:`_mysql.match` instance with modifications + applied. + """ + + self.modifiers = self.modifiers.union({"mysql_query_expansion": True}) + return self diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/json.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/json.py new file mode 100644 index 0000000..ebe4a34 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/json.py @@ -0,0 +1,81 @@ +# dialects/mysql/json.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import types as sqltypes + + +class JSON(sqltypes.JSON): + """MySQL JSON type. + + MySQL supports JSON as of version 5.7. + MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2. + + :class:`_mysql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a MySQL or MariaDB backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`.mysql.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_EXTRACT`` + function at the database level. + + """ + + pass + + +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadb.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadb.py new file mode 100644 index 0000000..10a05f9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadb.py @@ -0,0 +1,32 @@ +# dialects/mysql/mariadb.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from .base import MariaDBIdentifierPreparer +from .base import MySQLDialect + + +class MariaDBDialect(MySQLDialect): + is_mariadb = True + supports_statement_cache = True + name = "mariadb" + preparer = MariaDBIdentifierPreparer + + +def loader(driver): + driver_mod = __import__( + "sqlalchemy.dialects.mysql.%s" % driver + ).dialects.mysql + driver_cls = getattr(driver_mod, driver).dialect + + return type( + "MariaDBDialect_%s" % driver, + ( + MariaDBDialect, + driver_cls, + ), + {"supports_statement_cache": True}, + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py new file mode 100644 index 0000000..9bb3fa4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -0,0 +1,275 @@ +# dialects/mysql/mariadbconnector.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" + +.. dialect:: mysql+mariadbconnector + :name: MariaDB Connector/Python + :dbapi: mariadb + :connectstring: mariadb+mariadbconnector://:@[:]/ + :url: https://pypi.org/project/mariadb/ + +Driver Status +------------- + +MariaDB Connector/Python enables Python programs to access MariaDB and MySQL +databases using an API which is compliant with the Python DB API 2.0 (PEP-249). +It is written in C and uses MariaDB Connector/C client library for client server +communication. + +Note that the default driver for a ``mariadb://`` connection URI continues to +be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver. + +.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python + +""" # noqa +import re +from uuid import UUID as _python_UUID + +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLExecutionContext +from ... import sql +from ... import util +from ...sql import sqltypes + + +mariadb_cpy_minimum_version = (1, 0, 1) + + +class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]): + # work around JIRA issue + # https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed, + # this type can be removed. + def result_processor(self, dialect, coltype): + if self.as_uuid: + + def process(value): + if value is not None: + if hasattr(value, "decode"): + value = value.decode("ascii") + value = _python_UUID(value) + return value + + return process + else: + + def process(value): + if value is not None: + if hasattr(value, "decode"): + value = value.decode("ascii") + value = str(_python_UUID(value)) + return value + + return process + + +class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext): + _lastrowid = None + + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(buffered=False) + + def create_default_cursor(self): + return self._dbapi_connection.cursor(buffered=True) + + def post_exec(self): + super().post_exec() + + self._rowcount = self.cursor.rowcount + + if self.isinsert and self.compiled.postfetch_lastrowid: + self._lastrowid = self.cursor.lastrowid + + def get_lastrowid(self): + return self._lastrowid + + +class MySQLCompiler_mariadbconnector(MySQLCompiler): + pass + + +class MySQLDialect_mariadbconnector(MySQLDialect): + driver = "mariadbconnector" + supports_statement_cache = True + + # set this to True at the module level to prevent the driver from running + # against a backend that server detects as MySQL. currently this appears to + # be unnecessary as MariaDB client libraries have always worked against + # MySQL databases. However, if this changes at some point, this can be + # adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if + # this change is made at some point to ensure the correct exception + # is raised at the correct point when running the driver against + # a MySQL backend. + # is_mariadb = True + + supports_unicode_statements = True + encoding = "utf8mb4" + convert_unicode = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_native_decimal = True + default_paramstyle = "qmark" + execution_ctx_cls = MySQLExecutionContext_mariadbconnector + statement_compiler = MySQLCompiler_mariadbconnector + + supports_server_side_cursors = True + + colspecs = util.update_copy( + MySQLDialect.colspecs, {sqltypes.Uuid: _MariaDBUUID} + ) + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.paramstyle = "qmark" + if self.dbapi is not None: + if self._dbapi_version < mariadb_cpy_minimum_version: + raise NotImplementedError( + "The minimum required version for MariaDB " + "Connector/Python is %s" + % ".".join(str(x) for x in mariadb_cpy_minimum_version) + ) + + @classmethod + def import_dbapi(cls): + return __import__("mariadb") + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + elif isinstance(e, self.dbapi.Error): + str_e = str(e).lower() + return "not connected" in str_e or "isn't valid" in str_e + else: + return False + + def create_connect_args(self, url): + opts = url.translate_connect_args() + + int_params = [ + "connect_timeout", + "read_timeout", + "write_timeout", + "client_flag", + "port", + "pool_size", + ] + bool_params = [ + "local_infile", + "ssl_verify_cert", + "ssl", + "pool_reset_connection", + ] + + for key in int_params: + util.coerce_kw_type(opts, key, int) + for key in bool_params: + util.coerce_kw_type(opts, key, bool) + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get("client_flag", 0) + if self.dbapi is not None: + try: + CLIENT_FLAGS = __import__( + self.dbapi.__name__ + ".constants.CLIENT" + ).constants.CLIENT + client_flag |= CLIENT_FLAGS.FOUND_ROWS + except (AttributeError, ImportError): + self.supports_sane_rowcount = False + opts["client_flag"] = client_flag + return [[], opts] + + def _extract_error_code(self, exception): + try: + rc = exception.errno + except: + rc = -1 + return rc + + def _detect_charset(self, connection): + return "utf8mb4" + + def get_isolation_level_values(self, dbapi_connection): + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ) + + def set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit = True + else: + connection.autocommit = False + super().set_isolation_level(connection, level) + + def do_begin_twophase(self, connection, xid): + connection.execute( + sql.text("XA BEGIN :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_prepare_twophase(self, connection, xid): + connection.execute( + sql.text("XA END :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + connection.execute( + sql.text("XA PREPARE :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + connection.execute( + sql.text("XA END :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + connection.execute( + sql.text("XA ROLLBACK :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_prepare_twophase(connection, xid) + connection.execute( + sql.text("XA COMMIT :xid").bindparams( + sql.bindparam("xid", xid, literal_execute=True) + ) + ) + + +dialect = MySQLDialect_mariadbconnector diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py new file mode 100644 index 0000000..b152339 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -0,0 +1,179 @@ +# dialects/mysql/mysqlconnector.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: mysql+mysqlconnector + :name: MySQL Connector/Python + :dbapi: myconnpy + :connectstring: mysql+mysqlconnector://:@[:]/ + :url: https://pypi.org/project/mysql-connector-python/ + +.. note:: + + The MySQL Connector/Python DBAPI has had many issues since its release, + some of which may remain unresolved, and the mysqlconnector dialect is + **not tested as part of SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. + +""" # noqa + +import re + +from .base import BIT +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLIdentifierPreparer +from ... import util + + +class MySQLCompiler_mysqlconnector(MySQLCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) + + +class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): + @property + def _double_percents(self): + return False + + @_double_percents.setter + def _double_percents(self, value): + pass + + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) + return value + + +class _myconnpyBIT(BIT): + def result_processor(self, dialect, coltype): + """MySQL-connector already converts mysql bits, so.""" + + return None + + +class MySQLDialect_mysqlconnector(MySQLDialect): + driver = "mysqlconnector" + supports_statement_cache = True + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = "format" + statement_compiler = MySQLCompiler_mysqlconnector + + preparer = MySQLIdentifierPreparer_mysqlconnector + + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) + + @classmethod + def import_dbapi(cls): + from mysql import connector + + return connector + + def do_ping(self, dbapi_connection): + dbapi_connection.ping(False) + return True + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + + opts.update(url.query) + + util.coerce_kw_type(opts, "allow_local_infile", bool) + util.coerce_kw_type(opts, "autocommit", bool) + util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connection_timeout", int) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "consume_results", bool) + util.coerce_kw_type(opts, "force_ipv6", bool) + util.coerce_kw_type(opts, "get_warnings", bool) + util.coerce_kw_type(opts, "pool_reset_session", bool) + util.coerce_kw_type(opts, "pool_size", int) + util.coerce_kw_type(opts, "raise_on_warnings", bool) + util.coerce_kw_type(opts, "raw", bool) + util.coerce_kw_type(opts, "ssl_verify_cert", bool) + util.coerce_kw_type(opts, "use_pure", bool) + util.coerce_kw_type(opts, "use_unicode", bool) + + # unfortunately, MySQL/connector python refuses to release a + # cursor without reading fully, so non-buffered isn't an option + opts.setdefault("buffered", True) + + # FOUND_ROWS must be set in ClientFlag to enable + # supports_sane_rowcount. + if self.dbapi is not None: + try: + from mysql.connector.constants import ClientFlag + + client_flags = opts.get( + "client_flags", ClientFlag.get_default() + ) + client_flags |= ClientFlag.FOUND_ROWS + opts["client_flags"] = client_flags + except Exception: + pass + return [[], opts] + + @util.memoized_property + def _mysqlconnector_version_info(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + + def _detect_charset(self, connection): + return connection.connection.charset + + def _extract_error_code(self, exception): + return exception.errno + + def is_disconnect(self, e, connection, cursor): + errnos = (2006, 2013, 2014, 2045, 2055, 2048) + exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) + if isinstance(e, exceptions): + return ( + e.errno in errnos + or "MySQL Connection not available." in str(e) + or "Connection to MySQL is not available" in str(e) + ) + else: + return False + + def _compat_fetchall(self, rp, charset=None): + return rp.fetchall() + + def _compat_fetchone(self, rp, charset=None): + return rp.fetchone() + + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + } + + def _set_isolation_level(self, connection, level): + if level == "AUTOCOMMIT": + connection.autocommit = True + else: + connection.autocommit = False + super()._set_isolation_level(connection, level) + + +dialect = MySQLDialect_mysqlconnector diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqldb.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqldb.py new file mode 100644 index 0000000..0c632b6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/mysqldb.py @@ -0,0 +1,303 @@ +# dialects/mysql/mysqldb.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" + +.. dialect:: mysql+mysqldb + :name: mysqlclient (maintained fork of MySQL-Python) + :dbapi: mysqldb + :connectstring: mysql+mysqldb://:@[:]/ + :url: https://pypi.org/project/mysqlclient/ + +Driver Status +------------- + +The mysqlclient DBAPI is a maintained fork of the +`MySQL-Python `_ DBAPI +that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3 +and is very stable. + +.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python + +.. _mysqldb_unicode: + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + +.. _mysqldb_ssl: + +SSL Connections +---------------- + +The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the +key "ssl", which may be specified using the +:paramref:`_sa.create_engine.connect_args` dictionary:: + + engine = create_engine( + "mysql+mysqldb://scott:tiger@192.168.0.134/test", + connect_args={ + "ssl": { + "ca": "/home/gord/client-ssl/ca.pem", + "cert": "/home/gord/client-ssl/client-cert.pem", + "key": "/home/gord/client-ssl/client-key.pem" + } + } + ) + +For convenience, the following keys may also be specified inline within the URL +where they will be interpreted into the "ssl" dictionary automatically: +"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher", +"ssl_check_hostname". An example is as follows:: + + connection_uri = ( + "mysql+mysqldb://scott:tiger@192.168.0.134/test" + "?ssl_ca=/home/gord/client-ssl/ca.pem" + "&ssl_cert=/home/gord/client-ssl/client-cert.pem" + "&ssl_key=/home/gord/client-ssl/client-key.pem" + ) + +.. seealso:: + + :ref:`pymysql_ssl` in the PyMySQL dialect + + +Using MySQLdb with Google Cloud SQL +----------------------------------- + +Google Cloud SQL now recommends use of the MySQLdb dialect. Connect +using a URL like the following:: + + mysql+mysqldb://root@/?unix_socket=/cloudsql/: + +Server Side Cursors +------------------- + +The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. + +""" + +import re + +from .base import MySQLCompiler +from .base import MySQLDialect +from .base import MySQLExecutionContext +from .base import MySQLIdentifierPreparer +from .base import TEXT +from ... import sql +from ... import util + + +class MySQLExecutionContext_mysqldb(MySQLExecutionContext): + pass + + +class MySQLCompiler_mysqldb(MySQLCompiler): + pass + + +class MySQLDialect_mysqldb(MySQLDialect): + driver = "mysqldb" + supports_statement_cache = True + supports_unicode_statements = True + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + supports_native_decimal = True + + default_paramstyle = "format" + execution_ctx_cls = MySQLExecutionContext_mysqldb + statement_compiler = MySQLCompiler_mysqldb + preparer = MySQLIdentifierPreparer + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._mysql_dbapi_version = ( + self._parse_dbapi_version(self.dbapi.__version__) + if self.dbapi is not None and hasattr(self.dbapi, "__version__") + else (0, 0, 0) + ) + + def _parse_dbapi_version(self, version): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) + if m: + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) + else: + return (0, 0, 0) + + @util.langhelpers.memoized_property + def supports_server_side_cursors(self): + try: + cursors = __import__("MySQLdb.cursors").cursors + self._sscursor = cursors.SSCursor + return True + except (ImportError, AttributeError): + return False + + @classmethod + def import_dbapi(cls): + return __import__("MySQLdb") + + def on_connect(self): + super_ = super().on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + charset_name = conn.character_set_name() + + if charset_name is not None: + cursor = conn.cursor() + cursor.execute("SET NAMES %s" % charset_name) + cursor.close() + + return on_connect + + def do_ping(self, dbapi_connection): + dbapi_connection.ping() + return True + + def do_executemany(self, cursor, statement, parameters, context=None): + rowcount = cursor.executemany(statement, parameters) + if context is not None: + context._rowcount = rowcount + + def _check_unicode_returns(self, connection): + # work around issue fixed in + # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 + # specific issue w/ the utf8mb4_bin collation and unicode returns + + collation = connection.exec_driver_sql( + "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation"), + ) + ).scalar() + has_utf8mb4_bin = self.server_version_info > (5,) and collation + if has_utf8mb4_bin: + additional_tests = [ + sql.collate( + sql.cast( + sql.literal_column("'test collated returns'"), + TEXT(charset="utf8mb4"), + ), + "utf8mb4_bin", + ) + ] + else: + additional_tests = [] + return super()._check_unicode_returns(connection, additional_tests) + + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict( + database="db", username="user", password="passwd" + ) + + opts = url.translate_connect_args(**_translate_args) + opts.update(url.query) + + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "read_timeout", int) + util.coerce_kw_type(opts, "write_timeout", int) + util.coerce_kw_type(opts, "client_flag", int) + util.coerce_kw_type(opts, "local_infile", int) + # Note: using either of the below will cause all strings to be + # returned as Unicode, both in raw SQL operations and with column + # types like String and MSString. + util.coerce_kw_type(opts, "use_unicode", bool) + util.coerce_kw_type(opts, "charset", str) + + # Rich values 'cursorclass' and 'conv' are not supported via + # query string. + + ssl = {} + keys = [ + ("ssl_ca", str), + ("ssl_key", str), + ("ssl_cert", str), + ("ssl_capath", str), + ("ssl_cipher", str), + ("ssl_check_hostname", bool), + ] + for key, kw_type in keys: + if key in opts: + ssl[key[4:]] = opts[key] + util.coerce_kw_type(ssl, key[4:], kw_type) + del opts[key] + if ssl: + opts["ssl"] = ssl + + # FOUND_ROWS must be set in CLIENT_FLAGS to enable + # supports_sane_rowcount. + client_flag = opts.get("client_flag", 0) + + client_flag_found_rows = self._found_rows_client_flag() + if client_flag_found_rows is not None: + client_flag |= client_flag_found_rows + opts["client_flag"] = client_flag + return [[], opts] + + def _found_rows_client_flag(self): + if self.dbapi is not None: + try: + CLIENT_FLAGS = __import__( + self.dbapi.__name__ + ".constants.CLIENT" + ).constants.CLIENT + except (AttributeError, ImportError): + return None + else: + return CLIENT_FLAGS.FOUND_ROWS + else: + return None + + def _extract_error_code(self, exception): + return exception.args[0] + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + try: + # note: the SQL here would be + # "SHOW VARIABLES LIKE 'character_set%%'" + cset_name = connection.connection.character_set_name + except AttributeError: + util.warn( + "No 'character_set_name' can be detected with " + "this MySQL-Python version; " + "please upgrade to a recent version of MySQL-Python. " + "Assuming latin1." + ) + return "latin1" + else: + return cset_name() + + def get_isolation_level_values(self, dbapi_connection): + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ) + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + dbapi_connection.autocommit(True) + else: + dbapi_connection.autocommit(False) + super().set_isolation_level(dbapi_connection, level) + + +dialect = MySQLDialect_mysqldb diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/provision.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/provision.py new file mode 100644 index 0000000..3f05bce --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/provision.py @@ -0,0 +1,107 @@ +# dialects/mysql/provision.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import exc +from ...testing.provision import configure_follower +from ...testing.provision import create_db +from ...testing.provision import drop_db +from ...testing.provision import generate_driver_url +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import upsert + + +@generate_driver_url.for_db("mysql", "mariadb") +def generate_driver_url(url, driver, query_str): + backend = url.get_backend_name() + + # NOTE: at the moment, tests are running mariadbconnector + # against both mariadb and mysql backends. if we want this to be + # limited, do the decision making here to reject a "mysql+mariadbconnector" + # URL. Optionally also re-enable the module level + # MySQLDialect_mariadbconnector.is_mysql flag as well, which must include + # a unit and/or functional test. + + # all the Jenkins tests have been running mysqlclient Python library + # built against mariadb client drivers for years against all MySQL / + # MariaDB versions going back to MySQL 5.6, currently they can talk + # to MySQL databases without problems. + + if backend == "mysql": + dialect_cls = url.get_dialect() + if dialect_cls._is_mariadb_from_url(url): + backend = "mariadb" + + new_url = url.set( + drivername="%s+%s" % (backend, driver) + ).update_query_string(query_str) + + try: + new_url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return new_url + + +@create_db.for_db("mysql", "mariadb") +def _mysql_create_db(cfg, eng, ident): + with eng.begin() as conn: + try: + _mysql_drop_db(cfg, conn, ident) + except Exception: + pass + + with eng.begin() as conn: + conn.exec_driver_sql( + "CREATE DATABASE %s CHARACTER SET utf8mb4" % ident + ) + conn.exec_driver_sql( + "CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident + ) + conn.exec_driver_sql( + "CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident + ) + + +@configure_follower.for_db("mysql", "mariadb") +def _mysql_configure_follower(config, ident): + config.test_schema = "%s_test_schema" % ident + config.test_schema_2 = "%s_test_schema_2" % ident + + +@drop_db.for_db("mysql", "mariadb") +def _mysql_drop_db(cfg, eng, ident): + with eng.begin() as conn: + conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident) + conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident) + conn.exec_driver_sql("DROP DATABASE %s" % ident) + + +@temp_table_keyword_args.for_db("mysql", "mariadb") +def _mysql_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@upsert.for_db("mariadb") +def _upsert( + cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False +): + from sqlalchemy.dialects.mysql import insert + + stmt = insert(table) + + if set_lambda: + stmt = stmt.on_duplicate_key_update(**set_lambda(stmt.inserted)) + else: + pk1 = table.primary_key.c[0] + stmt = stmt.on_duplicate_key_update({pk1.key: pk1}) + + stmt = stmt.returning( + *returning, sort_by_parameter_order=sort_by_parameter_order + ) + return stmt diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pymysql.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pymysql.py new file mode 100644 index 0000000..830e441 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pymysql.py @@ -0,0 +1,137 @@ +# dialects/mysql/pymysql.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + +.. dialect:: mysql+pymysql + :name: PyMySQL + :dbapi: pymysql + :connectstring: mysql+pymysql://:@/[?] + :url: https://pymysql.readthedocs.io/ + +Unicode +------- + +Please see :ref:`mysql_unicode` for current recommendations on unicode +handling. + +.. _pymysql_ssl: + +SSL Connections +------------------ + +The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb, +described at :ref:`mysqldb_ssl`. See that section for additional examples. + +If the server uses an automatically-generated certificate that is self-signed +or does not match the host name (as seen from the client), it may also be +necessary to indicate ``ssl_check_hostname=false`` in PyMySQL:: + + connection_uri = ( + "mysql+pymysql://scott:tiger@192.168.0.134/test" + "?ssl_ca=/home/gord/client-ssl/ca.pem" + "&ssl_cert=/home/gord/client-ssl/client-cert.pem" + "&ssl_key=/home/gord/client-ssl/client-key.pem" + "&ssl_check_hostname=false" + ) + + +MySQL-Python Compatibility +-------------------------- + +The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver, +and targets 100% compatibility. Most behavioral notes for MySQL-python apply +to the pymysql driver as well. + +""" # noqa + +from .mysqldb import MySQLDialect_mysqldb +from ...util import langhelpers + + +class MySQLDialect_pymysql(MySQLDialect_mysqldb): + driver = "pymysql" + supports_statement_cache = True + + description_encoding = None + + @langhelpers.memoized_property + def supports_server_side_cursors(self): + try: + cursors = __import__("pymysql.cursors").cursors + self._sscursor = cursors.SSCursor + return True + except (ImportError, AttributeError): + return False + + @classmethod + def import_dbapi(cls): + return __import__("pymysql") + + @langhelpers.memoized_property + def _send_false_to_ping(self): + """determine if pymysql has deprecated, changed the default of, + or removed the 'reconnect' argument of connection.ping(). + + See #10492 and + https://github.com/PyMySQL/mysqlclient/discussions/651#discussioncomment-7308971 + for background. + + """ # noqa: E501 + + try: + Connection = __import__( + "pymysql.connections" + ).connections.Connection + except (ImportError, AttributeError): + return True + else: + insp = langhelpers.get_callable_argspec(Connection.ping) + try: + reconnect_arg = insp.args[1] + except IndexError: + return False + else: + return reconnect_arg == "reconnect" and ( + not insp.defaults or insp.defaults[0] is not False + ) + + def do_ping(self, dbapi_connection): + if self._send_false_to_ping: + dbapi_connection.ping(False) + else: + dbapi_connection.ping() + + return True + + def create_connect_args(self, url, _translate_args=None): + if _translate_args is None: + _translate_args = dict(username="user") + return super().create_connect_args( + url, _translate_args=_translate_args + ) + + def is_disconnect(self, e, connection, cursor): + if super().is_disconnect(e, connection, cursor): + return True + elif isinstance(e, self.dbapi.Error): + str_e = str(e).lower() + return ( + "already closed" in str_e or "connection was killed" in str_e + ) + else: + return False + + def _extract_error_code(self, exception): + if isinstance(exception.args[0], Exception): + exception = exception.args[0] + return exception.args[0] + + +dialect = MySQLDialect_pymysql diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pyodbc.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pyodbc.py new file mode 100644 index 0000000..428c8df --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/pyodbc.py @@ -0,0 +1,138 @@ +# dialects/mysql/pyodbc.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + + +.. dialect:: mysql+pyodbc + :name: PyODBC + :dbapi: pyodbc + :connectstring: mysql+pyodbc://:@ + :url: https://pypi.org/project/pyodbc/ + +.. note:: + + The PyODBC for MySQL dialect is **not tested as part of + SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. + However, if you want to use the mysql+pyodbc dialect and require + full support for ``utf8mb4`` characters (including supplementary + characters like emoji) be sure to use a current release of + MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode") + version of the driver in your DSN or connection string. + +Pass through exact pyodbc connection string:: + + import urllib + connection_string = ( + 'DRIVER=MySQL ODBC 8.0 ANSI Driver;' + 'SERVER=localhost;' + 'PORT=3307;' + 'DATABASE=mydb;' + 'UID=root;' + 'PWD=(whatever);' + 'charset=utf8mb4;' + ) + params = urllib.parse.quote_plus(connection_string) + connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params + +""" # noqa + +import re + +from .base import MySQLDialect +from .base import MySQLExecutionContext +from .types import TIME +from ... import exc +from ... import util +from ...connectors.pyodbc import PyODBCConnector +from ...sql.sqltypes import Time + + +class _pyodbcTIME(TIME): + def result_processor(self, dialect, coltype): + def process(value): + # pyodbc returns a datetime.time object; no need to convert + return value + + return process + + +class MySQLExecutionContext_pyodbc(MySQLExecutionContext): + def get_lastrowid(self): + cursor = self.create_cursor() + cursor.execute("SELECT LAST_INSERT_ID()") + lastrowid = cursor.fetchone()[0] + cursor.close() + return lastrowid + + +class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): + supports_statement_cache = True + colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME}) + supports_unicode_statements = True + execution_ctx_cls = MySQLExecutionContext_pyodbc + + pyodbc_driver_name = "MySQL" + + def _detect_charset(self, connection): + """Sniff out the character set in use for connection results.""" + + # Prefer 'character_set_results' for the current connection over the + # value in the driver. SET NAMES or individual variable SETs will + # change the charset without updating the driver's view of the world. + # + # If it's decided that issuing that sort of SQL leaves you SOL, then + # this can prefer the driver value. + + # set this to None as _fetch_setting attempts to use it (None is OK) + self._connection_charset = None + try: + value = self._fetch_setting(connection, "character_set_client") + if value: + return value + except exc.DBAPIError: + pass + + util.warn( + "Could not detect the connection character set. " + "Assuming latin1." + ) + return "latin1" + + def _get_server_version_info(self, connection): + return MySQLDialect._get_server_version_info(self, connection) + + def _extract_error_code(self, exception): + m = re.compile(r"\((\d+)\)").search(str(exception.args)) + c = m.group(1) + if c: + return int(c) + else: + return None + + def on_connect(self): + super_ = super().on_connect() + + def on_connect(conn): + if super_ is not None: + super_(conn) + + # declare Unicode encoding for pyodbc as per + # https://github.com/mkleehammer/pyodbc/wiki/Unicode + pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR + pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR + conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8") + conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8") + conn.setencoding(encoding="utf-8") + + return on_connect + + +dialect = MySQLDialect_pyodbc diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reflection.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reflection.py new file mode 100644 index 0000000..c764e8c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reflection.py @@ -0,0 +1,677 @@ +# dialects/mysql/reflection.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import re + +from .enumerated import ENUM +from .enumerated import SET +from .types import DATETIME +from .types import TIME +from .types import TIMESTAMP +from ... import log +from ... import types as sqltypes +from ... import util + + +class ReflectedState: + """Stores raw information about a SHOW CREATE TABLE statement.""" + + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.fk_constraints = [] + self.ck_constraints = [] + + +@log.class_logger +class MySQLTableDefinitionParser: + """Parses the results of a SHOW CREATE TABLE statement.""" + + def __init__(self, dialect, preparer): + self.dialect = dialect + self.preparer = preparer + self._prep_regexes() + + def parse(self, show_create, charset): + state = ReflectedState() + state.charset = charset + for line in re.split(r"\r?\n", show_create): + if line.startswith(" " + self.preparer.initial_quote): + self._parse_column(line, state) + # a regular table options line + elif line.startswith(") "): + self._parse_table_options(line, state) + # an ANSI-mode table options line + elif line == ")": + pass + elif line.startswith("CREATE "): + self._parse_table_name(line, state) + elif "PARTITION" in line: + self._parse_partition_options(line, state) + # Not present in real reflection, but may be if + # loading from a file. + elif not line: + pass + else: + type_, spec = self._parse_constraints(line) + if type_ is None: + util.warn("Unknown schema content: %r" % line) + elif type_ == "key": + state.keys.append(spec) + elif type_ == "fk_constraint": + state.fk_constraints.append(spec) + elif type_ == "ck_constraint": + state.ck_constraints.append(spec) + else: + pass + return state + + def _check_view(self, sql: str) -> bool: + return bool(self._re_is_view.match(sql)) + + def _parse_constraints(self, line): + """Parse a KEY or CONSTRAINT line. + + :param line: A line of SHOW CREATE TABLE output + """ + + # KEY + m = self._re_key.match(line) + if m: + spec = m.groupdict() + # convert columns into name, length pairs + # NOTE: we may want to consider SHOW INDEX as the + # format of indexes in MySQL becomes more complex + spec["columns"] = self._parse_keyexprs(spec["columns"]) + if spec["version_sql"]: + m2 = self._re_key_version_sql.match(spec["version_sql"]) + if m2 and m2.groupdict()["parser"]: + spec["parser"] = m2.groupdict()["parser"] + if spec["parser"]: + spec["parser"] = self.preparer.unformat_identifiers( + spec["parser"] + )[0] + return "key", spec + + # FOREIGN KEY CONSTRAINT + m = self._re_fk_constraint.match(line) + if m: + spec = m.groupdict() + spec["table"] = self.preparer.unformat_identifiers(spec["table"]) + spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])] + spec["foreign"] = [ + c[0] for c in self._parse_keyexprs(spec["foreign"]) + ] + return "fk_constraint", spec + + # CHECK constraint + m = self._re_ck_constraint.match(line) + if m: + spec = m.groupdict() + return "ck_constraint", spec + + # PARTITION and SUBPARTITION + m = self._re_partition.match(line) + if m: + # Punt! + return "partition", line + + # No match. + return (None, line) + + def _parse_table_name(self, line, state): + """Extract the table name. + + :param line: The first line of SHOW CREATE TABLE + """ + + regex, cleanup = self._pr_name + m = regex.match(line) + if m: + state.table_name = cleanup(m.group("name")) + + def _parse_table_options(self, line, state): + """Build a dictionary of all reflected table-level options. + + :param line: The final line of SHOW CREATE TABLE output. + """ + + options = {} + + if line and line != ")": + rest_of_line = line + for regex, cleanup in self._pr_options: + m = regex.search(rest_of_line) + if not m: + continue + directive, value = m.group("directive"), m.group("val") + if cleanup: + value = cleanup(value) + options[directive.lower()] = value + rest_of_line = regex.sub("", rest_of_line) + + for nope in ("auto_increment", "data directory", "index directory"): + options.pop(nope, None) + + for opt, val in options.items(): + state.table_options["%s_%s" % (self.dialect.name, opt)] = val + + def _parse_partition_options(self, line, state): + options = {} + new_line = line[:] + + while new_line.startswith("(") or new_line.startswith(" "): + new_line = new_line[1:] + + for regex, cleanup in self._pr_options: + m = regex.search(new_line) + if not m or "PARTITION" not in regex.pattern: + continue + + directive = m.group("directive") + directive = directive.lower() + is_subpartition = directive == "subpartition" + + if directive == "partition" or is_subpartition: + new_line = new_line.replace(") */", "") + new_line = new_line.replace(",", "") + if is_subpartition and new_line.endswith(")"): + new_line = new_line[:-1] + if self.dialect.name == "mariadb" and new_line.endswith(")"): + if ( + "MAXVALUE" in new_line + or "MINVALUE" in new_line + or "ENGINE" in new_line + ): + # final line of MariaDB partition endswith ")" + new_line = new_line[:-1] + + defs = "%s_%s_definitions" % (self.dialect.name, directive) + options[defs] = new_line + + else: + directive = directive.replace(" ", "_") + value = m.group("val") + if cleanup: + value = cleanup(value) + options[directive] = value + break + + for opt, val in options.items(): + part_def = "%s_partition_definitions" % (self.dialect.name) + subpart_def = "%s_subpartition_definitions" % (self.dialect.name) + if opt == part_def or opt == subpart_def: + # builds a string of definitions + if opt not in state.table_options: + state.table_options[opt] = val + else: + state.table_options[opt] = "%s, %s" % ( + state.table_options[opt], + val, + ) + else: + state.table_options["%s_%s" % (self.dialect.name, opt)] = val + + def _parse_column(self, line, state): + """Extract column details. + + Falls back to a 'minimal support' variant if full parse fails. + + :param line: Any column-bearing line from SHOW CREATE TABLE + """ + + spec = None + m = self._re_column.match(line) + if m: + spec = m.groupdict() + spec["full"] = True + else: + m = self._re_column_loose.match(line) + if m: + spec = m.groupdict() + spec["full"] = False + if not spec: + util.warn("Unknown column definition %r" % line) + return + if not spec["full"]: + util.warn("Incomplete reflection of column definition %r" % line) + + name, type_, args = spec["name"], spec["coltype"], spec["arg"] + + try: + col_type = self.dialect.ischema_names[type_] + except KeyError: + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) + col_type = sqltypes.NullType + + # Column type positional arguments eg. varchar(32) + if args is None or args == "": + type_args = [] + elif args[0] == "'" and args[-1] == "'": + type_args = self._re_csv_str.findall(args) + else: + type_args = [int(v) for v in self._re_csv_int.findall(args)] + + # Column type keyword options + type_kw = {} + + if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)): + if type_args: + type_kw["fsp"] = type_args.pop(0) + + for kw in ("unsigned", "zerofill"): + if spec.get(kw, False): + type_kw[kw] = True + for kw in ("charset", "collate"): + if spec.get(kw, False): + type_kw[kw] = spec[kw] + if issubclass(col_type, (ENUM, SET)): + type_args = _strip_values(type_args) + + if issubclass(col_type, SET) and "" in type_args: + type_kw["retrieve_as_bitwise"] = True + + type_instance = col_type(*type_args, **type_kw) + + col_kw = {} + + # NOT NULL + col_kw["nullable"] = True + # this can be "NULL" in the case of TIMESTAMP + if spec.get("notnull", False) == "NOT NULL": + col_kw["nullable"] = False + # For generated columns, the nullability is marked in a different place + if spec.get("notnull_generated", False) == "NOT NULL": + col_kw["nullable"] = False + + # AUTO_INCREMENT + if spec.get("autoincr", False): + col_kw["autoincrement"] = True + elif issubclass(col_type, sqltypes.Integer): + col_kw["autoincrement"] = False + + # DEFAULT + default = spec.get("default", None) + + if default == "NULL": + # eliminates the need to deal with this later. + default = None + + comment = spec.get("comment", None) + + if comment is not None: + comment = cleanup_text(comment) + + sqltext = spec.get("generated") + if sqltext is not None: + computed = dict(sqltext=sqltext) + persisted = spec.get("persistence") + if persisted is not None: + computed["persisted"] = persisted == "STORED" + col_kw["computed"] = computed + + col_d = dict( + name=name, type=type_instance, default=default, comment=comment + ) + col_d.update(col_kw) + state.columns.append(col_d) + + def _describe_to_create(self, table_name, columns): + """Re-format DESCRIBE output as a SHOW CREATE TABLE string. + + DESCRIBE is a much simpler reflection and is sufficient for + reflecting views for runtime use. This method formats DDL + for columns only- keys are omitted. + + :param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples. + SHOW FULL COLUMNS FROM rows must be rearranged for use with + this function. + """ + + buffer = [] + for row in columns: + (name, col_type, nullable, default, extra) = ( + row[i] for i in (0, 1, 2, 4, 5) + ) + + line = [" "] + line.append(self.preparer.quote_identifier(name)) + line.append(col_type) + if not nullable: + line.append("NOT NULL") + if default: + if "auto_increment" in default: + pass + elif col_type.startswith("timestamp") and default.startswith( + "C" + ): + line.append("DEFAULT") + line.append(default) + elif default == "NULL": + line.append("DEFAULT") + line.append(default) + else: + line.append("DEFAULT") + line.append("'%s'" % default.replace("'", "''")) + if extra: + line.append(extra) + + buffer.append(" ".join(line)) + + return "".join( + [ + ( + "CREATE TABLE %s (\n" + % self.preparer.quote_identifier(table_name) + ), + ",\n".join(buffer), + "\n) ", + ] + ) + + def _parse_keyexprs(self, identifiers): + """Unpack '"col"(2),"col" ASC'-ish strings into components.""" + + return [ + (colname, int(length) if length else None, modifiers) + for colname, length, modifiers in self._re_keyexprs.findall( + identifiers + ) + ] + + def _prep_regexes(self): + """Pre-compile regular expressions.""" + + self._re_columns = [] + self._pr_options = [] + + _final = self.preparer.final_quote + + quotes = dict( + zip( + ("iq", "fq", "esc_fq"), + [ + re.escape(s) + for s in ( + self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final), + ) + ], + ) + ) + + self._pr_name = _pr_compile( + r"^CREATE (?:\w+ +)?TABLE +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes, + self.preparer._unescape_identifier, + ) + + self._re_is_view = _re_compile(r"^CREATE(?! TABLE)(\s.*)?\sVIEW") + + # `col`,`col2`(32),`col3`(15) DESC + # + self._re_keyexprs = _re_compile( + r"(?:" + r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)" + r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes + ) + + # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' + self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27") + + # 123 or 123,456 + self._re_csv_int = _re_compile(r"\d+") + + # `colname` [type opts] + # (NOT NULL | NULL) + # DEFAULT ('value' | CURRENT_TIMESTAMP...) + # COMMENT 'comment' + # COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT) + # STORAGE (DISK|MEMORY) + self._re_column = _re_compile( + r" " + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P\w+)" + r"(?:\((?P(?:\d+|\d+,\d+|" + r"(?:'(?:''|[^'])*',?)+))\))?" + r"(?: +(?PUNSIGNED))?" + r"(?: +(?PZEROFILL))?" + r"(?: +CHARACTER SET +(?P[\w_]+))?" + r"(?: +COLLATE +(?P[\w_]+))?" + r"(?: +(?P(?:NOT )?NULL))?" + r"(?: +DEFAULT +(?P" + r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+" + r"(?: +ON UPDATE [\-\w\.\(\)]+)?)" + r"))?" + r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P\(" + r".*\))? ?(?PVIRTUAL|STORED)?" + r"(?: +(?P(?:NOT )?NULL))?" + r")?" + r"(?: +(?PAUTO_INCREMENT))?" + r"(?: +COMMENT +'(?P(?:''|[^'])*)')?" + r"(?: +COLUMN_FORMAT +(?P\w+))?" + r"(?: +STORAGE +(?P\w+))?" + r"(?: +(?P.*))?" + r",?$" % quotes + ) + + # Fallback, try to parse as little as possible + self._re_column_loose = _re_compile( + r" " + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P\w+)" + r"(?:\((?P(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?" + r".*?(?P(?:NOT )NULL)?" % quotes + ) + + # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? + # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) + # KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */ + self._re_key = _re_compile( + r" " + r"(?:(?P\S+) )?KEY" + r"(?: +%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?" + r"(?: +USING +(?P\S+))?" + r" +\((?P.+?)\)" + r"(?: +USING +(?P\S+))?" + r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P\S+))?" + r"(?: +WITH PARSER +(?P\S+))?" + r"(?: +COMMENT +(?P(\x27\x27|\x27([^\x27])*?\x27)+))?" + r"(?: +/\*(?P.+)\*/ *)?" + r",?$" % quotes + ) + + # https://forums.mysql.com/read.php?20,567102,567111#msg-567111 + # It means if the MySQL version >= \d+, execute what's in the comment + self._re_key_version_sql = _re_compile( + r"\!\d+ " r"(?: *WITH PARSER +(?P\S+) *)?" + ) + + # CONSTRAINT `name` FOREIGN KEY (`local_col`) + # REFERENCES `remote` (`remote_col`) + # MATCH FULL | MATCH PARTIAL | MATCH SIMPLE + # ON DELETE CASCADE ON UPDATE RESTRICT + # + # unique constraints come back as KEYs + kw = quotes.copy() + kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION" + self._re_fk_constraint = _re_compile( + r" " + r"CONSTRAINT +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"FOREIGN KEY +" + r"\((?P[^\)]+?)\) REFERENCES +" + r"(?P%(iq)s[^%(fq)s]+%(fq)s" + r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +" + r"\((?P(?:%(iq)s[^%(fq)s]+%(fq)s(?: *, *)?)+)\)" + r"(?: +(?PMATCH \w+))?" + r"(?: +ON DELETE (?P%(on)s))?" + r"(?: +ON UPDATE (?P%(on)s))?" % kw + ) + + # CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)' + # testing on MariaDB 10.2 shows that the CHECK constraint + # is returned on a line by itself, so to match without worrying + # about parenthesis in the expression we go to the end of the line + self._re_ck_constraint = _re_compile( + r" " + r"CONSTRAINT +" + r"%(iq)s(?P(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"CHECK +" + r"\((?P.+)\),?" % kw + ) + + # PARTITION + # + # punt! + self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)") + + # Table-level options (COLLATE, ENGINE, etc.) + # Do the string options first, since they have quoted + # strings we need to get rid of. + for option in _options_of_type_string: + self._add_option_string(option) + + for option in ( + "ENGINE", + "TYPE", + "AUTO_INCREMENT", + "AVG_ROW_LENGTH", + "CHARACTER SET", + "DEFAULT CHARSET", + "CHECKSUM", + "COLLATE", + "DELAY_KEY_WRITE", + "INSERT_METHOD", + "MAX_ROWS", + "MIN_ROWS", + "PACK_KEYS", + "ROW_FORMAT", + "KEY_BLOCK_SIZE", + "STATS_SAMPLE_PAGES", + ): + self._add_option_word(option) + + for option in ( + "PARTITION BY", + "SUBPARTITION BY", + "PARTITIONS", + "SUBPARTITIONS", + "PARTITION", + "SUBPARTITION", + ): + self._add_partition_option_word(option) + + self._add_option_regex("UNION", r"\([^\)]+\)") + self._add_option_regex("TABLESPACE", r".*? STORAGE DISK") + self._add_option_regex( + "RAID_TYPE", + r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+", + ) + + _optional_equals = r"(?:\s*(?:=\s*)|\s+)" + + def _add_option_string(self, directive): + regex = r"(?P%s)%s" r"'(?P(?:[^']|'')*?)'(?!')" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append(_pr_compile(regex, cleanup_text)) + + def _add_option_word(self, directive): + regex = r"(?P%s)%s" r"(?P\w+)" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append(_pr_compile(regex)) + + def _add_partition_option_word(self, directive): + if directive == "PARTITION BY" or directive == "SUBPARTITION BY": + regex = r"(?%s)%s" r"(?P\w+.*)" % ( + re.escape(directive), + self._optional_equals, + ) + elif directive == "SUBPARTITIONS" or directive == "PARTITIONS": + regex = r"(?%s)%s" r"(?P\d+)" % ( + re.escape(directive), + self._optional_equals, + ) + else: + regex = r"(?%s)(?!\S)" % (re.escape(directive),) + self._pr_options.append(_pr_compile(regex)) + + def _add_option_regex(self, directive, regex): + regex = r"(?P%s)%s" r"(?P%s)" % ( + re.escape(directive), + self._optional_equals, + regex, + ) + self._pr_options.append(_pr_compile(regex)) + + +_options_of_type_string = ( + "COMMENT", + "DATA DIRECTORY", + "INDEX DIRECTORY", + "PASSWORD", + "CONNECTION", +) + + +def _pr_compile(regex, cleanup=None): + """Prepare a 2-tuple of compiled regex and callable.""" + + return (_re_compile(regex), cleanup) + + +def _re_compile(regex): + """Compile a string to regex, I and UNICODE.""" + + return re.compile(regex, re.I | re.UNICODE) + + +def _strip_values(values): + "Strip reflected values quotes" + strip_values = [] + for a in values: + if a[0:1] == '"' or a[0:1] == "'": + # strip enclosing quotes and unquote interior + a = a[1:-1].replace(a[0] * 2, a[0]) + strip_values.append(a) + return strip_values + + +def cleanup_text(raw_text: str) -> str: + if "\\" in raw_text: + raw_text = re.sub( + _control_char_regexp, lambda s: _control_char_map[s[0]], raw_text + ) + return raw_text.replace("''", "'") + + +_control_char_map = { + "\\\\": "\\", + "\\0": "\0", + "\\a": "\a", + "\\b": "\b", + "\\t": "\t", + "\\n": "\n", + "\\v": "\v", + "\\f": "\f", + "\\r": "\r", + # '\\e':'\e', +} +_control_char_regexp = re.compile( + "|".join(re.escape(k) for k in _control_char_map) +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reserved_words.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reserved_words.py new file mode 100644 index 0000000..04764c1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/reserved_words.py @@ -0,0 +1,571 @@ +# dialects/mysql/reserved_words.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +# generated using: +# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9 +# +# https://mariadb.com/kb/en/reserved-words/ +# includes: Reserved Words, Oracle Mode (separate set unioned) +# excludes: Exceptions, Function Names +# mypy: ignore-errors + +RESERVED_WORDS_MARIADB = { + "accessible", + "add", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "do_domain_ids", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "enclosed", + "escaped", + "except", + "exists", + "exit", + "explain", + "false", + "fetch", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "general", + "grant", + "group", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "ignore_domain_ids", + "ignore_server_ids", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "intersect", + "interval", + "into", + "is", + "iterate", + "join", + "key", + "keys", + "kill", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_heartbeat_period", + "master_ssl_verify_server_cert", + "match", + "maxvalue", + "mediumblob", + "mediumint", + "mediumtext", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "no_write_to_binlog", + "not", + "null", + "numeric", + "offset", + "on", + "optimize", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "over", + "page_checksum", + "parse_vcol_expr", + "partition", + "position", + "precision", + "primary", + "procedure", + "purge", + "range", + "read", + "read_write", + "reads", + "real", + "recursive", + "ref_system_id", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "resignal", + "restrict", + "return", + "returning", + "revoke", + "right", + "rlike", + "rows", + "row_number", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "signal", + "slow", + "smallint", + "spatial", + "specific", + "sql", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "sqlexception", + "sqlstate", + "sqlwarning", + "ssl", + "starting", + "stats_auto_recalc", + "stats_persistent", + "stats_sample_pages", + "straight_join", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "when", + "where", + "while", + "window", + "with", + "write", + "xor", + "year_month", + "zerofill", +}.union( + { + "body", + "elsif", + "goto", + "history", + "others", + "package", + "period", + "raise", + "rowtype", + "system", + "system_time", + "versioning", + "without", + } +) + +# https://dev.mysql.com/doc/refman/8.3/en/keywords.html +# https://dev.mysql.com/doc/refman/8.0/en/keywords.html +# https://dev.mysql.com/doc/refman/5.7/en/keywords.html +# https://dev.mysql.com/doc/refman/5.6/en/keywords.html +# includes: MySQL x.0 Keywords and Reserved Words +# excludes: MySQL x.0 New Keywords and Reserved Words, +# MySQL x.0 Removed Keywords and Reserved Words +RESERVED_WORDS_MYSQL = { + "accessible", + "add", + "admin", + "all", + "alter", + "analyze", + "and", + "array", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "cube", + "cume_dist", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "dense_rank", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "empty", + "enclosed", + "escaped", + "except", + "exists", + "exit", + "explain", + "false", + "fetch", + "first_value", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "function", + "general", + "generated", + "get", + "get_master_public_key", + "grant", + "group", + "grouping", + "groups", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "ignore_server_ids", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "intersect", + "interval", + "into", + "io_after_gtids", + "io_before_gtids", + "is", + "iterate", + "join", + "json_table", + "key", + "keys", + "kill", + "lag", + "last_value", + "lateral", + "lead", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_bind", + "master_heartbeat_period", + "master_ssl_verify_server_cert", + "match", + "maxvalue", + "mediumblob", + "mediumint", + "mediumtext", + "member", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "no_write_to_binlog", + "not", + "nth_value", + "ntile", + "null", + "numeric", + "of", + "on", + "optimize", + "optimizer_costs", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "over", + "parse_gcol_expr", + "parallel", + "partition", + "percent_rank", + "persist", + "persist_only", + "precision", + "primary", + "procedure", + "purge", + "qualify", + "range", + "rank", + "read", + "read_write", + "reads", + "real", + "recursive", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "resignal", + "restrict", + "return", + "revoke", + "right", + "rlike", + "role", + "row", + "row_number", + "rows", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "signal", + "slow", + "smallint", + "spatial", + "specific", + "sql", + "sql_after_gtids", + "sql_before_gtids", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "sqlexception", + "sqlstate", + "sqlwarning", + "ssl", + "starting", + "stored", + "straight_join", + "system", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "virtual", + "when", + "where", + "while", + "window", + "with", + "write", + "xor", + "year_month", + "zerofill", +} diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/types.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/types.py new file mode 100644 index 0000000..734f6ae --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/mysql/types.py @@ -0,0 +1,774 @@ +# dialects/mysql/types.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import datetime + +from ... import exc +from ... import util +from ...sql import sqltypes + + +class _NumericType: + """Base for MySQL numeric types. + + This is the base both for NUMERIC as well as INTEGER, hence + it's a mixin. + + """ + + def __init__(self, unsigned=False, zerofill=False, **kw): + self.unsigned = unsigned + self.zerofill = zerofill + super().__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_NumericType, sqltypes.Numeric] + ) + + +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + if isinstance(self, (REAL, DOUBLE)) and ( + (precision is None and scale is not None) + or (precision is not None and scale is None) + ): + raise exc.ArgumentError( + "You must specify both precision and scale or omit " + "both altogether." + ) + super().__init__(precision=precision, asdecimal=asdecimal, **kw) + self.scale = scale + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_FloatType, _NumericType, sqltypes.Float] + ) + + +class _IntegerType(_NumericType, sqltypes.Integer): + def __init__(self, display_width=None, **kw): + self.display_width = display_width + super().__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer] + ) + + +class _StringType(sqltypes.String): + """Base for MySQL string types.""" + + def __init__( + self, + charset=None, + collation=None, + ascii=False, # noqa + binary=False, + unicode=False, + national=False, + **kw, + ): + self.charset = charset + + # allow collate= or collation= + kw.setdefault("collation", kw.pop("collate", collation)) + + self.ascii = ascii + self.unicode = unicode + self.binary = binary + self.national = national + super().__init__(**kw) + + def __repr__(self): + return util.generic_repr( + self, to_inspect=[_StringType, sqltypes.String] + ) + + +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): + # TODO: float arguments? + sqltypes.Float.__init__(self) + sqltypes.MatchType.__init__(self) + + +class NUMERIC(_NumericType, sqltypes.NUMERIC): + """MySQL NUMERIC type.""" + + __visit_name__ = "NUMERIC" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a NUMERIC. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class DECIMAL(_NumericType, sqltypes.DECIMAL): + """MySQL DECIMAL type.""" + + __visit_name__ = "DECIMAL" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DECIMAL. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class DOUBLE(_FloatType, sqltypes.DOUBLE): + """MySQL DOUBLE type.""" + + __visit_name__ = "DOUBLE" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a DOUBLE. + + .. note:: + + The :class:`.DOUBLE` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. + Specify either ``scale=n`` or ``decimal_return_scale=n`` in order + to change this scale, or ``asdecimal=False`` to return values + directly as Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class REAL(_FloatType, sqltypes.REAL): + """MySQL REAL type.""" + + __visit_name__ = "REAL" + + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): + """Construct a REAL. + + .. note:: + + The :class:`.REAL` type by default converts from float + to Decimal, using a truncation that defaults to 10 digits. + Specify either ``scale=n`` or ``decimal_return_scale=n`` in order + to change this scale, or ``asdecimal=False`` to return values + directly as Python floating points. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + +class FLOAT(_FloatType, sqltypes.FLOAT): + """MySQL FLOAT type.""" + + __visit_name__ = "FLOAT" + + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): + """Construct a FLOAT. + + :param precision: Total digits in this number. If scale and precision + are both None, values are stored to limits allowed by the server. + + :param scale: The number of digits after the decimal point. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) + + def bind_processor(self, dialect): + return None + + +class INTEGER(_IntegerType, sqltypes.INTEGER): + """MySQL INTEGER type.""" + + __visit_name__ = "INTEGER" + + def __init__(self, display_width=None, **kw): + """Construct an INTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class BIGINT(_IntegerType, sqltypes.BIGINT): + """MySQL BIGINTEGER type.""" + + __visit_name__ = "BIGINT" + + def __init__(self, display_width=None, **kw): + """Construct a BIGINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class MEDIUMINT(_IntegerType): + """MySQL MEDIUMINTEGER type.""" + + __visit_name__ = "MEDIUMINT" + + def __init__(self, display_width=None, **kw): + """Construct a MEDIUMINTEGER + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class TINYINT(_IntegerType): + """MySQL TINYINT type.""" + + __visit_name__ = "TINYINT" + + def __init__(self, display_width=None, **kw): + """Construct a TINYINT. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class SMALLINT(_IntegerType, sqltypes.SMALLINT): + """MySQL SMALLINTEGER type.""" + + __visit_name__ = "SMALLINT" + + def __init__(self, display_width=None, **kw): + """Construct a SMALLINTEGER. + + :param display_width: Optional, maximum display width for this number. + + :param unsigned: a boolean, optional. + + :param zerofill: Optional. If true, values will be stored as strings + left-padded with zeros. Note that this does not effect the values + returned by the underlying database API, which continue to be + numeric. + + """ + super().__init__(display_width=display_width, **kw) + + +class BIT(sqltypes.TypeEngine): + """MySQL BIT type. + + This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater + for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a + MSTinyInteger() type. + + """ + + __visit_name__ = "BIT" + + def __init__(self, length=None): + """Construct a BIT. + + :param length: Optional, number of bits. + + """ + self.length = length + + def result_processor(self, dialect, coltype): + """Convert a MySQL's 64 bit, variable length binary string to a long. + + TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector + already do this, so this logic should be moved to those dialects. + + """ + + def process(value): + if value is not None: + v = 0 + for i in value: + if not isinstance(i, int): + i = ord(i) # convert byte to int on Python 2 + v = v << 8 | i + return v + return value + + return process + + +class TIME(sqltypes.TIME): + """MySQL TIME type.""" + + __visit_name__ = "TIME" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super().__init__(timezone=timezone) + self.fsp = fsp + + def result_processor(self, dialect, coltype): + time = datetime.time + + def process(value): + # convert from a timedelta value + if value is not None: + microseconds = value.microseconds + seconds = value.seconds + minutes = seconds // 60 + return time( + minutes // 60, + minutes % 60, + seconds - minutes * 60, + microsecond=microseconds, + ) + else: + return None + + return process + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """MySQL TIMESTAMP type.""" + + __visit_name__ = "TIMESTAMP" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL TIMESTAMP type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the TIMESTAMP type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super().__init__(timezone=timezone) + self.fsp = fsp + + +class DATETIME(sqltypes.DATETIME): + """MySQL DATETIME type.""" + + __visit_name__ = "DATETIME" + + def __init__(self, timezone=False, fsp=None): + """Construct a MySQL DATETIME type. + + :param timezone: not used by the MySQL dialect. + :param fsp: fractional seconds precision value. + MySQL 5.6.4 supports storage of fractional seconds; + this parameter will be used when emitting DDL + for the DATETIME type. + + .. note:: + + DBAPI driver support for fractional seconds may + be limited; current support includes + MySQL Connector/Python. + + """ + super().__init__(timezone=timezone) + self.fsp = fsp + + +class YEAR(sqltypes.TypeEngine): + """MySQL YEAR type, for single byte storage of years 1901-2155.""" + + __visit_name__ = "YEAR" + + def __init__(self, display_width=None): + self.display_width = display_width + + +class TEXT(_StringType, sqltypes.TEXT): + """MySQL TEXT type, for character storage encoded up to 2^16 bytes.""" + + __visit_name__ = "TEXT" + + def __init__(self, length=None, **kw): + """Construct a TEXT. + + :param length: Optional, if provided the server may optimize storage + by substituting the smallest TEXT type sufficient to store + ``length`` bytes of characters. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(length=length, **kw) + + +class TINYTEXT(_StringType): + """MySQL TINYTEXT type, for character storage encoded up to 2^8 bytes.""" + + __visit_name__ = "TINYTEXT" + + def __init__(self, **kwargs): + """Construct a TINYTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(**kwargs) + + +class MEDIUMTEXT(_StringType): + """MySQL MEDIUMTEXT type, for character storage encoded up + to 2^24 bytes.""" + + __visit_name__ = "MEDIUMTEXT" + + def __init__(self, **kwargs): + """Construct a MEDIUMTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(**kwargs) + + +class LONGTEXT(_StringType): + """MySQL LONGTEXT type, for character storage encoded up to 2^32 bytes.""" + + __visit_name__ = "LONGTEXT" + + def __init__(self, **kwargs): + """Construct a LONGTEXT. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(**kwargs) + + +class VARCHAR(_StringType, sqltypes.VARCHAR): + """MySQL VARCHAR type, for variable-length character data.""" + + __visit_name__ = "VARCHAR" + + def __init__(self, length=None, **kwargs): + """Construct a VARCHAR. + + :param charset: Optional, a column-level character set for this string + value. Takes precedence to 'ascii' or 'unicode' short-hand. + + :param collation: Optional, a column-level collation for this string + value. Takes precedence to 'binary' short-hand. + + :param ascii: Defaults to False: short-hand for the ``latin1`` + character set, generates ASCII in schema. + + :param unicode: Defaults to False: short-hand for the ``ucs2`` + character set, generates UNICODE in schema. + + :param national: Optional. If true, use the server's configured + national character set. + + :param binary: Defaults to False: short-hand, pick the binary + collation type that matches the column's character set. Generates + BINARY in schema. This does not affect the type of data stored, + only the collation of character data. + + """ + super().__init__(length=length, **kwargs) + + +class CHAR(_StringType, sqltypes.CHAR): + """MySQL CHAR type, for fixed-length character data.""" + + __visit_name__ = "CHAR" + + def __init__(self, length=None, **kwargs): + """Construct a CHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + super().__init__(length=length, **kwargs) + + @classmethod + def _adapt_string_for_cast(cls, type_): + # copy the given string type into a CHAR + # for the purposes of rendering a CAST expression + type_ = sqltypes.to_instance(type_) + if isinstance(type_, sqltypes.CHAR): + return type_ + elif isinstance(type_, _StringType): + return CHAR( + length=type_.length, + charset=type_.charset, + collation=type_.collation, + ascii=type_.ascii, + binary=type_.binary, + unicode=type_.unicode, + national=False, # not supported in CAST + ) + else: + return CHAR(length=type_.length) + + +class NVARCHAR(_StringType, sqltypes.NVARCHAR): + """MySQL NVARCHAR type. + + For variable-length character data in the server's configured national + character set. + """ + + __visit_name__ = "NVARCHAR" + + def __init__(self, length=None, **kwargs): + """Construct an NVARCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs["national"] = True + super().__init__(length=length, **kwargs) + + +class NCHAR(_StringType, sqltypes.NCHAR): + """MySQL NCHAR type. + + For fixed-length character data in the server's configured national + character set. + """ + + __visit_name__ = "NCHAR" + + def __init__(self, length=None, **kwargs): + """Construct an NCHAR. + + :param length: Maximum data length, in characters. + + :param binary: Optional, use the default binary collation for the + national character set. This does not affect the type of data + stored, use a BINARY type for binary data. + + :param collation: Optional, request a particular collation. Must be + compatible with the national character set. + + """ + kwargs["national"] = True + super().__init__(length=length, **kwargs) + + +class TINYBLOB(sqltypes._Binary): + """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" + + __visit_name__ = "TINYBLOB" + + +class MEDIUMBLOB(sqltypes._Binary): + """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" + + __visit_name__ = "MEDIUMBLOB" + + +class LONGBLOB(sqltypes._Binary): + """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" + + __visit_name__ = "LONGBLOB" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__init__.py new file mode 100644 index 0000000..d855122 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__init__.py @@ -0,0 +1,67 @@ +# dialects/oracle/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from types import ModuleType + +from . import base # noqa +from . import cx_oracle # noqa +from . import oracledb # noqa +from .base import BFILE +from .base import BINARY_DOUBLE +from .base import BINARY_FLOAT +from .base import BLOB +from .base import CHAR +from .base import CLOB +from .base import DATE +from .base import DOUBLE_PRECISION +from .base import FLOAT +from .base import INTERVAL +from .base import LONG +from .base import NCHAR +from .base import NCLOB +from .base import NUMBER +from .base import NVARCHAR +from .base import NVARCHAR2 +from .base import RAW +from .base import REAL +from .base import ROWID +from .base import TIMESTAMP +from .base import VARCHAR +from .base import VARCHAR2 + +# Alias oracledb also as oracledb_async +oracledb_async = type( + "oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async} +) + +base.dialect = dialect = cx_oracle.dialect + +__all__ = ( + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "DATE", + "NUMBER", + "BLOB", + "BFILE", + "CLOB", + "NCLOB", + "TIMESTAMP", + "RAW", + "FLOAT", + "DOUBLE_PRECISION", + "BINARY_DOUBLE", + "BINARY_FLOAT", + "LONG", + "dialect", + "INTERVAL", + "VARCHAR2", + "NVARCHAR2", + "ROWID", + "REAL", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..cc02ead Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..bf594ef Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/cx_oracle.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/cx_oracle.cpython-311.pyc new file mode 100644 index 0000000..9e8e947 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/cx_oracle.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/dictionary.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/dictionary.cpython-311.pyc new file mode 100644 index 0000000..89ce69c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/dictionary.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/oracledb.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/oracledb.cpython-311.pyc new file mode 100644 index 0000000..9325524 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/oracledb.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/provision.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/provision.cpython-311.pyc new file mode 100644 index 0000000..6d3c52d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/provision.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000..24bfa8d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/__pycache__/types.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/base.py new file mode 100644 index 0000000..a548b34 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/base.py @@ -0,0 +1,3240 @@ +# dialects/oracle/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: oracle + :name: Oracle + :full_support: 18c + :normal_support: 11+ + :best_effort: 9+ + + +Auto Increment Behavior +----------------------- + +SQLAlchemy Table objects which include integer primary keys are usually +assumed to have "autoincrementing" behavior, meaning they can generate their +own primary key values upon INSERT. For use within Oracle, two options are +available, which are the use of IDENTITY columns (Oracle 12 and above only) +or the association of a SEQUENCE with the column. + +Specifying GENERATED AS IDENTITY (Oracle 12 and above) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Starting from version 12 Oracle can make use of identity columns using +the :class:`_sql.Identity` to specify the autoincrementing behavior:: + + t = Table('mytable', metadata, + Column('id', Integer, Identity(start=3), primary_key=True), + Column(...), ... + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE mytable ( + id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 3), + ..., + PRIMARY KEY (id) + ) + +The :class:`_schema.Identity` object support many options to control the +"autoincrementing" behavior of the column, like the starting value, the +incrementing value, etc. +In addition to the standard options, Oracle supports setting +:paramref:`_schema.Identity.always` to ``None`` to use the default +generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports +setting :paramref:`_schema.Identity.on_null` to ``True`` to specify ON NULL +in conjunction with a 'BY DEFAULT' identity column. + +Using a SEQUENCE (all Oracle versions) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Older version of Oracle had no "autoincrement" +feature, SQLAlchemy relies upon sequences to produce these values. With the +older Oracle versions, *a sequence must always be explicitly specified to +enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To +specify sequences, use the sqlalchemy.schema.Sequence object which is passed +to a Column construct:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), + Column(...), ... + ) + +This step is also required when using table reflection, i.e. autoload_with=engine:: + + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), + autoload_with=engine + ) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the option of an autoincrementing + column. + +.. _oracle_isolation_level: + +Transaction Isolation Level / Autocommit +---------------------------------------- + +The Oracle database supports "READ COMMITTED" and "SERIALIZABLE" modes of +isolation. The AUTOCOMMIT isolation level is also supported by the cx_Oracle +dialect. + +To set using per-connection execution options:: + + connection = engine.connect() + connection = connection.execution_options( + isolation_level="AUTOCOMMIT" + ) + +For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle dialect sets the +level at the session level using ``ALTER SESSION``, which is reverted back +to its default setting when the connection is returned to the connection +pool. + +Valid values for ``isolation_level`` include: + +* ``READ COMMITTED`` +* ``AUTOCOMMIT`` +* ``SERIALIZABLE`` + +.. note:: The implementation for the + :meth:`_engine.Connection.get_isolation_level` method as implemented by the + Oracle dialect necessarily forces the start of a transaction using the + Oracle LOCAL_TRANSACTION_ID function; otherwise no level is normally + readable. + + Additionally, the :meth:`_engine.Connection.get_isolation_level` method will + raise an exception if the ``v$transaction`` view is not available due to + permissions or other reasons, which is a common occurrence in Oracle + installations. + + The cx_Oracle dialect attempts to call the + :meth:`_engine.Connection.get_isolation_level` method when the dialect makes + its first connection to the database in order to acquire the + "default"isolation level. This default level is necessary so that the level + can be reset on a connection after it has been temporarily modified using + :meth:`_engine.Connection.execution_options` method. In the common event + that the :meth:`_engine.Connection.get_isolation_level` method raises an + exception due to ``v$transaction`` not being readable as well as any other + database-related failure, the level is assumed to be "READ COMMITTED". No + warning is emitted for this initial first-connect condition as it is + expected to be a common restriction on Oracle databases. + +.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_oracle dialect + as well as the notion of a default isolation level + +.. versionadded:: 1.3.21 Added support for SERIALIZABLE as well as live + reading of the isolation level. + +.. versionchanged:: 1.3.22 In the event that the default isolation + level cannot be read due to permissions on the v$transaction view as + is common in Oracle installations, the default isolation level is hardcoded + to "READ COMMITTED" which was the behavior prior to 1.3.21. + +.. seealso:: + + :ref:`dbapi_autocommit` + +Identifier Casing +----------------- + +In Oracle, the data dictionary represents all case insensitive identifier +names using UPPERCASE text. SQLAlchemy on the other hand considers an +all-lower case identifier name to be case insensitive. The Oracle dialect +converts all case insensitive identifiers to and from those two formats during +schema level communication, such as reflection of tables and indexes. Using +an UPPERCASE name on the SQLAlchemy side indicates a case sensitive +identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names +have been truly created as case sensitive (i.e. using quoted names), all +lowercase names should be used on the SQLAlchemy side. + +.. _oracle_max_identifier_lengths: + +Max Identifier Lengths +---------------------- + +Oracle has changed the default max identifier length as of Oracle Server +version 12.2. Prior to this version, the length was 30, and for 12.2 and +greater it is now 128. This change impacts SQLAlchemy in the area of +generated SQL label names as well as the generation of constraint names, +particularly in the case where the constraint naming convention feature +described at :ref:`constraint_naming_conventions` is being used. + +To assist with this change and others, Oracle includes the concept of a +"compatibility" version, which is a version number that is independent of the +actual server version in order to assist with migration of Oracle databases, +and may be configured within the Oracle server itself. This compatibility +version is retrieved using the query ``SELECT value FROM v$parameter WHERE +name = 'compatible';``. The SQLAlchemy Oracle dialect, when tasked with +determining the default max identifier length, will attempt to use this query +upon first connect in order to determine the effective compatibility version of +the server, which determines what the maximum allowed identifier length is for +the server. If the table is not available, the server version information is +used instead. + +As of SQLAlchemy 1.4, the default max identifier length for the Oracle dialect +is 128 characters. Upon first connect, the compatibility version is detected +and if it is less than Oracle version 12.2, the max identifier length is +changed to be 30 characters. In all cases, setting the +:paramref:`_sa.create_engine.max_identifier_length` parameter will bypass this +change and the value given will be used as is:: + + engine = create_engine( + "oracle+cx_oracle://scott:tiger@oracle122", + max_identifier_length=30) + +The maximum identifier length comes into play both when generating anonymized +SQL labels in SELECT statements, but more crucially when generating constraint +names from a naming convention. It is this area that has created the need for +SQLAlchemy to change this default conservatively. For example, the following +naming convention produces two very different constraint names based on the +identifier length:: + + from sqlalchemy import Column + from sqlalchemy import Index + from sqlalchemy import Integer + from sqlalchemy import MetaData + from sqlalchemy import Table + from sqlalchemy.dialects import oracle + from sqlalchemy.schema import CreateIndex + + m = MetaData(naming_convention={"ix": "ix_%(column_0N_name)s"}) + + t = Table( + "t", + m, + Column("some_column_name_1", Integer), + Column("some_column_name_2", Integer), + Column("some_column_name_3", Integer), + ) + + ix = Index( + None, + t.c.some_column_name_1, + t.c.some_column_name_2, + t.c.some_column_name_3, + ) + + oracle_dialect = oracle.dialect(max_identifier_length=30) + print(CreateIndex(ix).compile(dialect=oracle_dialect)) + +With an identifier length of 30, the above CREATE INDEX looks like:: + + CREATE INDEX ix_some_column_name_1s_70cd ON t + (some_column_name_1, some_column_name_2, some_column_name_3) + +However with length=128, it becomes:: + + CREATE INDEX ix_some_column_name_1some_column_name_2some_column_name_3 ON t + (some_column_name_1, some_column_name_2, some_column_name_3) + +Applications which have run versions of SQLAlchemy prior to 1.4 on an Oracle +server version 12.2 or greater are therefore subject to the scenario of a +database migration that wishes to "DROP CONSTRAINT" on a name that was +previously generated with the shorter length. This migration will fail when +the identifier length is changed without the name of the index or constraint +first being adjusted. Such applications are strongly advised to make use of +:paramref:`_sa.create_engine.max_identifier_length` +in order to maintain control +of the generation of truncated names, and to fully review and test all database +migrations in a staging environment when changing this value to ensure that the +impact of this change has been mitigated. + +.. versionchanged:: 1.4 the default max_identifier_length for Oracle is 128 + characters, which is adjusted down to 30 upon first connect if an older + version of Oracle server (compatibility version < 12.2) is detected. + + +LIMIT/OFFSET/FETCH Support +-------------------------- + +Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` make +use of ``FETCH FIRST N ROW / OFFSET N ROWS`` syntax assuming +Oracle 12c or above, and assuming the SELECT statement is not embedded within +a compound statement like UNION. This syntax is also available directly by using +the :meth:`_sql.Select.fetch` method. + +.. versionchanged:: 2.0 the Oracle dialect now uses + ``FETCH FIRST N ROW / OFFSET N ROWS`` for all + :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` usage including + within the ORM and legacy :class:`_orm.Query`. To force the legacy + behavior using window functions, specify the ``enable_offset_fetch=False`` + dialect parameter to :func:`_sa.create_engine`. + +The use of ``FETCH FIRST / OFFSET`` may be disabled on any Oracle version +by passing ``enable_offset_fetch=False`` to :func:`_sa.create_engine`, which +will force the use of "legacy" mode that makes use of window functions. +This mode is also selected automatically when using a version of Oracle +prior to 12c. + +When using legacy mode, or when a :class:`.Select` statement +with limit/offset is embedded in a compound statement, an emulated approach for +LIMIT / OFFSET based on window functions is used, which involves creation of a +subquery using ``ROW_NUMBER`` that is prone to performance issues as well as +SQL construction issues for complex statements. However, this approach is +supported by all Oracle versions. See notes below. + +Notes on LIMIT / OFFSET emulation (when fetch() method cannot be used) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If using :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset`, or with the +ORM the :meth:`_orm.Query.limit` and :meth:`_orm.Query.offset` methods on an +Oracle version prior to 12c, the following notes apply: + +* SQLAlchemy currently makes use of ROWNUM to achieve + LIMIT/OFFSET; the exact methodology is taken from + https://blogs.oracle.com/oraclemagazine/on-rownum-and-limiting-results . + +* the "FIRST_ROWS()" optimization keyword is not used by default. To enable + the usage of this optimization directive, specify ``optimize_limits=True`` + to :func:`_sa.create_engine`. + + .. versionchanged:: 1.4 + The Oracle dialect renders limit/offset integer values using a "post + compile" scheme which renders the integer directly before passing the + statement to the cursor for execution. The ``use_binds_for_limits`` flag + no longer has an effect. + + .. seealso:: + + :ref:`change_4808`. + +.. _oracle_returning: + +RETURNING Support +----------------- + +The Oracle database supports RETURNING fully for INSERT, UPDATE and DELETE +statements that are invoked with a single collection of bound parameters +(that is, a ``cursor.execute()`` style statement; SQLAlchemy does not generally +support RETURNING with :term:`executemany` statements). Multiple rows may be +returned as well. + +.. versionchanged:: 2.0 the Oracle backend has full support for RETURNING + on parity with other backends. + + + +ON UPDATE CASCADE +----------------- + +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based +solution is available at +https://asktom.oracle.com/tkyte/update_cascade/index.html . + +When using the SQLAlchemy ORM, the ORM has limited ability to manually issue +cascading updates - specify ForeignKey objects using the +"deferrable=True, initially='deferred'" keyword arguments, +and specify "passive_updates=False" on each relationship(). + +Oracle 8 Compatibility +---------------------- + +.. warning:: The status of Oracle 8 compatibility is not known for SQLAlchemy + 2.0. + +When Oracle 8 is detected, the dialect internally configures itself to the +following behaviors: + +* the use_ansi flag is set to False. This has the effect of converting all + JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN + makes use of Oracle's (+) operator. + +* the NVARCHAR2 and NCLOB datatypes are no longer generated as DDL when + the :class:`~sqlalchemy.types.Unicode` is used - VARCHAR2 and CLOB are issued + instead. This because these types don't seem to work correctly on Oracle 8 + even though they are available. The :class:`~sqlalchemy.types.NVARCHAR` and + :class:`~sqlalchemy.dialects.oracle.NCLOB` types will always generate + NVARCHAR2 and NCLOB. + + +Synonym/DBLINK Reflection +------------------------- + +When using reflection with Table objects, the dialect can optionally search +for tables indicated by synonyms, either in local or remote schemas or +accessed over DBLINK, by passing the flag ``oracle_resolve_synonyms=True`` as +a keyword argument to the :class:`_schema.Table` construct:: + + some_table = Table('some_table', autoload_with=some_engine, + oracle_resolve_synonyms=True) + +When this flag is set, the given name (such as ``some_table`` above) will +be searched not just in the ``ALL_TABLES`` view, but also within the +``ALL_SYNONYMS`` view to see if this name is actually a synonym to another +name. If the synonym is located and refers to a DBLINK, the oracle dialect +knows how to locate the table's information using DBLINK syntax(e.g. +``@dblink``). + +``oracle_resolve_synonyms`` is accepted wherever reflection arguments are +accepted, including methods such as :meth:`_schema.MetaData.reflect` and +:meth:`_reflection.Inspector.get_columns`. + +If synonyms are not in use, this flag should be left disabled. + +.. _oracle_constraint_reflection: + +Constraint Reflection +--------------------- + +The Oracle dialect can return information about foreign key, unique, and +CHECK constraints, as well as indexes on tables. + +Raw information regarding these constraints can be acquired using +:meth:`_reflection.Inspector.get_foreign_keys`, +:meth:`_reflection.Inspector.get_unique_constraints`, +:meth:`_reflection.Inspector.get_check_constraints`, and +:meth:`_reflection.Inspector.get_indexes`. + +.. versionchanged:: 1.2 The Oracle dialect can now reflect UNIQUE and + CHECK constraints. + +When using reflection at the :class:`_schema.Table` level, the +:class:`_schema.Table` +will also include these constraints. + +Note the following caveats: + +* When using the :meth:`_reflection.Inspector.get_check_constraints` method, + Oracle + builds a special "IS NOT NULL" constraint for columns that specify + "NOT NULL". This constraint is **not** returned by default; to include + the "IS NOT NULL" constraints, pass the flag ``include_all=True``:: + + from sqlalchemy import create_engine, inspect + + engine = create_engine("oracle+cx_oracle://s:t@dsn") + inspector = inspect(engine) + all_check_constraints = inspector.get_check_constraints( + "some_table", include_all=True) + +* in most cases, when reflecting a :class:`_schema.Table`, + a UNIQUE constraint will + **not** be available as a :class:`.UniqueConstraint` object, as Oracle + mirrors unique constraints with a UNIQUE index in most cases (the exception + seems to be when two or more unique constraints represent the same columns); + the :class:`_schema.Table` will instead represent these using + :class:`.Index` + with the ``unique=True`` flag set. + +* Oracle creates an implicit index for the primary key of a table; this index + is **excluded** from all index results. + +* the list of columns reflected for an index will not include column names + that start with SYS_NC. + +Table names with SYSTEM/SYSAUX tablespaces +------------------------------------------- + +The :meth:`_reflection.Inspector.get_table_names` and +:meth:`_reflection.Inspector.get_temp_table_names` +methods each return a list of table names for the current engine. These methods +are also part of the reflection which occurs within an operation such as +:meth:`_schema.MetaData.reflect`. By default, +these operations exclude the ``SYSTEM`` +and ``SYSAUX`` tablespaces from the operation. In order to change this, the +default list of tablespaces excluded can be changed at the engine level using +the ``exclude_tablespaces`` parameter:: + + # exclude SYSAUX and SOME_TABLESPACE, but not SYSTEM + e = create_engine( + "oracle+cx_oracle://scott:tiger@xe", + exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"]) + +DateTime Compatibility +---------------------- + +Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``, +which can actually store a date and time value. For this reason, the Oracle +dialect provides a type :class:`_oracle.DATE` which is a subclass of +:class:`.DateTime`. This type has no special behavior, and is only +present as a "marker" for this type; additionally, when a database column +is reflected and the type is reported as ``DATE``, the time-supporting +:class:`_oracle.DATE` type is used. + +.. _oracle_table_options: + +Oracle Table Options +------------------------- + +The CREATE TABLE phrase supports the following options with Oracle +in conjunction with the :class:`_schema.Table` construct: + + +* ``ON COMMIT``:: + + Table( + "some_table", metadata, ..., + prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS') + +* ``COMPRESS``:: + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=True) + + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=6) + + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. + +.. _oracle_index_options: + +Oracle Specific Index Options +----------------------------- + +Bitmap Indexes +~~~~~~~~~~~~~~ + +You can specify the ``oracle_bitmap`` parameter to create a bitmap index +instead of a B-tree index:: + + Index('my_index', my_table.c.data, oracle_bitmap=True) + +Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not +check for such limitations, only the database will. + +Index compression +~~~~~~~~~~~~~~~~~ + +Oracle has a more efficient storage mode for indexes containing lots of +repeated values. Use the ``oracle_compress`` parameter to turn on key +compression:: + + Index('my_index', my_table.c.data, oracle_compress=True) + + Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, + oracle_compress=1) + +The ``oracle_compress`` parameter accepts either an integer specifying the +number of prefix columns to compress, or ``True`` to use the default (all +columns for non-unique indexes, all but the last column for unique indexes). + +""" # noqa + +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache +from functools import wraps +import re + +from . import dictionary +from .types import _OracleBoolean +from .types import _OracleDate +from .types import BFILE +from .types import BINARY_DOUBLE +from .types import BINARY_FLOAT +from .types import DATE +from .types import FLOAT +from .types import INTERVAL +from .types import LONG +from .types import NCLOB +from .types import NUMBER +from .types import NVARCHAR2 # noqa +from .types import OracleRaw # noqa +from .types import RAW +from .types import ROWID # noqa +from .types import TIMESTAMP +from .types import VARCHAR2 # noqa +from ... import Computed +from ... import exc +from ... import schema as sa_schema +from ... import sql +from ... import util +from ...engine import default +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import and_ +from ...sql import bindparam +from ...sql import compiler +from ...sql import expression +from ...sql import func +from ...sql import null +from ...sql import or_ +from ...sql import select +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql import visitors +from ...sql.visitors import InternalTraversal +from ...types import BLOB +from ...types import CHAR +from ...types import CLOB +from ...types import DOUBLE_PRECISION +from ...types import INTEGER +from ...types import NCHAR +from ...types import NVARCHAR +from ...types import REAL +from ...types import VARCHAR + +RESERVED_WORDS = set( + "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN " + "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED " + "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE " + "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE " + "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES " + "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS " + "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER " + "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR " + "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split() +) + +NO_ARG_FNS = set( + "UID CURRENT_DATE SYSDATE USER CURRENT_TIME CURRENT_TIMESTAMP".split() +) + + +colspecs = { + sqltypes.Boolean: _OracleBoolean, + sqltypes.Interval: INTERVAL, + sqltypes.DateTime: DATE, + sqltypes.Date: _OracleDate, +} + +ischema_names = { + "VARCHAR2": VARCHAR, + "NVARCHAR2": NVARCHAR, + "CHAR": CHAR, + "NCHAR": NCHAR, + "DATE": DATE, + "NUMBER": NUMBER, + "BLOB": BLOB, + "BFILE": BFILE, + "CLOB": CLOB, + "NCLOB": NCLOB, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP WITH TIME ZONE": TIMESTAMP, + "TIMESTAMP WITH LOCAL TIME ZONE": TIMESTAMP, + "INTERVAL DAY TO SECOND": INTERVAL, + "RAW": RAW, + "FLOAT": FLOAT, + "DOUBLE PRECISION": DOUBLE_PRECISION, + "REAL": REAL, + "LONG": LONG, + "BINARY_DOUBLE": BINARY_DOUBLE, + "BINARY_FLOAT": BINARY_FLOAT, + "ROWID": ROWID, +} + + +class OracleTypeCompiler(compiler.GenericTypeCompiler): + # Note: + # Oracle DATE == DATETIME + # Oracle does not allow milliseconds in DATE + # Oracle does not support TIME columns + + def visit_datetime(self, type_, **kw): + return self.visit_DATE(type_, **kw) + + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) + + def visit_double(self, type_, **kw): + return self.visit_DOUBLE_PRECISION(type_, **kw) + + def visit_unicode(self, type_, **kw): + if self.dialect._use_nchar_for_unicode: + return self.visit_NVARCHAR2(type_, **kw) + else: + return self.visit_VARCHAR2(type_, **kw) + + def visit_INTERVAL(self, type_, **kw): + return "INTERVAL DAY%s TO SECOND%s" % ( + type_.day_precision is not None + and "(%d)" % type_.day_precision + or "", + type_.second_precision is not None + and "(%d)" % type_.second_precision + or "", + ) + + def visit_LONG(self, type_, **kw): + return "LONG" + + def visit_TIMESTAMP(self, type_, **kw): + if getattr(type_, "local_timezone", False): + return "TIMESTAMP WITH LOCAL TIME ZONE" + elif type_.timezone: + return "TIMESTAMP WITH TIME ZONE" + else: + return "TIMESTAMP" + + def visit_DOUBLE_PRECISION(self, type_, **kw): + return self._generate_numeric(type_, "DOUBLE PRECISION", **kw) + + def visit_BINARY_DOUBLE(self, type_, **kw): + return self._generate_numeric(type_, "BINARY_DOUBLE", **kw) + + def visit_BINARY_FLOAT(self, type_, **kw): + return self._generate_numeric(type_, "BINARY_FLOAT", **kw) + + def visit_FLOAT(self, type_, **kw): + kw["_requires_binary_precision"] = True + return self._generate_numeric(type_, "FLOAT", **kw) + + def visit_NUMBER(self, type_, **kw): + return self._generate_numeric(type_, "NUMBER", **kw) + + def _generate_numeric( + self, + type_, + name, + precision=None, + scale=None, + _requires_binary_precision=False, + **kw, + ): + if precision is None: + precision = getattr(type_, "precision", None) + + if _requires_binary_precision: + binary_precision = getattr(type_, "binary_precision", None) + + if precision and binary_precision is None: + # https://www.oracletutorial.com/oracle-basics/oracle-float/ + estimated_binary_precision = int(precision / 0.30103) + raise exc.ArgumentError( + "Oracle FLOAT types use 'binary precision', which does " + "not convert cleanly from decimal 'precision'. Please " + "specify " + f"this type with a separate Oracle variant, such as " + f"{type_.__class__.__name__}(precision={precision})." + f"with_variant(oracle.FLOAT" + f"(binary_precision=" + f"{estimated_binary_precision}), 'oracle'), so that the " + "Oracle specific 'binary_precision' may be specified " + "accurately." + ) + else: + precision = binary_precision + + if scale is None: + scale = getattr(type_, "scale", None) + + if precision is None: + return name + elif scale is None: + n = "%(name)s(%(precision)s)" + return n % {"name": name, "precision": precision} + else: + n = "%(name)s(%(precision)s, %(scale)s)" + return n % {"name": name, "precision": precision, "scale": scale} + + def visit_string(self, type_, **kw): + return self.visit_VARCHAR2(type_, **kw) + + def visit_VARCHAR2(self, type_, **kw): + return self._visit_varchar(type_, "", "2") + + def visit_NVARCHAR2(self, type_, **kw): + return self._visit_varchar(type_, "N", "2") + + visit_NVARCHAR = visit_NVARCHAR2 + + def visit_VARCHAR(self, type_, **kw): + return self._visit_varchar(type_, "", "") + + def _visit_varchar(self, type_, n, num): + if not type_.length: + return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n} + elif not n and self.dialect._supports_char_length: + varchar = "VARCHAR%(two)s(%(length)s CHAR)" + return varchar % {"length": type_.length, "two": num} + else: + varchar = "%(n)sVARCHAR%(two)s(%(length)s)" + return varchar % {"length": type_.length, "two": num, "n": n} + + def visit_text(self, type_, **kw): + return self.visit_CLOB(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + if self.dialect._use_nchar_for_unicode: + return self.visit_NCLOB(type_, **kw) + else: + return self.visit_CLOB(type_, **kw) + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) + + def visit_big_integer(self, type_, **kw): + return self.visit_NUMBER(type_, precision=19, **kw) + + def visit_boolean(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_RAW(self, type_, **kw): + if type_.length: + return "RAW(%(length)s)" % {"length": type_.length} + else: + return "RAW" + + def visit_ROWID(self, type_, **kw): + return "ROWID" + + +class OracleCompiler(compiler.SQLCompiler): + """Oracle compiler modifies the lexical structure of Select + statements to work under non-ANSI configured Oracle databases, if + the use_ansi flag is False. + """ + + compound_keywords = util.update_copy( + compiler.SQLCompiler.compound_keywords, + {expression.CompoundSelect.EXCEPT: "MINUS"}, + ) + + def __init__(self, *args, **kwargs): + self.__wheres = {} + super().__init__(*args, **kwargs) + + def visit_mod_binary(self, binary, operator, **kw): + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_char_length_func(self, fn, **kw): + return "LENGTH" + self.function_argspec(fn, **kw) + + def visit_match_op_binary(self, binary, operator, **kw): + return "CONTAINS (%s, %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def get_cte_preamble(self, recursive): + return "WITH" + + def get_select_hint_text(self, byfroms): + return " ".join("/*+ %s */" % text for table, text in byfroms.items()) + + def function_argspec(self, fn, **kw): + if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS: + return compiler.SQLCompiler.function_argspec(self, fn, **kw) + else: + return "" + + def visit_function(self, func, **kw): + text = super().visit_function(func, **kw) + if kw.get("asfrom", False): + text = "TABLE (%s)" % text + return text + + def visit_table_valued_column(self, element, **kw): + text = super().visit_table_valued_column(element, **kw) + text = text + ".COLUMN_VALUE" + return text + + def default_from(self): + """Called when a ``SELECT`` statement has no froms, + and no ``FROM`` clause is to be appended. + + The Oracle compiler tacks a "FROM DUAL" to the statement. + """ + + return " FROM DUAL" + + def visit_join(self, join, from_linter=None, **kwargs): + if self.dialect.use_ansi: + return compiler.SQLCompiler.visit_join( + self, join, from_linter=from_linter, **kwargs + ) + else: + if from_linter: + from_linter.edges.add((join.left, join.right)) + + kwargs["asfrom"] = True + if isinstance(join.right, expression.FromGrouping): + right = join.right.element + else: + right = join.right + return ( + self.process(join.left, from_linter=from_linter, **kwargs) + + ", " + + self.process(right, from_linter=from_linter, **kwargs) + ) + + def _get_nonansi_join_whereclause(self, froms): + clauses = [] + + def visit_join(join): + if join.isouter: + # https://docs.oracle.com/database/121/SQLRF/queries006.htm#SQLRF52354 + # "apply the outer join operator (+) to all columns of B in + # the join condition in the WHERE clause" - that is, + # unconditionally regardless of operator or the other side + def visit_binary(binary): + if isinstance( + binary.left, expression.ColumnClause + ) and join.right.is_derived_from(binary.left.table): + binary.left = _OuterJoinColumn(binary.left) + elif isinstance( + binary.right, expression.ColumnClause + ) and join.right.is_derived_from(binary.right.table): + binary.right = _OuterJoinColumn(binary.right) + + clauses.append( + visitors.cloned_traverse( + join.onclause, {}, {"binary": visit_binary} + ) + ) + else: + clauses.append(join.onclause) + + for j in join.left, join.right: + if isinstance(j, expression.Join): + visit_join(j) + elif isinstance(j, expression.FromGrouping): + visit_join(j.element) + + for f in froms: + if isinstance(f, expression.Join): + visit_join(f) + + if not clauses: + return None + else: + return sql.and_(*clauses) + + def visit_outer_join_column(self, vc, **kw): + return self.process(vc.column, **kw) + "(+)" + + def visit_sequence(self, seq, **kw): + return self.preparer.format_sequence(seq) + ".nextval" + + def get_render_as_alias_suffix(self, alias_name_text): + """Oracle doesn't like ``FROM table AS alias``""" + + return " " + alias_name_text + + def returning_clause( + self, stmt, returning_cols, *, populate_result_map, **kw + ): + columns = [] + binds = [] + + for i, column in enumerate( + expression._select_iterables(returning_cols) + ): + if ( + self.isupdate + and isinstance(column, sa_schema.Column) + and isinstance(column.server_default, Computed) + and not self.dialect._supports_update_returning_computed_cols + ): + util.warn( + "Computed columns don't work with Oracle UPDATE " + "statements that use RETURNING; the value of the column " + "*before* the UPDATE takes place is returned. It is " + "advised to not use RETURNING with an Oracle computed " + "column. Consider setting implicit_returning to False on " + "the Table object in order to avoid implicit RETURNING " + "clauses from being generated for this Table." + ) + if column.type._has_column_expression: + col_expr = column.type.column_expression(column) + else: + col_expr = column + + outparam = sql.outparam("ret_%d" % i, type_=column.type) + self.binds[outparam.key] = outparam + binds.append( + self.bindparam_string(self._truncate_bindparam(outparam)) + ) + + # has_out_parameters would in a normal case be set to True + # as a result of the compiler visiting an outparam() object. + # in this case, the above outparam() objects are not being + # visited. Ensure the statement itself didn't have other + # outparam() objects independently. + # technically, this could be supported, but as it would be + # a very strange use case without a clear rationale, disallow it + if self.has_out_parameters: + raise exc.InvalidRequestError( + "Using explicit outparam() objects with " + "UpdateBase.returning() in the same Core DML statement " + "is not supported in the Oracle dialect." + ) + + self._oracle_returning = True + + columns.append(self.process(col_expr, within_columns_clause=False)) + if populate_result_map: + self._add_to_result_map( + getattr(col_expr, "name", col_expr._anon_name_label), + getattr(col_expr, "name", col_expr._anon_name_label), + ( + column, + getattr(column, "name", None), + getattr(column, "key", None), + ), + column.type, + ) + + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) + + def _row_limit_clause(self, select, **kw): + """ORacle 12c supports OFFSET/FETCH operators + Use it instead subquery with row_number + + """ + + if ( + select._fetch_clause is not None + or not self.dialect._supports_offset_fetch + ): + return super()._row_limit_clause( + select, use_literal_execute_for_simple_int=True, **kw + ) + else: + return self.fetch_clause( + select, + fetch_clause=self._get_limit_or_fetch(select), + use_literal_execute_for_simple_int=True, + **kw, + ) + + def _get_limit_or_fetch(self, select): + if select._fetch_clause is None: + return select._limit_clause + else: + return select._fetch_clause + + def translate_select_structure(self, select_stmt, **kwargs): + select = select_stmt + + if not getattr(select, "_oracle_visit", None): + if not self.dialect.use_ansi: + froms = self._display_froms_for_select( + select, kwargs.get("asfrom", False) + ) + whereclause = self._get_nonansi_join_whereclause(froms) + if whereclause is not None: + select = select.where(whereclause) + select._oracle_visit = True + + # if fetch is used this is not needed + if ( + select._has_row_limiting_clause + and not self.dialect._supports_offset_fetch + and select._fetch_clause is None + ): + limit_clause = select._limit_clause + offset_clause = select._offset_clause + + if select._simple_int_clause(limit_clause): + limit_clause = limit_clause.render_literal_execute() + + if select._simple_int_clause(offset_clause): + offset_clause = offset_clause.render_literal_execute() + + # currently using form at: + # https://blogs.oracle.com/oraclemagazine/\ + # on-rownum-and-limiting-results + + orig_select = select + select = select._generate() + select._oracle_visit = True + + # add expressions to accommodate FOR UPDATE OF + for_update = select._for_update_arg + if for_update is not None and for_update.of: + for_update = for_update._clone() + for_update._copy_internals() + + for elem in for_update.of: + if not select.selected_columns.contains_column(elem): + select = select.add_columns(elem) + + # Wrap the middle select and add the hint + inner_subquery = select.alias() + limitselect = sql.select( + *[ + c + for c in inner_subquery.c + if orig_select.selected_columns.corresponding_column(c) + is not None + ] + ) + + if ( + limit_clause is not None + and self.dialect.optimize_limits + and select._simple_int_clause(limit_clause) + ): + limitselect = limitselect.prefix_with( + expression.text( + "/*+ FIRST_ROWS(%s) */" + % self.process(limit_clause, **kwargs) + ) + ) + + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + # add expressions to accommodate FOR UPDATE OF + if for_update is not None and for_update.of: + adapter = sql_util.ClauseAdapter(inner_subquery) + for_update.of = [ + adapter.traverse(elem) for elem in for_update.of + ] + + # If needed, add the limiting clause + if limit_clause is not None: + if select._simple_int_clause(limit_clause) and ( + offset_clause is None + or select._simple_int_clause(offset_clause) + ): + max_row = limit_clause + + if offset_clause is not None: + max_row = max_row + offset_clause + + else: + max_row = limit_clause + + if offset_clause is not None: + max_row = max_row + offset_clause + limitselect = limitselect.where( + sql.literal_column("ROWNUM") <= max_row + ) + + # If needed, add the ora_rn, and wrap again with offset. + if offset_clause is None: + limitselect._for_update_arg = for_update + select = limitselect + else: + limitselect = limitselect.add_columns( + sql.literal_column("ROWNUM").label("ora_rn") + ) + limitselect._oracle_visit = True + limitselect._is_wrapper = True + + if for_update is not None and for_update.of: + limitselect_cols = limitselect.selected_columns + for elem in for_update.of: + if ( + limitselect_cols.corresponding_column(elem) + is None + ): + limitselect = limitselect.add_columns(elem) + + limit_subquery = limitselect.alias() + origselect_cols = orig_select.selected_columns + offsetselect = sql.select( + *[ + c + for c in limit_subquery.c + if origselect_cols.corresponding_column(c) + is not None + ] + ) + + offsetselect._oracle_visit = True + offsetselect._is_wrapper = True + + if for_update is not None and for_update.of: + adapter = sql_util.ClauseAdapter(limit_subquery) + for_update.of = [ + adapter.traverse(elem) for elem in for_update.of + ] + + offsetselect = offsetselect.where( + sql.literal_column("ora_rn") > offset_clause + ) + + offsetselect._for_update_arg = for_update + select = offsetselect + + return select + + def limit_clause(self, select, **kw): + return "" + + def visit_empty_set_expr(self, type_, **kw): + return "SELECT 1 FROM DUAL WHERE 1!=1" + + def for_update_clause(self, select, **kw): + if self.is_subquery(): + return "" + + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tmp += " OF " + ", ".join( + self.process(elem, **kw) for elem in select._for_update_arg.of + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "DECODE(%s, %s, 0, 1) = 1" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "DECODE(%s, %s, 0, 1) = 0" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + string = self.process(binary.left, **kw) + pattern = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + if flags is None: + return "REGEXP_LIKE(%s, %s)" % (string, pattern) + else: + return "REGEXP_LIKE(%s, %s, %s)" % ( + string, + pattern, + self.render_literal_value(flags, sqltypes.STRINGTYPE), + ) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_regexp_match_op_binary( + binary, operator, **kw + ) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + string = self.process(binary.left, **kw) + pattern_replace = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + if flags is None: + return "REGEXP_REPLACE(%s, %s)" % ( + string, + pattern_replace, + ) + else: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + string, + pattern_replace, + self.render_literal_value(flags, sqltypes.STRINGTYPE), + ) + + def visit_aggregate_strings_func(self, fn, **kw): + return "LISTAGG%s" % self.function_argspec(fn, **kw) + + +class OracleDDLCompiler(compiler.DDLCompiler): + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % constraint.ondelete + + # oracle has no ON UPDATE CASCADE - + # its only available via triggers + # https://asktom.oracle.com/tkyte/update_cascade/index.html + if constraint.onupdate is not None: + util.warn( + "Oracle does not contain native UPDATE CASCADE " + "functionality - onupdates will not be rendered for foreign " + "keys. Consider using deferrable=True, initially='deferred' " + "or triggers." + ) + + return text + + def visit_drop_table_comment(self, drop, **kw): + return "COMMENT ON TABLE %s IS ''" % self.preparer.format_table( + drop.element + ) + + def visit_create_index(self, create, **kw): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.dialect_options["oracle"]["bitmap"]: + text += "BITMAP " + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=True), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + if index.dialect_options["oracle"]["compress"] is not False: + if index.dialect_options["oracle"]["compress"] is True: + text += " COMPRESS" + else: + text += " COMPRESS %d" % ( + index.dialect_options["oracle"]["compress"] + ) + return text + + def post_create_table(self, table): + table_opts = [] + opts = table.dialect_options["oracle"] + + if opts["on_commit"]: + on_commit_options = opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) + + if opts["compress"]: + if opts["compress"] is True: + table_opts.append("\n COMPRESS") + else: + table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) + + return "".join(table_opts) + + def get_identity_options(self, identity_options): + text = super().get_identity_options(identity_options) + text = text.replace("NO MINVALUE", "NOMINVALUE") + text = text.replace("NO MAXVALUE", "NOMAXVALUE") + text = text.replace("NO CYCLE", "NOCYCLE") + if identity_options.order is not None: + text += " ORDER" if identity_options.order else " NOORDER" + return text.strip() + + def visit_computed_column(self, generated, **kw): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + raise exc.CompileError( + "Oracle computed columns do not support 'stored' persistence; " + "set the 'persisted' flag to None or False for Oracle support." + ) + elif generated.persisted is False: + text += " VIRTUAL" + return text + + def visit_identity_column(self, identity, **kw): + if identity.always is None: + kind = "" + else: + kind = "ALWAYS" if identity.always else "BY DEFAULT" + text = "GENERATED %s" % kind + if identity.on_null: + text += " ON NULL" + text += " AS IDENTITY" + options = self.get_identity_options(identity) + if options: + text += " (%s)" % options + return text + + +class OracleIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = {x.lower() for x in RESERVED_WORDS} + illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union( + ["_", "$"] + ) + + def _bindparam_requires_quotes(self, value): + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(str(value)) + ) + + def format_savepoint(self, savepoint): + name = savepoint.ident.lstrip("_") + return super().format_savepoint(savepoint, name) + + +class OracleExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar( + "SELECT " + + self.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", + type_, + ) + + def pre_exec(self): + if self.statement and "_oracle_dblink" in self.execution_options: + self.statement = self.statement.replace( + dictionary.DB_LINK_PLACEHOLDER, + self.execution_options["_oracle_dblink"], + ) + + +class OracleDialect(default.DefaultDialect): + name = "oracle" + supports_statement_cache = True + supports_alter = True + max_identifier_length = 128 + + _supports_offset_fetch = True + + insert_returning = True + update_returning = True + delete_returning = True + + div_is_floordiv = False + + supports_simple_order_by_label = False + cte_follows_insert = True + returns_native_bytes = True + + supports_sequences = True + sequences_optional = False + postfetch_lastrowid = False + + default_paramstyle = "named" + colspecs = colspecs + ischema_names = ischema_names + requires_name_normalize = True + + supports_comments = True + + supports_default_values = False + supports_default_metavalue = True + supports_empty_insert = False + supports_identity_columns = True + + statement_compiler = OracleCompiler + ddl_compiler = OracleDDLCompiler + type_compiler_cls = OracleTypeCompiler + preparer = OracleIdentifierPreparer + execution_ctx_cls = OracleExecutionContext + + reflection_options = ("oracle_resolve_synonyms",) + + _use_nchar_for_unicode = False + + construct_arguments = [ + ( + sa_schema.Table, + {"resolve_synonyms": False, "on_commit": None, "compress": False}, + ), + (sa_schema.Index, {"bitmap": False, "compress": False}), + ] + + @util.deprecated_params( + use_binds_for_limits=( + "1.4", + "The ``use_binds_for_limits`` Oracle dialect parameter is " + "deprecated. The dialect now renders LIMIT /OFFSET integers " + "inline in all cases using a post-compilation hook, so that the " + "value is still represented by a 'bound parameter' on the Core " + "Expression side.", + ) + ) + def __init__( + self, + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=None, + use_nchar_for_unicode=False, + exclude_tablespaces=("SYSTEM", "SYSAUX"), + enable_offset_fetch=True, + **kwargs, + ): + default.DefaultDialect.__init__(self, **kwargs) + self._use_nchar_for_unicode = use_nchar_for_unicode + self.use_ansi = use_ansi + self.optimize_limits = optimize_limits + self.exclude_tablespaces = exclude_tablespaces + self.enable_offset_fetch = self._supports_offset_fetch = ( + enable_offset_fetch + ) + + def initialize(self, connection): + super().initialize(connection) + + # Oracle 8i has RETURNING: + # https://docs.oracle.com/cd/A87860_01/doc/index.htm + + # so does Oracle8: + # https://docs.oracle.com/cd/A64702_01/doc/index.htm + + if self._is_oracle_8: + self.colspecs = self.colspecs.copy() + self.colspecs.pop(sqltypes.Interval) + self.use_ansi = False + + self.supports_identity_columns = self.server_version_info >= (12,) + self._supports_offset_fetch = ( + self.enable_offset_fetch and self.server_version_info >= (12,) + ) + + def _get_effective_compat_server_version_info(self, connection): + # dialect does not need compat levels below 12.2, so don't query + # in those cases + + if self.server_version_info < (12, 2): + return self.server_version_info + try: + compat = connection.exec_driver_sql( + "SELECT value FROM v$parameter WHERE name = 'compatible'" + ).scalar() + except exc.DBAPIError: + compat = None + + if compat: + try: + return tuple(int(x) for x in compat.split(".")) + except: + return self.server_version_info + else: + return self.server_version_info + + @property + def _is_oracle_8(self): + return self.server_version_info and self.server_version_info < (9,) + + @property + def _supports_table_compression(self): + return self.server_version_info and self.server_version_info >= (10, 1) + + @property + def _supports_table_compress_for(self): + return self.server_version_info and self.server_version_info >= (11,) + + @property + def _supports_char_length(self): + return not self._is_oracle_8 + + @property + def _supports_update_returning_computed_cols(self): + # on version 18 this error is no longet present while it happens on 11 + # it may work also on versions before the 18 + return self.server_version_info and self.server_version_info >= (18,) + + @property + def _supports_except_all(self): + return self.server_version_info and self.server_version_info >= (21,) + + def do_release_savepoint(self, connection, name): + # Oracle does not support RELEASE SAVEPOINT + pass + + def _check_max_identifier_length(self, connection): + if self._get_effective_compat_server_version_info(connection) < ( + 12, + 2, + ): + return 30 + else: + # use the default + return None + + def get_isolation_level_values(self, dbapi_connection): + return ["READ COMMITTED", "SERIALIZABLE"] + + def get_default_isolation_level(self, dbapi_conn): + try: + return self.get_isolation_level(dbapi_conn) + except NotImplementedError: + raise + except: + return "READ COMMITTED" + + def _execute_reflection( + self, connection, query, dblink, returns_long, params=None + ): + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + execution_options = { + # handle db links + "_oracle_dblink": dblink or "", + # override any schema translate map + "schema_translate_map": None, + } + + if dblink and returns_long: + # Oracle seems to error with + # "ORA-00997: illegal use of LONG datatype" when returning + # LONG columns via a dblink in a query with bind params + # This type seems to be very hard to cast into something else + # so it seems easier to just use bind param in this case + def visit_bindparam(bindparam): + bindparam.literal_execute = True + + query = visitors.cloned_traverse( + query, {}, {"bindparam": visit_bindparam} + ) + return connection.execute( + query, params, execution_options=execution_options + ) + + @util.memoized_property + def _has_table_query(self): + # materialized views are returned by all_tables + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + ) + .union_all( + select( + dictionary.all_views.c.view_name.label("table_name"), + dictionary.all_views.c.owner, + ) + ) + .subquery("tables_and_views") + ) + + query = select(tables.c.table_name).where( + tables.c.table_name == bindparam("table_name"), + tables.c.owner == bindparam("owner"), + ) + return query + + @reflection.cache + def has_table( + self, connection, table_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + self._ensure_has_table_connection(connection) + + if not schema: + schema = self.default_schema_name + + params = { + "table_name": self.denormalize_name(table_name), + "owner": self.denormalize_schema_name(schema), + } + cursor = self._execute_reflection( + connection, + self._has_table_query, + dblink, + returns_long=False, + params=params, + ) + return bool(cursor.scalar()) + + @reflection.cache + def has_sequence( + self, connection, sequence_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_name + == self.denormalize_schema_name(sequence_name), + dictionary.all_sequences.c.sequence_owner + == self.denormalize_schema_name(schema), + ) + + cursor = self._execute_reflection( + connection, query, dblink, returns_long=False + ) + return bool(cursor.scalar()) + + def _get_default_schema_name(self, connection): + return self.normalize_name( + connection.exec_driver_sql( + "select sys_context( 'userenv', 'current_schema' ) from dual" + ).scalar() + ) + + def denormalize_schema_name(self, name): + # look for quoted_name + force = getattr(name, "quote", None) + if force is None and name == "public": + # look for case insensitive, no quoting specified, "public" + return "PUBLIC" + return super().denormalize_name(name) + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_synonyms(self, connection, schema, filter_names, dblink, **kw): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = select( + dictionary.all_synonyms.c.synonym_name, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.db_link, + ).where(dictionary.all_synonyms.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_synonyms.c.synonym_name.in_( + params["filter_names"] + ) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() + return result.all() + + @lru_cache() + def _all_objects_query( + self, owner, scope, kind, has_filter_names, has_mat_views + ): + query = ( + select(dictionary.all_objects.c.object_name) + .select_from(dictionary.all_objects) + .where(dictionary.all_objects.c.owner == owner) + ) + + # NOTE: materialized views are listed in all_objects twice; + # once as MATERIALIZE VIEW and once as TABLE + if kind is ObjectKind.ANY: + # materilaized view are listed also as tables so there is no + # need to add them to the in_. + query = query.where( + dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW")) + ) + else: + object_type = [] + if ObjectKind.VIEW in kind: + object_type.append("VIEW") + if ( + ObjectKind.MATERIALIZED_VIEW in kind + and ObjectKind.TABLE not in kind + ): + # materilaized view are listed also as tables so there is no + # need to add them to the in_ if also selecting tables. + object_type.append("MATERIALIZED VIEW") + if ObjectKind.TABLE in kind: + object_type.append("TABLE") + if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind: + # materialized view are listed also as tables, + # so they need to be filtered out + # EXCEPT ALL / MINUS profiles as faster than using + # NOT EXISTS or NOT IN with a subquery, but it's in + # general faster to get the mat view names and exclude + # them only when needed + query = query.where( + dictionary.all_objects.c.object_name.not_in( + bindparam("mat_views") + ) + ) + query = query.where( + dictionary.all_objects.c.object_type.in_(object_type) + ) + + # handles scope + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_objects.c.temporary == "N") + elif scope is ObjectScope.TEMPORARY: + query = query.where(dictionary.all_objects.c.temporary == "Y") + + if has_filter_names: + query = query.where( + dictionary.all_objects.c.object_name.in_( + bindparam("filter_names") + ) + ) + return query + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("scope", InternalTraversal.dp_plain_obj), + ("kind", InternalTraversal.dp_plain_obj), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_all_objects( + self, connection, schema, scope, kind, filter_names, dblink, **kw + ): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _all_objects_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + + query = self._all_objects_query( + owner, scope, kind, has_filter_names, has_mat_views + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ).scalars() + + return result.all() + + def _handle_synonyms_decorator(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + return self._handle_synonyms(fn, *args, **kwargs) + + return wrapper + + def _handle_synonyms(self, fn, connection, *args, **kwargs): + if not kwargs.get("oracle_resolve_synonyms", False): + return fn(self, connection, *args, **kwargs) + + original_kw = kwargs.copy() + schema = kwargs.pop("schema", None) + result = self._get_synonyms( + connection, + schema=schema, + filter_names=kwargs.pop("filter_names", None), + dblink=kwargs.pop("dblink", None), + info_cache=kwargs.get("info_cache", None), + ) + + dblinks_owners = defaultdict(dict) + for row in result: + key = row["db_link"], row["table_owner"] + tn = self.normalize_name(row["table_name"]) + dblinks_owners[key][tn] = row["synonym_name"] + + if not dblinks_owners: + # No synonym, do the plain thing + return fn(self, connection, *args, **original_kw) + + data = {} + for (dblink, table_owner), mapping in dblinks_owners.items(): + call_kw = { + **original_kw, + "schema": table_owner, + "dblink": self.normalize_name(dblink), + "filter_names": mapping.keys(), + } + call_result = fn(self, connection, *args, **call_kw) + for (_, tn), value in call_result: + synonym_name = self.normalize_name(mapping[tn]) + data[(schema, synonym_name)] = value + return data.items() + + @reflection.cache + def get_schema_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + query = select(dictionary.all_users.c.username).order_by( + dictionary.all_users.c.username + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_table_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + # note that table_names() isn't loading DBLINKed or synonym'ed tables + if schema is None: + schema = self.default_schema_name + + den_schema = self.denormalize_schema_name(schema) + if kw.get("oracle_resolve_synonyms", False): + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .union_all( + select( + dictionary.all_synonyms.c.synonym_name.label( + "table_name" + ), + dictionary.all_synonyms.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .select_from(dictionary.all_tables) + .join( + dictionary.all_synonyms, + and_( + dictionary.all_tables.c.table_name + == dictionary.all_synonyms.c.table_name, + dictionary.all_tables.c.owner + == func.coalesce( + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.owner, + ), + ), + ) + ) + .subquery("available_tables") + ) + else: + tables = dictionary.all_tables + + query = select(tables.c.table_name) + if self.exclude_tablespaces: + query = query.where( + func.coalesce( + tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) + ) + query = query.where( + tables.c.owner == den_schema, + tables.c.iot_name.is_(null()), + tables.c.duration.is_(null()), + ) + + # remove materialized views + mat_query = select( + dictionary.all_mviews.c.mview_name.label("table_name") + ).where(dictionary.all_mviews.c.owner == den_schema) + + query = ( + query.except_all(mat_query) + if self._supports_except_all + else query.except_(mat_query) + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_temp_table_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + schema = self.denormalize_schema_name(self.default_schema_name) + + query = select(dictionary.all_tables.c.table_name) + if self.exclude_tablespaces: + query = query.where( + func.coalesce( + dictionary.all_tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) + ) + query = query.where( + dictionary.all_tables.c.owner == schema, + dictionary.all_tables.c.iot_name.is_(null()), + dictionary.all_tables.c.duration.is_not(null()), + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_materialized_view_names( + self, connection, schema=None, dblink=None, _normalize=True, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_mviews.c.mview_name).where( + dictionary.all_mviews.c.owner + == self.denormalize_schema_name(schema) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + if _normalize: + return [self.normalize_name(row) for row in result] + else: + return result.all() + + @reflection.cache + def get_view_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_views.c.view_name).where( + dictionary.all_views.c.owner + == self.denormalize_schema_name(schema) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + @reflection.cache + def get_sequence_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_owner + == self.denormalize_schema_name(schema) + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + + def _value_or_raise(self, data, table, schema): + table = self.normalize_name(str(table)) + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + fn = [self.denormalize_name(name) for name in filter_names] + return True, {"filter_names": fn} + else: + return False, {} + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_options( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _table_options_query( + self, owner, scope, kind, has_filter_names, has_mat_views + ): + query = select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.compression, + dictionary.all_tables.c.compress_for, + ).where(dictionary.all_tables.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_tables.c.table_name.in_( + bindparam("filter_names") + ) + ) + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_tables.c.duration.is_(null())) + elif scope is ObjectScope.TEMPORARY: + query = query.where( + dictionary.all_tables.c.duration.is_not(null()) + ) + + if ( + has_mat_views + and ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # cant use EXCEPT ALL / MINUS here because we don't have an + # excludable row vs. the query above + # outerjoin + where null works better on oracle 21 but 11 does + # not like it at all. this is the next best thing + + query = query.where( + dictionary.all_tables.c.table_name.not_in( + bindparam("mat_views") + ) + ) + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + query = query.where( + dictionary.all_tables.c.table_name.in_(bindparam("mat_views")) + ) + return query + + @_handle_synonyms_decorator + def get_multi_table_options( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False + + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _table_options_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + params["mat_views"] = mat_views + + options = {} + default = ReflectionDefaults.table_options + + if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind: + query = self._table_options_query( + owner, scope, kind, has_filter_names, has_mat_views + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + + for table, compression, compress_for in result: + if compression == "ENABLED": + data = {"oracle_compress": compress_for} + else: + data = default() + options[(schema, self.normalize_name(table))] = data + if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope: + # add the views (no temporary views) + for view in self.get_view_names(connection, schema, dblink, **kw): + if not filter_names or view in filter_names: + options[(schema, view)] = default() + + return options.items() + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def _run_batches( + self, connection, query, dblink, returns_long, mappings, all_objects + ): + each_batch = 500 + batches = list(all_objects) + while batches: + batch = batches[0:each_batch] + batches[0:each_batch] = [] + + result = self._execute_reflection( + connection, + query, + dblink, + returns_long=returns_long, + params={"all_objects": batch}, + ) + if mappings: + yield from result.mappings() + else: + yield from result + + @lru_cache() + def _column_query(self, owner): + all_cols = dictionary.all_tab_cols + all_comments = dictionary.all_col_comments + all_ids = dictionary.all_tab_identity_cols + + if self.server_version_info >= (12,): + add_cols = ( + all_cols.c.default_on_null, + sql.case( + (all_ids.c.table_name.is_(None), sql.null()), + else_=all_ids.c.generation_type + + "," + + all_ids.c.identity_options, + ).label("identity_options"), + ) + join_identity_cols = True + else: + add_cols = ( + sql.null().label("default_on_null"), + sql.null().label("identity_options"), + ) + join_identity_cols = False + + # NOTE: on oracle cannot create tables/views without columns and + # a table cannot have all column hidden: + # ORA-54039: table must have at least one column that is not invisible + # all_tab_cols returns data for tables/views/mat-views. + # all_tab_cols does not return recycled tables + + query = ( + select( + all_cols.c.table_name, + all_cols.c.column_name, + all_cols.c.data_type, + all_cols.c.char_length, + all_cols.c.data_precision, + all_cols.c.data_scale, + all_cols.c.nullable, + all_cols.c.data_default, + all_comments.c.comments, + all_cols.c.virtual_column, + *add_cols, + ).select_from(all_cols) + # NOTE: all_col_comments has a row for each column even if no + # comment is present, so a join could be performed, but there + # seems to be no difference compared to an outer join + .outerjoin( + all_comments, + and_( + all_cols.c.table_name == all_comments.c.table_name, + all_cols.c.column_name == all_comments.c.column_name, + all_cols.c.owner == all_comments.c.owner, + ), + ) + ) + if join_identity_cols: + query = query.outerjoin( + all_ids, + and_( + all_cols.c.table_name == all_ids.c.table_name, + all_cols.c.column_name == all_ids.c.column_name, + all_cols.c.owner == all_ids.c.owner, + ), + ) + + query = query.where( + all_cols.c.table_name.in_(bindparam("all_objects")), + all_cols.c.hidden_column == "NO", + all_cols.c.owner == owner, + ).order_by(all_cols.c.table_name, all_cols.c.column_id) + return query + + @_handle_synonyms_decorator + def get_multi_columns( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + query = self._column_query(owner) + + if ( + filter_names + and kind is ObjectKind.ANY + and scope is ObjectScope.ANY + ): + all_objects = [self.denormalize_name(n) for n in filter_names] + else: + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + columns = defaultdict(list) + + # all_tab_cols.data_default is LONG + result = self._run_batches( + connection, + query, + dblink, + returns_long=True, + mappings=True, + all_objects=all_objects, + ) + + def maybe_int(value): + if isinstance(value, float) and value.is_integer(): + return int(value) + else: + return value + + remove_size = re.compile(r"\(\d+\)") + + for row_dict in result: + table_name = self.normalize_name(row_dict["table_name"]) + orig_colname = row_dict["column_name"] + colname = self.normalize_name(orig_colname) + coltype = row_dict["data_type"] + precision = maybe_int(row_dict["data_precision"]) + + if coltype == "NUMBER": + scale = maybe_int(row_dict["data_scale"]) + if precision is None and scale == 0: + coltype = INTEGER() + else: + coltype = NUMBER(precision, scale) + elif coltype == "FLOAT": + # https://docs.oracle.com/cd/B14117_01/server.101/b10758/sqlqr06.htm + if precision == 126: + # The DOUBLE PRECISION datatype is a floating-point + # number with binary precision 126. + coltype = DOUBLE_PRECISION() + elif precision == 63: + # The REAL datatype is a floating-point number with a + # binary precision of 63, or 18 decimal. + coltype = REAL() + else: + # non standard precision + coltype = FLOAT(binary_precision=precision) + + elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"): + char_length = maybe_int(row_dict["char_length"]) + coltype = self.ischema_names.get(coltype)(char_length) + elif "WITH TIME ZONE" in coltype: + coltype = TIMESTAMP(timezone=True) + elif "WITH LOCAL TIME ZONE" in coltype: + coltype = TIMESTAMP(local_timezone=True) + else: + coltype = re.sub(remove_size, "", coltype) + try: + coltype = self.ischema_names[coltype] + except KeyError: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (coltype, colname) + ) + coltype = sqltypes.NULLTYPE + + default = row_dict["data_default"] + if row_dict["virtual_column"] == "YES": + computed = dict(sqltext=default) + default = None + else: + computed = None + + identity_options = row_dict["identity_options"] + if identity_options is not None: + identity = self._parse_identity_options( + identity_options, row_dict["default_on_null"] + ) + default = None + else: + identity = None + + cdict = { + "name": colname, + "type": coltype, + "nullable": row_dict["nullable"] == "Y", + "default": default, + "comment": row_dict["comments"], + } + if orig_colname.lower() == orig_colname: + cdict["quote"] = True + if computed is not None: + cdict["computed"] = computed + if identity is not None: + cdict["identity"] = identity + + columns[(schema, table_name)].append(cdict) + + # NOTE: default not needed since all tables have columns + # default = ReflectionDefaults.columns + # return ( + # (key, value if value else default()) + # for key, value in columns.items() + # ) + return columns.items() + + def _parse_identity_options(self, identity_options, default_on_null): + # identity_options is a string that starts with 'ALWAYS,' or + # 'BY DEFAULT,' and continues with + # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1, + # CYCLE_FLAG: N, CACHE_SIZE: 1, ORDER_FLAG: N, SCALE_FLAG: N, + # EXTEND_FLAG: N, SESSION_FLAG: N, KEEP_VALUE: N + parts = [p.strip() for p in identity_options.split(",")] + identity = { + "always": parts[0] == "ALWAYS", + "on_null": default_on_null == "YES", + } + + for part in parts[1:]: + option, value = part.split(":") + value = value.strip() + + if "START WITH" in option: + identity["start"] = int(value) + elif "INCREMENT BY" in option: + identity["increment"] = int(value) + elif "MAX_VALUE" in option: + identity["maxvalue"] = int(value) + elif "MIN_VALUE" in option: + identity["minvalue"] = int(value) + elif "CYCLE_FLAG" in option: + identity["cycle"] = value == "Y" + elif "CACHE_SIZE" in option: + identity["cache"] = int(value) + elif "ORDER_FLAG" in option: + identity["order"] = value == "Y" + return identity + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_comment( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _comment_query(self, owner, scope, kind, has_filter_names): + # NOTE: all_tab_comments / all_mview_comments have a row for all + # object even if they don't have comments + queries = [] + if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind: + # all_tab_comments returns also plain views + tbl_view = select( + dictionary.all_tab_comments.c.table_name, + dictionary.all_tab_comments.c.comments, + ).where( + dictionary.all_tab_comments.c.owner == owner, + dictionary.all_tab_comments.c.table_name.not_like("BIN$%"), + ) + if ObjectKind.VIEW not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "TABLE" + ) + elif ObjectKind.TABLE not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "VIEW" + ) + queries.append(tbl_view) + if ObjectKind.MATERIALIZED_VIEW in kind: + mat_view = select( + dictionary.all_mview_comments.c.mview_name.label("table_name"), + dictionary.all_mview_comments.c.comments, + ).where( + dictionary.all_mview_comments.c.owner == owner, + dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"), + ) + queries.append(mat_view) + if len(queries) == 1: + query = queries[0] + else: + union = sql.union_all(*queries).subquery("tables_and_views") + query = select(union.c.table_name, union.c.comments) + + name_col = query.selected_columns.table_name + + if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY): + temp = "Y" if scope is ObjectScope.TEMPORARY else "N" + # need distinct since materialized view are listed also + # as tables in all_objects + query = query.distinct().join( + dictionary.all_objects, + and_( + dictionary.all_objects.c.owner == owner, + dictionary.all_objects.c.object_name == name_col, + dictionary.all_objects.c.temporary == temp, + ), + ) + if has_filter_names: + query = query.where(name_col.in_(bindparam("filter_names"))) + return query + + @_handle_synonyms_decorator + def get_multi_table_comment( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(owner, scope, kind, has_filter_names) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + default = ReflectionDefaults.table_comment + # materialized views by default seem to have a comment like + # "snapshot table for snapshot owner.mat_view_name" + ignore_mat_view = "snapshot table for snapshot " + return ( + ( + (schema, self.normalize_name(table)), + ( + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default() + ), + ) + for table, comment in result + ) + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _index_query(self, owner): + return ( + select( + dictionary.all_ind_columns.c.table_name, + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_name, + dictionary.all_indexes.c.index_type, + dictionary.all_indexes.c.uniqueness, + dictionary.all_indexes.c.compression, + dictionary.all_indexes.c.prefix_length, + dictionary.all_ind_columns.c.descend, + dictionary.all_ind_expressions.c.column_expression, + ) + .select_from(dictionary.all_ind_columns) + .join( + dictionary.all_indexes, + sql.and_( + dictionary.all_ind_columns.c.index_name + == dictionary.all_indexes.c.index_name, + dictionary.all_ind_columns.c.index_owner + == dictionary.all_indexes.c.owner, + ), + ) + .outerjoin( + # NOTE: this adds about 20% to the query time. Using a + # case expression with a scalar subquery only when needed + # with the assumption that most indexes are not expression + # would be faster but oracle does not like that with + # LONG datatype. It errors with: + # ORA-00997: illegal use of LONG datatype + dictionary.all_ind_expressions, + sql.and_( + dictionary.all_ind_expressions.c.index_name + == dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_expressions.c.index_owner + == dictionary.all_ind_columns.c.index_owner, + dictionary.all_ind_expressions.c.column_position + == dictionary.all_ind_columns.c.column_position, + ), + ) + .where( + dictionary.all_indexes.c.table_owner == owner, + dictionary.all_indexes.c.table_name.in_( + bindparam("all_objects") + ), + ) + .order_by( + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_position, + ) + ) + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + query = self._index_query(owner) + + pks = { + row_dict["constraint_name"] + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ) + if row_dict["constraint_type"] == "P" + } + + # all_ind_expressions.column_expression is LONG + result = self._run_batches( + connection, + query, + dblink, + returns_long=True, + mappings=True, + all_objects=all_objects, + ) + + return [ + row_dict + for row_dict in result + if row_dict["index_name"] not in pks + ] + + @_handle_synonyms_decorator + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + uniqueness = {"NONUNIQUE": False, "UNIQUE": True} + enabled = {"DISABLED": False, "ENABLED": True} + is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"} + + indexes = defaultdict(dict) + + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ): + index_name = self.normalize_name(row_dict["index_name"]) + table_name = self.normalize_name(row_dict["table_name"]) + table_indexes = indexes[(schema, table_name)] + + if index_name not in table_indexes: + table_indexes[index_name] = index_dict = { + "name": index_name, + "column_names": [], + "dialect_options": {}, + "unique": uniqueness.get(row_dict["uniqueness"], False), + } + do = index_dict["dialect_options"] + if row_dict["index_type"] in is_bitmap: + do["oracle_bitmap"] = True + if enabled.get(row_dict["compression"], False): + do["oracle_compress"] = row_dict["prefix_length"] + + else: + index_dict = table_indexes[index_name] + + expr = row_dict["column_expression"] + if expr is not None: + index_dict["column_names"].append(None) + if "expressions" in index_dict: + index_dict["expressions"].append(expr) + else: + index_dict["expressions"] = index_dict["column_names"][:-1] + index_dict["expressions"].append(expr) + + if row_dict["descend"].lower() != "asc": + assert row_dict["descend"].lower() == "desc" + cs = index_dict.setdefault("column_sorting", {}) + cs[expr] = ("desc",) + else: + assert row_dict["descend"].lower() == "asc" + cn = self.normalize_name(row_dict["column_name"]) + index_dict["column_names"].append(cn) + if "expressions" in index_dict: + index_dict["expressions"].append(cn) + + default = ReflectionDefaults.indexes + + return ( + (key, list(indexes[key].values()) if key in indexes else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _constraint_query(self, owner): + local = dictionary.all_cons_columns.alias("local") + remote = dictionary.all_cons_columns.alias("remote") + return ( + select( + dictionary.all_constraints.c.table_name, + dictionary.all_constraints.c.constraint_type, + dictionary.all_constraints.c.constraint_name, + local.c.column_name.label("local_column"), + remote.c.table_name.label("remote_table"), + remote.c.column_name.label("remote_column"), + remote.c.owner.label("remote_owner"), + dictionary.all_constraints.c.search_condition, + dictionary.all_constraints.c.delete_rule, + ) + .select_from(dictionary.all_constraints) + .join( + local, + and_( + local.c.owner == dictionary.all_constraints.c.owner, + dictionary.all_constraints.c.constraint_name + == local.c.constraint_name, + ), + ) + .outerjoin( + remote, + and_( + dictionary.all_constraints.c.r_owner == remote.c.owner, + dictionary.all_constraints.c.r_constraint_name + == remote.c.constraint_name, + or_( + remote.c.position.is_(sql.null()), + local.c.position == remote.c.position, + ), + ), + ) + .where( + dictionary.all_constraints.c.owner == owner, + dictionary.all_constraints.c.table_name.in_( + bindparam("all_objects") + ), + dictionary.all_constraints.c.constraint_type.in_( + ("R", "P", "U", "C") + ), + ) + .order_by( + dictionary.all_constraints.c.constraint_name, local.c.position + ) + ) + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_all_constraint_rows( + self, connection, schema, dblink, all_objects, **kw + ): + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + query = self._constraint_query(owner) + + # since the result is cached a list must be created + values = list( + self._run_batches( + connection, + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, + ) + ) + return values + + @_handle_synonyms_decorator + def get_multi_pk_constraint( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + primary_keys = defaultdict(dict) + default = ReflectionDefaults.pk_constraint + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "P": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + column_name = self.normalize_name(row_dict["local_column"]) + + table_pk = primary_keys[(schema, table_name)] + if not table_pk: + table_pk["name"] = constraint_name + table_pk["constrained_columns"] = [column_name] + else: + table_pk["constrained_columns"].append(column_name) + + return ( + (key, primary_keys[key] if key in primary_keys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @_handle_synonyms_decorator + def get_multi_foreign_keys( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + + all_remote_owners = set() + fkeys = defaultdict(dict) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "R": + continue + + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + table_fkey = fkeys[(schema, table_name)] + + assert constraint_name is not None + + local_column = self.normalize_name(row_dict["local_column"]) + remote_table = self.normalize_name(row_dict["remote_table"]) + remote_column = self.normalize_name(row_dict["remote_column"]) + remote_owner_orig = row_dict["remote_owner"] + remote_owner = self.normalize_name(remote_owner_orig) + if remote_owner_orig is not None: + all_remote_owners.add(remote_owner_orig) + + if remote_table is None: + # ticket 363 + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + util.warn( + "Got 'None' querying 'table_name' from " + f"all_cons_columns{dblink or ''} - does the user have " + "proper rights to the table?" + ) + continue + + if constraint_name not in table_fkey: + table_fkey[constraint_name] = fkey = { + "name": constraint_name, + "constrained_columns": [], + "referred_schema": None, + "referred_table": remote_table, + "referred_columns": [], + "options": {}, + } + + if resolve_synonyms: + # will be removed below + fkey["_ref_schema"] = remote_owner + + if schema is not None or remote_owner_orig != owner: + fkey["referred_schema"] = remote_owner + + delete_rule = row_dict["delete_rule"] + if delete_rule != "NO ACTION": + fkey["options"]["ondelete"] = delete_rule + + else: + fkey = table_fkey[constraint_name] + + fkey["constrained_columns"].append(local_column) + fkey["referred_columns"].append(remote_column) + + if resolve_synonyms and all_remote_owners: + query = select( + dictionary.all_synonyms.c.owner, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.synonym_name, + ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners)) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() + + remote_owners_lut = {} + for row in result: + synonym_owner = self.normalize_name(row["owner"]) + table_name = self.normalize_name(row["table_name"]) + + remote_owners_lut[(synonym_owner, table_name)] = ( + row["table_owner"], + row["synonym_name"], + ) + + empty = (None, None) + for table_fkeys in fkeys.values(): + for table_fkey in table_fkeys.values(): + key = ( + table_fkey.pop("_ref_schema"), + table_fkey["referred_table"], + ) + remote_owner, syn_name = remote_owners_lut.get(key, empty) + if syn_name: + sn = self.normalize_name(syn_name) + table_fkey["referred_table"] = sn + if schema is not None or remote_owner != owner: + ro = self.normalize_name(remote_owner) + table_fkey["referred_schema"] = ro + else: + table_fkey["referred_schema"] = None + default = ReflectionDefaults.foreign_keys + + return ( + (key, list(fkeys[key].values()) if key in fkeys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @_handle_synonyms_decorator + def get_multi_unique_constraints( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + unique_cons = defaultdict(dict) + + index_names = { + row_dict["index_name"] + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ) + } + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "U": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name_orig = row_dict["constraint_name"] + constraint_name = self.normalize_name(constraint_name_orig) + column_name = self.normalize_name(row_dict["local_column"]) + table_uc = unique_cons[(schema, table_name)] + + assert constraint_name is not None + + if constraint_name not in table_uc: + table_uc[constraint_name] = uc = { + "name": constraint_name, + "column_names": [], + "duplicates_index": ( + constraint_name + if constraint_name_orig in index_names + else None + ), + } + else: + uc = table_uc[constraint_name] + + uc["column_names"].append(column_name) + + default = ReflectionDefaults.unique_constraints + + return ( + ( + key, + ( + list(unique_cons[key].values()) + if key in unique_cons + else default() + ), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + @reflection.cache + def get_view_definition( + self, + connection, + view_name, + schema=None, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + if kw.get("oracle_resolve_synonyms", False): + synonyms = self._get_synonyms( + connection, schema, filter_names=[view_name], dblink=dblink + ) + if synonyms: + assert len(synonyms) == 1 + row_dict = synonyms[0] + dblink = self.normalize_name(row_dict["db_link"]) + schema = row_dict["table_owner"] + view_name = row_dict["table_name"] + + name = self.denormalize_name(view_name) + owner = self.denormalize_schema_name( + schema or self.default_schema_name + ) + query = ( + select(dictionary.all_views.c.text) + .where( + dictionary.all_views.c.view_name == name, + dictionary.all_views.c.owner == owner, + ) + .union_all( + select(dictionary.all_mviews.c.query).where( + dictionary.all_mviews.c.mview_name == name, + dictionary.all_mviews.c.owner == owner, + ) + ) + ) + + rp = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalar() + if rp is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return rp + + @reflection.cache + def get_check_constraints( + self, connection, table_name, schema=None, include_all=False, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_check_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + include_all=include_all, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @_handle_synonyms_decorator + def get_multi_check_constraints( + self, + connection, + *, + schema, + filter_names, + dblink=None, + scope, + kind, + include_all=False, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) + + not_null = re.compile(r"..+?. IS NOT NULL$") + + check_constraints = defaultdict(list) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "C": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + search_condition = row_dict["search_condition"] + + table_checks = check_constraints[(schema, table_name)] + if constraint_name is not None and ( + include_all or not not_null.match(search_condition) + ): + table_checks.append( + {"name": constraint_name, "sqltext": search_condition} + ) + + default = ReflectionDefaults.check_constraints + + return ( + ( + key, + ( + check_constraints[key] + if key in check_constraints + else default() + ), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + def _list_dblinks(self, connection, dblink=None): + query = select(dictionary.all_db_links.c.db_link) + links = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(link) for link in links] + + +class _OuterJoinColumn(sql.ClauseElement): + __visit_name__ = "outer_join_column" + + def __init__(self, column): + self.column = column diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py new file mode 100644 index 0000000..9346224 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py @@ -0,0 +1,1492 @@ +# dialects/oracle/cx_oracle.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: oracle+cx_oracle + :name: cx-Oracle + :dbapi: cx_oracle + :connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] + :url: https://oracle.github.io/python-cx_Oracle/ + +DSN vs. Hostname connections +----------------------------- + +cx_Oracle provides several methods of indicating the target database. The +dialect translates from a series of different URL forms. + +Hostname Connections with Easy Connect Syntax +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Given a hostname, port and service name of the target Oracle Database, for +example from Oracle's `Easy Connect syntax +`_, +then connect in SQLAlchemy using the ``service_name`` query string parameter:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8") + +The `full Easy Connect syntax +`_ +is not supported. Instead, use a ``tnsnames.ora`` file and connect using a +DSN. + +Connections with tnsnames.ora or Oracle Cloud +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Alternatively, if no port, database name, or ``service_name`` is provided, the +dialect will use an Oracle DSN "connection string". This takes the "hostname" +portion of the URL as the data source name. For example, if the +``tnsnames.ora`` file contains a `Net Service Name +`_ +of ``myalias`` as below:: + + myalias = + (DESCRIPTION = + (ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521)) + (CONNECT_DATA = + (SERVER = DEDICATED) + (SERVICE_NAME = orclpdb1) + ) + ) + +The cx_Oracle dialect connects to this database service when ``myalias`` is the +hostname portion of the URL, without specifying a port, database name or +``service_name``:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8") + +Users of Oracle Cloud should use this syntax and also configure the cloud +wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases +`_. + +SID Connections +^^^^^^^^^^^^^^^ + +To use Oracle's obsolete SID connection syntax, the SID can be passed in a +"database name" portion of the URL as below:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8") + +Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as +follows:: + + >>> import cx_Oracle + >>> cx_Oracle.makedsn("hostname", 1521, sid="dbname") + '(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))' + +Passing cx_Oracle connect arguments +----------------------------------- + +Additional connection arguments can usually be passed via the URL +query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted +and converted to the correct symbol:: + + e = create_engine( + "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true") + +.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names + within the URL string itself, to be passed to the cx_Oracle DBAPI. As + was the case earlier but not correctly documented, the + :paramref:`_sa.create_engine.connect_args` parameter also accepts all + cx_Oracle DBAPI connect arguments. + +To pass arguments directly to ``.connect()`` without using the query +string, use the :paramref:`_sa.create_engine.connect_args` dictionary. +Any cx_Oracle parameter value and/or constant may be passed, such as:: + + import cx_Oracle + e = create_engine( + "oracle+cx_oracle://user:pass@dsn", + connect_args={ + "encoding": "UTF-8", + "nencoding": "UTF-8", + "mode": cx_Oracle.SYSDBA, + "events": True + } + ) + +Note that the default value for ``encoding`` and ``nencoding`` was changed to +"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that +version, or later. + +Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver +-------------------------------------------------------------------------- + +There are also options that are consumed by the SQLAlchemy cx_oracle dialect +itself. These options are always passed directly to :func:`_sa.create_engine` +, such as:: + + e = create_engine( + "oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False) + +The parameters accepted by the cx_oracle dialect are as follows: + +* ``arraysize`` - set the cx_oracle.arraysize value on cursors; defaults + to ``None``, indicating that the driver default should be used (typically + the value is 100). This setting controls how many rows are buffered when + fetching rows, and can have a significant effect on performance when + modified. The setting is used for both ``cx_Oracle`` as well as + ``oracledb``. + + .. versionchanged:: 2.0.26 - changed the default value from 50 to None, + to use the default value of the driver itself. + +* ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`. + +* ``coerce_to_decimal`` - see :ref:`cx_oracle_numeric` for detail. + +* ``encoding_errors`` - see :ref:`cx_oracle_unicode_encoding_errors` for detail. + +.. _cx_oracle_sessionpool: + +Using cx_Oracle SessionPool +--------------------------- + +The cx_Oracle library provides its own connection pool implementation that may +be used in place of SQLAlchemy's pooling functionality. This can be achieved +by using the :paramref:`_sa.create_engine.creator` parameter to provide a +function that returns a new connection, along with setting +:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable +SQLAlchemy's pooling:: + + import cx_Oracle + from sqlalchemy import create_engine + from sqlalchemy.pool import NullPool + + pool = cx_Oracle.SessionPool( + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" + ) + + engine = create_engine("oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool) + +The above engine may then be used normally where cx_Oracle's pool handles +connection pooling:: + + with engine.connect() as conn: + print(conn.scalar("select 1 FROM dual")) + + +As well as providing a scalable solution for multi-user applications, the +cx_Oracle session pool supports some Oracle features such as DRCP and +`Application Continuity +`_. + +Using Oracle Database Resident Connection Pooling (DRCP) +-------------------------------------------------------- + +When using Oracle's `DRCP +`_, +the best practice is to pass a connection class and "purity" when acquiring a +connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation +`_. + +This can be achieved by wrapping ``pool.acquire()``:: + + import cx_Oracle + from sqlalchemy import create_engine + from sqlalchemy.pool import NullPool + + pool = cx_Oracle.SessionPool( + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" + ) + + def creator(): + return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF) + + engine = create_engine("oracle+cx_oracle://", creator=creator, poolclass=NullPool) + +The above engine may then be used normally where cx_Oracle handles session +pooling and Oracle Database additionally uses DRCP:: + + with engine.connect() as conn: + print(conn.scalar("select 1 FROM dual")) + +.. _cx_oracle_unicode: + +Unicode +------- + +As is the case for all DBAPIs under Python 3, all strings are inherently +Unicode strings. In all cases however, the driver requires an explicit +encoding configuration. + +Ensuring the Correct Client Encoding +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The long accepted standard for establishing client encoding for nearly all +Oracle related software is via the `NLS_LANG `_ +environment variable. cx_Oracle like most other Oracle drivers will use +this environment variable as the source of its encoding configuration. The +format of this variable is idiosyncratic; a typical value would be +``AMERICAN_AMERICA.AL32UTF8``. + +The cx_Oracle driver also supports a programmatic alternative which is to +pass the ``encoding`` and ``nencoding`` parameters directly to its +``.connect()`` function. These can be present in the URL as follows:: + + engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8") + +For the meaning of the ``encoding`` and ``nencoding`` parameters, please +consult +`Characters Sets and National Language Support (NLS) `_. + +.. seealso:: + + `Characters Sets and National Language Support (NLS) `_ + - in the cx_Oracle documentation. + + +Unicode-specific Column datatypes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Core expression language handles unicode data by use of the :class:`.Unicode` +and :class:`.UnicodeText` +datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by +default. When using these datatypes with Unicode data, it is expected that +the Oracle database is configured with a Unicode-aware character set, as well +as that the ``NLS_LANG`` environment variable is set appropriately, so that +the VARCHAR2 and CLOB datatypes can accommodate the data. + +In the case that the Oracle database is not configured with a Unicode character +set, the two options are to use the :class:`_types.NCHAR` and +:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag +``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, +which will cause the +SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / +:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB. + +.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` + datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes + unless the ``use_nchar_for_unicode=True`` is passed to the dialect + when :func:`_sa.create_engine` is called. + + +.. _cx_oracle_unicode_encoding_errors: + +Encoding Errors +^^^^^^^^^^^^^^^ + +For the unusual case that data in the Oracle database is present with a broken +encoding, the dialect accepts a parameter ``encoding_errors`` which will be +passed to Unicode decoding functions in order to affect how decoding errors are +handled. The value is ultimately consumed by the Python `decode +`_ function, and +is passed both via cx_Oracle's ``encodingErrors`` parameter consumed by +``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the +cx_Oracle dialect makes use of both under different circumstances. + +.. versionadded:: 1.3.11 + + +.. _cx_oracle_setinputsizes: + +Fine grained control over cx_Oracle data binding performance with setinputsizes +------------------------------------------------------------------------------- + +The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the +DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the +datatypes that are bound to a SQL statement for Python values being passed as +parameters. While virtually no other DBAPI assigns any use to the +``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its +interactions with the Oracle client interface, and in some scenarios it is not +possible for SQLAlchemy to know exactly how data should be bound, as some +settings can cause profoundly different performance characteristics, while +altering the type coercion behavior at the same time. + +Users of the cx_Oracle dialect are **strongly encouraged** to read through +cx_Oracle's list of built-in datatype symbols at +https://cx-oracle.readthedocs.io/en/latest/api_manual/module.html#database-types. +Note that in some cases, significant performance degradation can occur when +using these types vs. not, in particular when specifying ``cx_Oracle.CLOB``. + +On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can +be used both for runtime visibility (e.g. logging) of the setinputsizes step as +well as to fully control how ``setinputsizes()`` is used on a per-statement +basis. + +.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes` + + +Example 1 - logging all setinputsizes calls +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following example illustrates how to log the intermediary values from a +SQLAlchemy perspective before they are converted to the raw ``setinputsizes()`` +parameter dictionary. The keys of the dictionary are :class:`.BindParameter` +objects which have a ``.key`` and a ``.type`` attribute:: + + from sqlalchemy import create_engine, event + + engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + + @event.listens_for(engine, "do_setinputsizes") + def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in inputsizes.items(): + log.info( + "Bound parameter name: %s SQLAlchemy type: %r " + "DBAPI object: %s", + bindparam.key, bindparam.type, dbapitype) + +Example 2 - remove all bindings to CLOB +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ``CLOB`` datatype in cx_Oracle incurs a significant performance overhead, +however is set by default for the ``Text`` type within the SQLAlchemy 1.2 +series. This setting can be modified as follows:: + + from sqlalchemy import create_engine, event + from cx_Oracle import CLOB + + engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") + + @event.listens_for(engine, "do_setinputsizes") + def _remove_clob(inputsizes, cursor, statement, parameters, context): + for bindparam, dbapitype in list(inputsizes.items()): + if dbapitype is CLOB: + del inputsizes[bindparam] + +.. _cx_oracle_returning: + +RETURNING Support +----------------- + +The cx_Oracle dialect implements RETURNING using OUT parameters. +The dialect supports RETURNING fully. + +.. _cx_oracle_lob: + +LOB Datatypes +-------------- + +LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and +BLOB. Modern versions of cx_Oracle and oracledb are optimized for these +datatypes to be delivered as a single buffer. As such, SQLAlchemy makes use of +these newer type handlers by default. + +To disable the use of newer type handlers and deliver LOB objects as classic +buffered objects with a ``read()`` method, the parameter +``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`, +which takes place only engine-wide. + +Two Phase Transactions Not Supported +------------------------------------- + +Two phase transactions are **not supported** under cx_Oracle due to poor +driver support. As of cx_Oracle 6.0b1, the interface for +two phase transactions has been changed to be more of a direct pass-through +to the underlying OCI layer with less automation. The additional logic +to support this system is not implemented in SQLAlchemy. + +.. _cx_oracle_numeric: + +Precision Numerics +------------------ + +SQLAlchemy's numeric types can handle receiving and returning values as Python +``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a +subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in +use, the :paramref:`.Numeric.asdecimal` flag determines if values should be +coerced to ``Decimal`` upon return, or returned as float objects. To make +matters more complicated under Oracle, Oracle's ``NUMBER`` type can also +represent integer values if the "scale" is zero, so the Oracle-specific +:class:`_oracle.NUMBER` type takes this into account as well. + +The cx_Oracle dialect makes extensive use of connection- and cursor-level +"outputtypehandler" callables in order to coerce numeric values as requested. +These callables are specific to the specific flavor of :class:`.Numeric` in +use, as well as if no SQLAlchemy typing objects are present. There are +observed scenarios where Oracle may sends incomplete or ambiguous information +about the numeric types being returned, such as a query where the numeric types +are buried under multiple levels of subquery. The type handlers do their best +to make the right decision in all cases, deferring to the underlying cx_Oracle +DBAPI for all those cases where the driver can make the best decision. + +When no typing objects are present, as when executing plain SQL strings, a +default "outputtypehandler" is present which will generally return numeric +values which specify precision and scale as Python ``Decimal`` objects. To +disable this coercion to decimal for performance reasons, pass the flag +``coerce_to_decimal=False`` to :func:`_sa.create_engine`:: + + engine = create_engine("oracle+cx_oracle://dsn", coerce_to_decimal=False) + +The ``coerce_to_decimal`` flag only impacts the results of plain string +SQL statements that are not otherwise associated with a :class:`.Numeric` +SQLAlchemy type (or a subclass of such). + +.. versionchanged:: 1.2 The numeric handling system for cx_Oracle has been + reworked to take advantage of newer cx_Oracle features as well + as better integration of outputtypehandlers. + +""" # noqa +from __future__ import annotations + +import decimal +import random +import re + +from . import base as oracle +from .base import OracleCompiler +from .base import OracleDialect +from .base import OracleExecutionContext +from .types import _OracleDateLiteralRender +from ... import exc +from ... import util +from ...engine import cursor as _cursor +from ...engine import interfaces +from ...engine import processors +from ...sql import sqltypes +from ...sql._typing import is_sql_compiler + +# source: +# https://github.com/oracle/python-cx_Oracle/issues/596#issuecomment-999243649 +_CX_ORACLE_MAGIC_LOB_SIZE = 131072 + + +class _OracleInteger(sqltypes.Integer): + def get_dbapi_type(self, dbapi): + # see https://github.com/oracle/python-cx_Oracle/issues/ + # 208#issuecomment-409715955 + return int + + def _cx_oracle_var(self, dialect, cursor, arraysize=None): + cx_Oracle = dialect.dbapi + return cursor.var( + cx_Oracle.STRING, + 255, + arraysize=arraysize if arraysize is not None else cursor.arraysize, + outconverter=int, + ) + + def _cx_oracle_outputtypehandler(self, dialect): + def handler(cursor, name, default_type, size, precision, scale): + return self._cx_oracle_var(dialect, cursor) + + return handler + + +class _OracleNumeric(sqltypes.Numeric): + is_number = False + + def bind_processor(self, dialect): + if self.scale == 0: + return None + elif self.asdecimal: + processor = processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + + def process(value): + if isinstance(value, (int, float)): + return processor(value) + elif value is not None and value.is_infinite(): + return float(value) + else: + return value + + return process + else: + return processors.to_float + + def result_processor(self, dialect, coltype): + return None + + def _cx_oracle_outputtypehandler(self, dialect): + cx_Oracle = dialect.dbapi + + def handler(cursor, name, default_type, size, precision, scale): + outconverter = None + + if precision: + if self.asdecimal: + if default_type == cx_Oracle.NATIVE_FLOAT: + # receiving float and doing Decimal after the fact + # allows for float("inf") to be handled + type_ = default_type + outconverter = decimal.Decimal + else: + type_ = decimal.Decimal + else: + if self.is_number and scale == 0: + # integer. cx_Oracle is observed to handle the widest + # variety of ints when no directives are passed, + # from 5.2 to 7.0. See [ticket:4457] + return None + else: + type_ = cx_Oracle.NATIVE_FLOAT + + else: + if self.asdecimal: + if default_type == cx_Oracle.NATIVE_FLOAT: + type_ = default_type + outconverter = decimal.Decimal + else: + type_ = decimal.Decimal + else: + if self.is_number and scale == 0: + # integer. cx_Oracle is observed to handle the widest + # variety of ints when no directives are passed, + # from 5.2 to 7.0. See [ticket:4457] + return None + else: + type_ = cx_Oracle.NATIVE_FLOAT + + return cursor.var( + type_, + 255, + arraysize=cursor.arraysize, + outconverter=outconverter, + ) + + return handler + + +class _OracleUUID(sqltypes.Uuid): + def get_dbapi_type(self, dbapi): + return dbapi.STRING + + +class _OracleBinaryFloat(_OracleNumeric): + def get_dbapi_type(self, dbapi): + return dbapi.NATIVE_FLOAT + + +class _OracleBINARY_FLOAT(_OracleBinaryFloat, oracle.BINARY_FLOAT): + pass + + +class _OracleBINARY_DOUBLE(_OracleBinaryFloat, oracle.BINARY_DOUBLE): + pass + + +class _OracleNUMBER(_OracleNumeric): + is_number = True + + +class _CXOracleDate(oracle._OracleDate): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return value.date() + else: + return value + + return process + + +class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP): + def literal_processor(self, dialect): + return self._literal_processor_datetime(dialect) + + +class _LOBDataType: + pass + + +# TODO: the names used across CHAR / VARCHAR / NCHAR / NVARCHAR +# here are inconsistent and not very good +class _OracleChar(sqltypes.CHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_CHAR + + +class _OracleNChar(sqltypes.NCHAR): + def get_dbapi_type(self, dbapi): + return dbapi.FIXED_NCHAR + + +class _OracleUnicodeStringNCHAR(oracle.NVARCHAR2): + def get_dbapi_type(self, dbapi): + return dbapi.NCHAR + + +class _OracleUnicodeStringCHAR(sqltypes.Unicode): + def get_dbapi_type(self, dbapi): + return dbapi.LONG_STRING + + +class _OracleUnicodeTextNCLOB(_LOBDataType, oracle.NCLOB): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.NCLOB. + # DB_TYPE_NVARCHAR will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_NVARCHAR + + +class _OracleUnicodeTextCLOB(_LOBDataType, sqltypes.UnicodeText): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.CLOB. + # DB_TYPE_NVARCHAR will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_NVARCHAR + + +class _OracleText(_LOBDataType, sqltypes.Text): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.CLOB. + # DB_TYPE_NVARCHAR will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_NVARCHAR + + +class _OracleLong(_LOBDataType, oracle.LONG): + def get_dbapi_type(self, dbapi): + return dbapi.LONG_STRING + + +class _OracleString(sqltypes.String): + pass + + +class _OracleEnum(sqltypes.Enum): + def bind_processor(self, dialect): + enum_proc = sqltypes.Enum.bind_processor(self, dialect) + + def process(value): + raw_str = enum_proc(value) + return raw_str + + return process + + +class _OracleBinary(_LOBDataType, sqltypes.LargeBinary): + def get_dbapi_type(self, dbapi): + # previously, this was dbapi.BLOB. + # DB_TYPE_RAW will instead be passed to setinputsizes() + # when this datatype is used. + return dbapi.DB_TYPE_RAW + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if not dialect.auto_convert_lobs: + return None + else: + return super().result_processor(dialect, coltype) + + +class _OracleInterval(oracle.INTERVAL): + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + +class _OracleRaw(oracle.RAW): + pass + + +class _OracleRowid(oracle.ROWID): + def get_dbapi_type(self, dbapi): + return dbapi.ROWID + + +class OracleCompiler_cx_oracle(OracleCompiler): + _oracle_cx_sql_compiler = True + + _oracle_returning = False + + # Oracle bind names can't start with digits or underscores. + # currently we rely upon Oracle-specific quoting of bind names in most + # cases. however for expanding params, the escape chars are used. + # see #8708 + bindname_escape_characters = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "C", + "[": "C", + "]": "C", + " ": "C", + "\\": "C", + "/": "C", + "?": "C", + } + ) + + def bindparam_string(self, name, **kw): + quote = getattr(name, "quote", None) + if ( + quote is True + or quote is not False + and self.preparer._bindparam_requires_quotes(name) + # bind param quoting for Oracle doesn't work with post_compile + # params. For those, the default bindparam_string will escape + # special chars, and the appending of a number "_1" etc. will + # take care of reserved words + and not kw.get("post_compile", False) + ): + # interesting to note about expanding parameters - since the + # new parameters take the form _, at least if + # they are originally formed from reserved words, they no longer + # need quoting :). names that include illegal characters + # won't work however. + quoted_name = '"%s"' % name + kw["escaped_from"] = name + name = quoted_name + return OracleCompiler.bindparam_string(self, name, **kw) + + # TODO: we could likely do away with quoting altogether for + # Oracle parameters and use the custom escaping here + escaped_from = kw.get("escaped_from", None) + if not escaped_from: + if self._bind_translate_re.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], + name, + ) + if new_name[0].isdigit() or new_name[0] == "_": + new_name = "D" + new_name + kw["escaped_from"] = name + name = new_name + elif name[0].isdigit() or name[0] == "_": + new_name = "D" + name + kw["escaped_from"] = name + name = new_name + + return OracleCompiler.bindparam_string(self, name, **kw) + + +class OracleExecutionContext_cx_oracle(OracleExecutionContext): + out_parameters = None + + def _generate_out_parameter_vars(self): + # check for has_out_parameters or RETURNING, create cx_Oracle.var + # objects if so + if self.compiled.has_out_parameters or self.compiled._oracle_returning: + out_parameters = self.out_parameters + assert out_parameters is not None + + len_params = len(self.parameters) + + quoted_bind_names = self.compiled.escaped_bind_names + for bindparam in self.compiled.binds.values(): + if bindparam.isoutparam: + name = self.compiled.bind_names[bindparam] + type_impl = bindparam.type.dialect_impl(self.dialect) + + if hasattr(type_impl, "_cx_oracle_var"): + out_parameters[name] = type_impl._cx_oracle_var( + self.dialect, self.cursor, arraysize=len_params + ) + else: + dbtype = type_impl.get_dbapi_type(self.dialect.dbapi) + + cx_Oracle = self.dialect.dbapi + + assert cx_Oracle is not None + + if dbtype is None: + raise exc.InvalidRequestError( + "Cannot create out parameter for " + "parameter " + "%r - its type %r is not supported by" + " cx_oracle" % (bindparam.key, bindparam.type) + ) + + # note this is an OUT parameter. Using + # non-LOB datavalues with large unicode-holding + # values causes the failure (both cx_Oracle and + # oracledb): + # ORA-22835: Buffer too small for CLOB to CHAR or + # BLOB to RAW conversion (actual: 16507, + # maximum: 4000) + # [SQL: INSERT INTO long_text (x, y, z) VALUES + # (:x, :y, :z) RETURNING long_text.x, long_text.y, + # long_text.z INTO :ret_0, :ret_1, :ret_2] + # so even for DB_TYPE_NVARCHAR we convert to a LOB + + if isinstance(type_impl, _LOBDataType): + if dbtype == cx_Oracle.DB_TYPE_NVARCHAR: + dbtype = cx_Oracle.NCLOB + elif dbtype == cx_Oracle.DB_TYPE_RAW: + dbtype = cx_Oracle.BLOB + # other LOB types go in directly + + out_parameters[name] = self.cursor.var( + dbtype, + # this is fine also in oracledb_async since + # the driver will await the read coroutine + outconverter=lambda value: value.read(), + arraysize=len_params, + ) + elif ( + isinstance(type_impl, _OracleNumeric) + and type_impl.asdecimal + ): + out_parameters[name] = self.cursor.var( + decimal.Decimal, + arraysize=len_params, + ) + + else: + out_parameters[name] = self.cursor.var( + dbtype, arraysize=len_params + ) + + for param in self.parameters: + param[quoted_bind_names.get(name, name)] = ( + out_parameters[name] + ) + + def _generate_cursor_outputtype_handler(self): + output_handlers = {} + + for keyname, name, objects, type_ in self.compiled._result_columns: + handler = type_._cached_custom_processor( + self.dialect, + "cx_oracle_outputtypehandler", + self._get_cx_oracle_type_handler, + ) + + if handler: + denormalized_name = self.dialect.denormalize_name(keyname) + output_handlers[denormalized_name] = handler + + if output_handlers: + default_handler = self._dbapi_connection.outputtypehandler + + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): + if name in output_handlers: + return output_handlers[name]( + cursor, name, default_type, size, precision, scale + ) + else: + return default_handler( + cursor, name, default_type, size, precision, scale + ) + + self.cursor.outputtypehandler = output_type_handler + + def _get_cx_oracle_type_handler(self, impl): + if hasattr(impl, "_cx_oracle_outputtypehandler"): + return impl._cx_oracle_outputtypehandler(self.dialect) + else: + return None + + def pre_exec(self): + super().pre_exec() + if not getattr(self.compiled, "_oracle_cx_sql_compiler", False): + return + + self.out_parameters = {} + + self._generate_out_parameter_vars() + + self._generate_cursor_outputtype_handler() + + def post_exec(self): + if ( + self.compiled + and is_sql_compiler(self.compiled) + and self.compiled._oracle_returning + ): + initial_buffer = self.fetchall_for_returning( + self.cursor, _internal=True + ) + + fetch_strategy = _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + [ + (entry.keyname, None) + for entry in self.compiled._result_columns + ], + initial_buffer=initial_buffer, + ) + + self.cursor_fetch_strategy = fetch_strategy + + def create_cursor(self): + c = self._dbapi_connection.cursor() + if self.dialect.arraysize: + c.arraysize = self.dialect.arraysize + + return c + + def fetchall_for_returning(self, cursor, *, _internal=False): + compiled = self.compiled + if ( + not _internal + and compiled is None + or not is_sql_compiler(compiled) + or not compiled._oracle_returning + ): + raise NotImplementedError( + "execution context was not prepared for Oracle RETURNING" + ) + + # create a fake cursor result from the out parameters. unlike + # get_out_parameter_values(), the result-row handlers here will be + # applied at the Result level + + numcols = len(self.out_parameters) + + # [stmt_result for stmt_result in outparam.values] == each + # statement in executemany + # [val for val in stmt_result] == each row for a particular + # statement + return list( + zip( + *[ + [ + val + for stmt_result in self.out_parameters[ + f"ret_{j}" + ].values + for val in (stmt_result or ()) + ] + for j in range(numcols) + ] + ) + ) + + def get_out_parameter_values(self, out_param_names): + # this method should not be called when the compiler has + # RETURNING as we've turned the has_out_parameters flag set to + # False. + assert not self.compiled.returning + + return [ + self.dialect._paramval(self.out_parameters[name]) + for name in out_param_names + ] + + +class OracleDialect_cx_oracle(OracleDialect): + supports_statement_cache = True + execution_ctx_cls = OracleExecutionContext_cx_oracle + statement_compiler = OracleCompiler_cx_oracle + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + + insert_executemany_returning = True + insert_executemany_returning_sort_by_parameter_order = True + update_executemany_returning = True + delete_executemany_returning = True + + bind_typing = interfaces.BindTyping.SETINPUTSIZES + + driver = "cx_oracle" + + colspecs = util.update_copy( + OracleDialect.colspecs, + { + sqltypes.TIMESTAMP: _CXOracleTIMESTAMP, + sqltypes.Numeric: _OracleNumeric, + sqltypes.Float: _OracleNumeric, + oracle.BINARY_FLOAT: _OracleBINARY_FLOAT, + oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE, + sqltypes.Integer: _OracleInteger, + oracle.NUMBER: _OracleNUMBER, + sqltypes.Date: _CXOracleDate, + sqltypes.LargeBinary: _OracleBinary, + sqltypes.Boolean: oracle._OracleBoolean, + sqltypes.Interval: _OracleInterval, + oracle.INTERVAL: _OracleInterval, + sqltypes.Text: _OracleText, + sqltypes.String: _OracleString, + sqltypes.UnicodeText: _OracleUnicodeTextCLOB, + sqltypes.CHAR: _OracleChar, + sqltypes.NCHAR: _OracleNChar, + sqltypes.Enum: _OracleEnum, + oracle.LONG: _OracleLong, + oracle.RAW: _OracleRaw, + sqltypes.Unicode: _OracleUnicodeStringCHAR, + sqltypes.NVARCHAR: _OracleUnicodeStringNCHAR, + sqltypes.Uuid: _OracleUUID, + oracle.NCLOB: _OracleUnicodeTextNCLOB, + oracle.ROWID: _OracleRowid, + }, + ) + + execute_sequence_format = list + + _cx_oracle_threaded = None + + _cursor_var_unicode_kwargs = util.immutabledict() + + @util.deprecated_params( + threaded=( + "1.3", + "The 'threaded' parameter to the cx_oracle/oracledb dialect " + "is deprecated as a dialect-level argument, and will be removed " + "in a future release. As of version 1.3, it defaults to False " + "rather than True. The 'threaded' option can be passed to " + "cx_Oracle directly in the URL query string passed to " + ":func:`_sa.create_engine`.", + ) + ) + def __init__( + self, + auto_convert_lobs=True, + coerce_to_decimal=True, + arraysize=None, + encoding_errors=None, + threaded=None, + **kwargs, + ): + OracleDialect.__init__(self, **kwargs) + self.arraysize = arraysize + self.encoding_errors = encoding_errors + if encoding_errors: + self._cursor_var_unicode_kwargs = { + "encodingErrors": encoding_errors + } + if threaded is not None: + self._cx_oracle_threaded = threaded + self.auto_convert_lobs = auto_convert_lobs + self.coerce_to_decimal = coerce_to_decimal + if self._use_nchar_for_unicode: + self.colspecs = self.colspecs.copy() + self.colspecs[sqltypes.Unicode] = _OracleUnicodeStringNCHAR + self.colspecs[sqltypes.UnicodeText] = _OracleUnicodeTextNCLOB + + dbapi_module = self.dbapi + self._load_version(dbapi_module) + + if dbapi_module is not None: + # these constants will first be seen in SQLAlchemy datatypes + # coming from the get_dbapi_type() method. We then + # will place the following types into setinputsizes() calls + # on each statement. Oracle constants that are not in this + # list will not be put into setinputsizes(). + self.include_set_input_sizes = { + dbapi_module.DATETIME, + dbapi_module.DB_TYPE_NVARCHAR, # used for CLOB, NCLOB + dbapi_module.DB_TYPE_RAW, # used for BLOB + dbapi_module.NCLOB, # not currently used except for OUT param + dbapi_module.CLOB, # not currently used except for OUT param + dbapi_module.LOB, # not currently used + dbapi_module.BLOB, # not currently used except for OUT param + dbapi_module.NCHAR, + dbapi_module.FIXED_NCHAR, + dbapi_module.FIXED_CHAR, + dbapi_module.TIMESTAMP, + int, # _OracleInteger, + # _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE, + dbapi_module.NATIVE_FLOAT, + } + + self._paramval = lambda value: value.getvalue() + + def _load_version(self, dbapi_module): + version = (0, 0, 0) + if dbapi_module is not None: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", dbapi_module.version) + if m: + version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + self.cx_oracle_ver = version + if self.cx_oracle_ver < (8,) and self.cx_oracle_ver > (0, 0, 0): + raise exc.InvalidRequestError( + "cx_Oracle version 8 and above are supported" + ) + + @classmethod + def import_dbapi(cls): + import cx_Oracle + + return cx_Oracle + + def initialize(self, connection): + super().initialize(connection) + self._detect_decimal_char(connection) + + def get_isolation_level(self, dbapi_connection): + # sources: + + # general idea of transaction id, have to start one, etc. + # https://stackoverflow.com/questions/10711204/how-to-check-isoloation-level + + # how to decode xid cols from v$transaction to match + # https://asktom.oracle.com/pls/apex/f?p=100:11:0::::P11_QUESTION_ID:9532779900346079444 + + # Oracle tuple comparison without using IN: + # https://www.sql-workbench.eu/comparison/tuple_comparison.html + + with dbapi_connection.cursor() as cursor: + # this is the only way to ensure a transaction is started without + # actually running DML. There's no way to see the configured + # isolation level without getting it from v$transaction which + # means transaction has to be started. + outval = cursor.var(str) + cursor.execute( + """ + begin + :trans_id := dbms_transaction.local_transaction_id( TRUE ); + end; + """, + {"trans_id": outval}, + ) + trans_id = outval.getvalue() + xidusn, xidslot, xidsqn = trans_id.split(".", 2) + + cursor.execute( + "SELECT CASE BITAND(t.flag, POWER(2, 28)) " + "WHEN 0 THEN 'READ COMMITTED' " + "ELSE 'SERIALIZABLE' END AS isolation_level " + "FROM v$transaction t WHERE " + "(t.xidusn, t.xidslot, t.xidsqn) = " + "((:xidusn, :xidslot, :xidsqn))", + {"xidusn": xidusn, "xidslot": xidslot, "xidsqn": xidsqn}, + ) + row = cursor.fetchone() + if row is None: + raise exc.InvalidRequestError( + "could not retrieve isolation level" + ) + result = row[0] + + return result + + def get_isolation_level_values(self, dbapi_connection): + return super().get_isolation_level_values(dbapi_connection) + [ + "AUTOCOMMIT" + ] + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + dbapi_connection.autocommit = True + else: + dbapi_connection.autocommit = False + dbapi_connection.rollback() + with dbapi_connection.cursor() as cursor: + cursor.execute(f"ALTER SESSION SET ISOLATION_LEVEL={level}") + + def _detect_decimal_char(self, connection): + # we have the option to change this setting upon connect, + # or just look at what it is upon connect and convert. + # to minimize the chance of interference with changes to + # NLS_TERRITORY or formatting behavior of the DB, we opt + # to just look at it + + dbapi_connection = connection.connection + + with dbapi_connection.cursor() as cursor: + # issue #8744 + # nls_session_parameters is not available in some Oracle + # modes like "mount mode". But then, v$nls_parameters is not + # available if the connection doesn't have SYSDBA priv. + # + # simplify the whole thing and just use the method that we were + # doing in the test suite already, selecting a number + + def output_type_handler( + cursor, name, defaultType, size, precision, scale + ): + return cursor.var( + self.dbapi.STRING, 255, arraysize=cursor.arraysize + ) + + cursor.outputtypehandler = output_type_handler + cursor.execute("SELECT 1.1 FROM DUAL") + value = cursor.fetchone()[0] + + decimal_char = value.lstrip("0")[1] + assert not decimal_char[0].isdigit() + + self._decimal_char = decimal_char + + if self._decimal_char != ".": + _detect_decimal = self._detect_decimal + _to_decimal = self._to_decimal + + self._detect_decimal = lambda value: _detect_decimal( + value.replace(self._decimal_char, ".") + ) + self._to_decimal = lambda value: _to_decimal( + value.replace(self._decimal_char, ".") + ) + + def _detect_decimal(self, value): + if "." in value: + return self._to_decimal(value) + else: + return int(value) + + _to_decimal = decimal.Decimal + + def _generate_connection_outputtype_handler(self): + """establish the default outputtypehandler established at the + connection level. + + """ + + dialect = self + cx_Oracle = dialect.dbapi + + number_handler = _OracleNUMBER( + asdecimal=True + )._cx_oracle_outputtypehandler(dialect) + float_handler = _OracleNUMBER( + asdecimal=False + )._cx_oracle_outputtypehandler(dialect) + + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): + if ( + default_type == cx_Oracle.NUMBER + and default_type is not cx_Oracle.NATIVE_FLOAT + ): + if not dialect.coerce_to_decimal: + return None + elif precision == 0 and scale in (0, -127): + # ambiguous type, this occurs when selecting + # numbers from deep subqueries + return cursor.var( + cx_Oracle.STRING, + 255, + outconverter=dialect._detect_decimal, + arraysize=cursor.arraysize, + ) + elif precision and scale > 0: + return number_handler( + cursor, name, default_type, size, precision, scale + ) + else: + return float_handler( + cursor, name, default_type, size, precision, scale + ) + + # if unicode options were specified, add a decoder, otherwise + # cx_Oracle should return Unicode + elif ( + dialect._cursor_var_unicode_kwargs + and default_type + in ( + cx_Oracle.STRING, + cx_Oracle.FIXED_CHAR, + ) + and default_type is not cx_Oracle.CLOB + and default_type is not cx_Oracle.NCLOB + ): + return cursor.var( + str, + size, + cursor.arraysize, + **dialect._cursor_var_unicode_kwargs, + ) + + elif dialect.auto_convert_lobs and default_type in ( + cx_Oracle.CLOB, + cx_Oracle.NCLOB, + ): + return cursor.var( + cx_Oracle.DB_TYPE_NVARCHAR, + _CX_ORACLE_MAGIC_LOB_SIZE, + cursor.arraysize, + **dialect._cursor_var_unicode_kwargs, + ) + + elif dialect.auto_convert_lobs and default_type in ( + cx_Oracle.BLOB, + ): + return cursor.var( + cx_Oracle.DB_TYPE_RAW, + _CX_ORACLE_MAGIC_LOB_SIZE, + cursor.arraysize, + ) + + return output_type_handler + + def on_connect(self): + output_type_handler = self._generate_connection_outputtype_handler() + + def on_connect(conn): + conn.outputtypehandler = output_type_handler + + return on_connect + + def create_connect_args(self, url): + opts = dict(url.query) + + for opt in ("use_ansi", "auto_convert_lobs"): + if opt in opts: + util.warn_deprecated( + f"{self.driver} dialect option {opt!r} should only be " + "passed to create_engine directly, not within the URL " + "string", + version="1.3", + ) + util.coerce_kw_type(opts, opt, bool) + setattr(self, opt, opts.pop(opt)) + + database = url.database + service_name = opts.pop("service_name", None) + if database or service_name: + # if we have a database, then we have a remote host + port = url.port + if port: + port = int(port) + else: + port = 1521 + + if database and service_name: + raise exc.InvalidRequestError( + '"service_name" option shouldn\'t ' + 'be used with a "database" part of the url' + ) + if database: + makedsn_kwargs = {"sid": database} + if service_name: + makedsn_kwargs = {"service_name": service_name} + + dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs) + else: + # we have a local tnsname + dsn = url.host + + if dsn is not None: + opts["dsn"] = dsn + if url.password is not None: + opts["password"] = url.password + if url.username is not None: + opts["user"] = url.username + + if self._cx_oracle_threaded is not None: + opts.setdefault("threaded", self._cx_oracle_threaded) + + def convert_cx_oracle_constant(value): + if isinstance(value, str): + try: + int_val = int(value) + except ValueError: + value = value.upper() + return getattr(self.dbapi, value) + else: + return int_val + else: + return value + + util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant) + util.coerce_kw_type(opts, "threaded", bool) + util.coerce_kw_type(opts, "events", bool) + util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant) + return ([], opts) + + def _get_server_version_info(self, connection): + return tuple(int(x) for x in connection.connection.version.split(".")) + + def is_disconnect(self, e, connection, cursor): + (error,) = e.args + if isinstance( + e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError) + ) and "not connected" in str(e): + return True + + if hasattr(error, "code") and error.code in { + 28, + 3114, + 3113, + 3135, + 1033, + 2396, + }: + # ORA-00028: your session has been killed + # ORA-03114: not connected to ORACLE + # ORA-03113: end-of-file on communication channel + # ORA-03135: connection lost contact + # ORA-01033: ORACLE initialization or shutdown in progress + # ORA-02396: exceeded maximum idle time, please connect again + # TODO: Others ? + return True + + if re.match(r"^(?:DPI-1010|DPI-1080|DPY-1001|DPY-4011)", str(e)): + # DPI-1010: not connected + # DPI-1080: connection was closed by ORA-3113 + # python-oracledb's DPY-1001: not connected to database + # python-oracledb's DPY-4011: the database or network closed the + # connection + # TODO: others? + return True + + return False + + def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified. + + """ + + id_ = random.randint(0, 2**128) + return (0x1234, "%032x" % id_, "%032x" % 9) + + def do_executemany(self, cursor, statement, parameters, context=None): + if isinstance(parameters, tuple): + parameters = list(parameters) + cursor.executemany(statement, parameters) + + def do_begin_twophase(self, connection, xid): + connection.connection.begin(*xid) + connection.connection.info["cx_oracle_xid"] = xid + + def do_prepare_twophase(self, connection, xid): + result = connection.connection.prepare() + connection.info["cx_oracle_prepared"] = result + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + self.do_rollback(connection.connection) + # TODO: need to end XA state here + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if not is_prepared: + self.do_commit(connection.connection) + else: + if recover: + raise NotImplementedError( + "2pc recovery not implemented for cx_Oracle" + ) + oci_prepared = connection.info["cx_oracle_prepared"] + if oci_prepared: + self.do_commit(connection.connection) + # TODO: need to end XA state here + + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + # not usually used, here to support if someone is modifying + # the dialect to use positional style + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + collection = ( + (key, dbtype) + for key, dbtype, sqltype in list_of_tuples + if dbtype + ) + + cursor.setinputsizes(**{key: dbtype for key, dbtype in collection}) + + def do_recover_twophase(self, connection): + raise NotImplementedError( + "recover two phase query for cx_Oracle not implemented" + ) + + +dialect = OracleDialect_cx_oracle diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/dictionary.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/dictionary.py new file mode 100644 index 0000000..63479b9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/dictionary.py @@ -0,0 +1,507 @@ +# dialects/oracle/dictionary.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .types import DATE +from .types import LONG +from .types import NUMBER +from .types import RAW +from .types import VARCHAR2 +from ... import Column +from ... import MetaData +from ... import Table +from ... import table +from ...sql.sqltypes import CHAR + +# constants +DB_LINK_PLACEHOLDER = "__$sa_dblink$__" +# tables +dual = table("dual") +dictionary_meta = MetaData() + +# NOTE: all the dictionary_meta are aliases because oracle does not like +# using the full table@dblink for every column in query, and complains with +# ORA-00960: ambiguous column naming in select list +all_tables = Table( + "all_tables" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("tablespace_name", VARCHAR2(30)), + Column("cluster_name", VARCHAR2(128)), + Column("iot_name", VARCHAR2(128)), + Column("status", VARCHAR2(8)), + Column("pct_free", NUMBER), + Column("pct_used", NUMBER), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("logging", VARCHAR2(3)), + Column("backed_up", VARCHAR2(1)), + Column("num_rows", NUMBER), + Column("blocks", NUMBER), + Column("empty_blocks", NUMBER), + Column("avg_space", NUMBER), + Column("chain_cnt", NUMBER), + Column("avg_row_len", NUMBER), + Column("avg_space_freelist_blocks", NUMBER), + Column("num_freelist_blocks", NUMBER), + Column("degree", VARCHAR2(10)), + Column("instances", VARCHAR2(10)), + Column("cache", VARCHAR2(5)), + Column("table_lock", VARCHAR2(8)), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("partitioned", VARCHAR2(3)), + Column("iot_type", VARCHAR2(12)), + Column("temporary", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("nested", VARCHAR2(3)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("row_movement", VARCHAR2(8)), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("skip_corrupt", VARCHAR2(8)), + Column("monitoring", VARCHAR2(3)), + Column("cluster_owner", VARCHAR2(128)), + Column("dependencies", VARCHAR2(8)), + Column("compression", VARCHAR2(8)), + Column("compress_for", VARCHAR2(30)), + Column("dropped", VARCHAR2(3)), + Column("read_only", VARCHAR2(3)), + Column("segment_created", VARCHAR2(3)), + Column("result_cache", VARCHAR2(7)), + Column("clustering", VARCHAR2(3)), + Column("activity_tracking", VARCHAR2(23)), + Column("dml_timestamp", VARCHAR2(25)), + Column("has_identity", VARCHAR2(3)), + Column("container_data", VARCHAR2(3)), + Column("inmemory", VARCHAR2(8)), + Column("inmemory_priority", VARCHAR2(8)), + Column("inmemory_distribute", VARCHAR2(15)), + Column("inmemory_compression", VARCHAR2(17)), + Column("inmemory_duplicate", VARCHAR2(13)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("externally_sharded", VARCHAR2(1)), + Column("externally_duplicated", VARCHAR2(1)), + Column("external", VARCHAR2(3)), + Column("hybrid", VARCHAR2(3)), + Column("cellmemory", VARCHAR2(24)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("inmemory_service", VARCHAR2(12)), + Column("inmemory_service_name", VARCHAR2(1000)), + Column("container_map_object", VARCHAR2(3)), + Column("memoptimize_read", VARCHAR2(8)), + Column("memoptimize_write", VARCHAR2(8)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("data_link_dml_enabled", VARCHAR2(3)), + Column("logical_replication", VARCHAR2(8)), +).alias("a_tables") + +all_views = Table( + "all_views" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("view_name", VARCHAR2(128), nullable=False), + Column("text_length", NUMBER), + Column("text", LONG), + Column("text_vc", VARCHAR2(4000)), + Column("type_text_length", NUMBER), + Column("type_text", VARCHAR2(4000)), + Column("oid_text_length", NUMBER), + Column("oid_text", VARCHAR2(4000)), + Column("view_type_owner", VARCHAR2(128)), + Column("view_type", VARCHAR2(128)), + Column("superview_name", VARCHAR2(128)), + Column("editioning_view", VARCHAR2(1)), + Column("read_only", VARCHAR2(1)), + Column("container_data", VARCHAR2(1)), + Column("bequeath", VARCHAR2(12)), + Column("origin_con_id", VARCHAR2(256)), + Column("default_collation", VARCHAR2(100)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("pdb_local_only", VARCHAR2(3)), +).alias("a_views") + +all_sequences = Table( + "all_sequences" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("sequence_owner", VARCHAR2(128), nullable=False), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("min_value", NUMBER), + Column("max_value", NUMBER), + Column("increment_by", NUMBER, nullable=False), + Column("cycle_flag", VARCHAR2(1)), + Column("order_flag", VARCHAR2(1)), + Column("cache_size", NUMBER, nullable=False), + Column("last_number", NUMBER, nullable=False), + Column("scale_flag", VARCHAR2(1)), + Column("extend_flag", VARCHAR2(1)), + Column("sharded_flag", VARCHAR2(1)), + Column("session_flag", VARCHAR2(1)), + Column("keep_value", VARCHAR2(1)), +).alias("a_sequences") + +all_users = Table( + "all_users" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("username", VARCHAR2(128), nullable=False), + Column("user_id", NUMBER, nullable=False), + Column("created", DATE, nullable=False), + Column("common", VARCHAR2(3)), + Column("oracle_maintained", VARCHAR2(1)), + Column("inherited", VARCHAR2(3)), + Column("default_collation", VARCHAR2(100)), + Column("implicit", VARCHAR2(3)), + Column("all_shard", VARCHAR2(3)), + Column("external_shard", VARCHAR2(3)), +).alias("a_users") + +all_mviews = Table( + "all_mviews" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("container_name", VARCHAR2(128), nullable=False), + Column("query", LONG), + Column("query_len", NUMBER(38)), + Column("updatable", VARCHAR2(1)), + Column("update_log", VARCHAR2(128)), + Column("master_rollback_seg", VARCHAR2(128)), + Column("master_link", VARCHAR2(128)), + Column("rewrite_enabled", VARCHAR2(1)), + Column("rewrite_capability", VARCHAR2(9)), + Column("refresh_mode", VARCHAR2(6)), + Column("refresh_method", VARCHAR2(8)), + Column("build_mode", VARCHAR2(9)), + Column("fast_refreshable", VARCHAR2(18)), + Column("last_refresh_type", VARCHAR2(8)), + Column("last_refresh_date", DATE), + Column("last_refresh_end_time", DATE), + Column("staleness", VARCHAR2(19)), + Column("after_fast_refresh", VARCHAR2(19)), + Column("unknown_prebuilt", VARCHAR2(1)), + Column("unknown_plsql_func", VARCHAR2(1)), + Column("unknown_external_table", VARCHAR2(1)), + Column("unknown_consider_fresh", VARCHAR2(1)), + Column("unknown_import", VARCHAR2(1)), + Column("unknown_trusted_fd", VARCHAR2(1)), + Column("compile_state", VARCHAR2(19)), + Column("use_no_index", VARCHAR2(1)), + Column("stale_since", DATE), + Column("num_pct_tables", NUMBER), + Column("num_fresh_pct_regions", NUMBER), + Column("num_stale_pct_regions", NUMBER), + Column("segment_created", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("default_collation", VARCHAR2(100)), + Column("on_query_computation", VARCHAR2(1)), + Column("auto", VARCHAR2(3)), +).alias("a_mviews") + +all_tab_identity_cols = Table( + "all_tab_identity_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("generation_type", VARCHAR2(10)), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("identity_options", VARCHAR2(298)), +).alias("a_tab_identity_cols") + +all_tab_cols = Table( + "all_tab_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("data_type", VARCHAR2(128)), + Column("data_type_mod", VARCHAR2(3)), + Column("data_type_owner", VARCHAR2(128)), + Column("data_length", NUMBER, nullable=False), + Column("data_precision", NUMBER), + Column("data_scale", NUMBER), + Column("nullable", VARCHAR2(1)), + Column("column_id", NUMBER), + Column("default_length", NUMBER), + Column("data_default", LONG), + Column("num_distinct", NUMBER), + Column("low_value", RAW(1000)), + Column("high_value", RAW(1000)), + Column("density", NUMBER), + Column("num_nulls", NUMBER), + Column("num_buckets", NUMBER), + Column("last_analyzed", DATE), + Column("sample_size", NUMBER), + Column("character_set_name", VARCHAR2(44)), + Column("char_col_decl_length", NUMBER), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("avg_col_len", NUMBER), + Column("char_length", NUMBER), + Column("char_used", VARCHAR2(1)), + Column("v80_fmt_image", VARCHAR2(3)), + Column("data_upgraded", VARCHAR2(3)), + Column("hidden_column", VARCHAR2(3)), + Column("virtual_column", VARCHAR2(3)), + Column("segment_column_id", NUMBER), + Column("internal_column_id", NUMBER, nullable=False), + Column("histogram", VARCHAR2(15)), + Column("qualified_col_name", VARCHAR2(4000)), + Column("user_generated", VARCHAR2(3)), + Column("default_on_null", VARCHAR2(3)), + Column("identity_column", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("collation", VARCHAR2(100)), + Column("collated_column_id", NUMBER), +).alias("a_tab_cols") + +all_tab_comments = Table( + "all_tab_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", VARCHAR2(11)), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_tab_comments") + +all_col_comments = Table( + "all_col_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_col_comments") + +all_mview_comments = Table( + "all_mview_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), +).alias("a_mview_comments") + +all_ind_columns = Table( + "all_ind_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("index_owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("column_position", NUMBER, nullable=False), + Column("column_length", NUMBER, nullable=False), + Column("char_length", NUMBER), + Column("descend", VARCHAR2(4)), + Column("collated_column_id", NUMBER), +).alias("a_ind_columns") + +all_indexes = Table( + "all_indexes" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("index_type", VARCHAR2(27)), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", CHAR(11)), + Column("uniqueness", VARCHAR2(9)), + Column("compression", VARCHAR2(13)), + Column("prefix_length", NUMBER), + Column("tablespace_name", VARCHAR2(30)), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("pct_threshold", NUMBER), + Column("include_column", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("pct_free", NUMBER), + Column("logging", VARCHAR2(3)), + Column("blevel", NUMBER), + Column("leaf_blocks", NUMBER), + Column("distinct_keys", NUMBER), + Column("avg_leaf_blocks_per_key", NUMBER), + Column("avg_data_blocks_per_key", NUMBER), + Column("clustering_factor", NUMBER), + Column("status", VARCHAR2(8)), + Column("num_rows", NUMBER), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("degree", VARCHAR2(40)), + Column("instances", VARCHAR2(40)), + Column("partitioned", VARCHAR2(3)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("pct_direct_access", NUMBER), + Column("ityp_owner", VARCHAR2(128)), + Column("ityp_name", VARCHAR2(128)), + Column("parameters", VARCHAR2(1000)), + Column("global_stats", VARCHAR2(3)), + Column("domidx_status", VARCHAR2(12)), + Column("domidx_opstatus", VARCHAR2(6)), + Column("funcidx_status", VARCHAR2(8)), + Column("join_index", VARCHAR2(3)), + Column("iot_redundant_pkey_elim", VARCHAR2(3)), + Column("dropped", VARCHAR2(3)), + Column("visibility", VARCHAR2(9)), + Column("domidx_management", VARCHAR2(14)), + Column("segment_created", VARCHAR2(3)), + Column("orphaned_entries", VARCHAR2(3)), + Column("indexing", VARCHAR2(7)), + Column("auto", VARCHAR2(3)), +).alias("a_indexes") + +all_ind_expressions = Table( + "all_ind_expressions" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("index_owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_expression", LONG), + Column("column_position", NUMBER, nullable=False), +).alias("a_ind_expressions") + +all_constraints = Table( + "all_constraints" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("constraint_name", VARCHAR2(128)), + Column("constraint_type", VARCHAR2(1)), + Column("table_name", VARCHAR2(128)), + Column("search_condition", LONG), + Column("search_condition_vc", VARCHAR2(4000)), + Column("r_owner", VARCHAR2(128)), + Column("r_constraint_name", VARCHAR2(128)), + Column("delete_rule", VARCHAR2(9)), + Column("status", VARCHAR2(8)), + Column("deferrable", VARCHAR2(14)), + Column("deferred", VARCHAR2(9)), + Column("validated", VARCHAR2(13)), + Column("generated", VARCHAR2(14)), + Column("bad", VARCHAR2(3)), + Column("rely", VARCHAR2(4)), + Column("last_change", DATE), + Column("index_owner", VARCHAR2(128)), + Column("index_name", VARCHAR2(128)), + Column("invalid", VARCHAR2(7)), + Column("view_related", VARCHAR2(14)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_constraints") + +all_cons_columns = Table( + "all_cons_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("constraint_name", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("position", NUMBER), +).alias("a_cons_columns") + +# TODO figure out if it's still relevant, since there is no mention from here +# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html +# original note: +# using user_db_links here since all_db_links appears +# to have more restricted permissions. +# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm +# will need to hear from more users if we are doing +# the right thing here. See [ticket:2619] +all_db_links = Table( + "all_db_links" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("db_link", VARCHAR2(128), nullable=False), + Column("username", VARCHAR2(128)), + Column("host", VARCHAR2(2000)), + Column("created", DATE, nullable=False), + Column("hidden", VARCHAR2(3)), + Column("shard_internal", VARCHAR2(3)), + Column("valid", VARCHAR2(3)), + Column("intra_cdb", VARCHAR2(3)), +).alias("a_db_links") + +all_synonyms = Table( + "all_synonyms" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("synonym_name", VARCHAR2(128)), + Column("table_owner", VARCHAR2(128)), + Column("table_name", VARCHAR2(128)), + Column("db_link", VARCHAR2(128)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_synonyms") + +all_objects = Table( + "all_objects" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("object_name", VARCHAR2(128), nullable=False), + Column("subobject_name", VARCHAR2(128)), + Column("object_id", NUMBER, nullable=False), + Column("data_object_id", NUMBER), + Column("object_type", VARCHAR2(23)), + Column("created", DATE, nullable=False), + Column("last_ddl_time", DATE, nullable=False), + Column("timestamp", VARCHAR2(19)), + Column("status", VARCHAR2(7)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("namespace", NUMBER, nullable=False), + Column("edition_name", VARCHAR2(128)), + Column("sharing", VARCHAR2(13)), + Column("editionable", VARCHAR2(1)), + Column("oracle_maintained", VARCHAR2(1)), + Column("application", VARCHAR2(1)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("created_appid", NUMBER), + Column("created_vsnid", NUMBER), + Column("modified_appid", NUMBER), + Column("modified_vsnid", NUMBER), +).alias("a_objects") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/oracledb.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/oracledb.py new file mode 100644 index 0000000..9cdec3b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/oracledb.py @@ -0,0 +1,311 @@ +# dialects/oracle/oracledb.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: oracle+oracledb + :name: python-oracledb + :dbapi: oracledb + :connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] + :url: https://oracle.github.io/python-oracledb/ + +python-oracledb is released by Oracle to supersede the cx_Oracle driver. +It is fully compatible with cx_Oracle and features both a "thin" client +mode that requires no dependencies, as well as a "thick" mode that uses +the Oracle Client Interface in the same way as cx_Oracle. + +.. seealso:: + + :ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver + as well. + +The SQLAlchemy ``oracledb`` dialect provides both a sync and an async +implementation under the same dialect name. The proper version is +selected depending on how the engine is created: + +* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will + automatically select the sync version, e.g.:: + + from sqlalchemy import create_engine + sync_engine = create_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1") + +* calling :func:`_asyncio.create_async_engine` with + ``oracle+oracledb://...`` will automatically select the async version, + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1") + +The asyncio version of the dialect may also be specified explicitly using the +``oracledb_async`` suffix, as:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("oracle+oracledb_async://scott:tiger@localhost/?service_name=XEPDB1") + +.. versionadded:: 2.0.25 added support for the async version of oracledb. + +Thick mode support +------------------ + +By default the ``python-oracledb`` is started in thin mode, that does not +require oracle client libraries to be installed in the system. The +``python-oracledb`` driver also support a "thick" mode, that behaves +similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI) +is installed. + +To enable this mode, the user may call ``oracledb.init_oracle_client`` +manually, or by passing the parameter ``thick_mode=True`` to +:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``, +like the ``lib_dir`` path, a dict may be passed to this parameter, as in:: + + engine = sa.create_engine("oracle+oracledb://...", thick_mode={ + "lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app" + }) + +.. seealso:: + + https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client + + +.. versionadded:: 2.0.0 added support for oracledb driver. + +""" # noqa +from __future__ import annotations + +import collections +import re +from typing import Any +from typing import TYPE_CHECKING + +from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle +from ... import exc +from ... import pool +from ...connectors.asyncio import AsyncAdapt_dbapi_connection +from ...connectors.asyncio import AsyncAdapt_dbapi_cursor +from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection +from ...util import asbool +from ...util import await_fallback +from ...util import await_only + +if TYPE_CHECKING: + from oracledb import AsyncConnection + from oracledb import AsyncCursor + + +class OracleDialect_oracledb(_OracleDialect_cx_oracle): + supports_statement_cache = True + driver = "oracledb" + _min_version = (1,) + + def __init__( + self, + auto_convert_lobs=True, + coerce_to_decimal=True, + arraysize=None, + encoding_errors=None, + thick_mode=None, + **kwargs, + ): + super().__init__( + auto_convert_lobs, + coerce_to_decimal, + arraysize, + encoding_errors, + **kwargs, + ) + + if self.dbapi is not None and ( + thick_mode or isinstance(thick_mode, dict) + ): + kw = thick_mode if isinstance(thick_mode, dict) else {} + self.dbapi.init_oracle_client(**kw) + + @classmethod + def import_dbapi(cls): + import oracledb + + return oracledb + + @classmethod + def is_thin_mode(cls, connection): + return connection.connection.dbapi_connection.thin + + @classmethod + def get_async_dialect_cls(cls, url): + return OracleDialectAsync_oracledb + + def _load_version(self, dbapi_module): + version = (0, 0, 0) + if dbapi_module is not None: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", dbapi_module.version) + if m: + version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + self.oracledb_ver = version + if ( + self.oracledb_ver > (0, 0, 0) + and self.oracledb_ver < self._min_version + ): + raise exc.InvalidRequestError( + f"oracledb version {self._min_version} and above are supported" + ) + + +class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor): + _cursor: AsyncCursor + __slots__ = () + + @property + def outputtypehandler(self): + return self._cursor.outputtypehandler + + @outputtypehandler.setter + def outputtypehandler(self, value): + self._cursor.outputtypehandler = value + + def var(self, *args, **kwargs): + return self._cursor.var(*args, **kwargs) + + def close(self): + self._rows.clear() + self._cursor.close() + + def setinputsizes(self, *args: Any, **kwargs: Any) -> Any: + return self._cursor.setinputsizes(*args, **kwargs) + + def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor: + try: + return cursor.__enter__() + except Exception as error: + self._adapt_connection._handle_exception(error) + + async def _execute_async(self, operation, parameters): + # override to not use mutex, oracledb already has mutex + + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) + + if self._cursor.description and not self.server_side: + self._rows = collections.deque(await self._cursor.fetchall()) + return result + + async def _executemany_async( + self, + operation, + seq_of_parameters, + ): + # override to not use mutex, oracledb already has mutex + return await self._cursor.executemany(operation, seq_of_parameters) + + def __enter__(self): + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + +class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): + _connection: AsyncConnection + __slots__ = () + + thin = True + + _cursor_cls = AsyncAdapt_oracledb_cursor + _ss_cursor_cls = None + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + self._connection.autocommit = value + + @property + def outputtypehandler(self): + return self._connection.outputtypehandler + + @outputtypehandler.setter + def outputtypehandler(self, value): + self._connection.outputtypehandler = value + + @property + def version(self): + return self._connection.version + + @property + def stmtcachesize(self): + return self._connection.stmtcachesize + + @stmtcachesize.setter + def stmtcachesize(self, value): + self._connection.stmtcachesize = value + + def cursor(self): + return AsyncAdapt_oracledb_cursor(self) + + +class AsyncAdaptFallback_oracledb_connection( + AsyncAdaptFallback_dbapi_connection, AsyncAdapt_oracledb_connection +): + __slots__ = () + + +class OracledbAdaptDBAPI: + def __init__(self, oracledb) -> None: + self.oracledb = oracledb + + for k, v in self.oracledb.__dict__.items(): + if k != "connect": + self.__dict__[k] = v + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async) + + if asbool(async_fallback): + return AsyncAdaptFallback_oracledb_connection( + self, await_fallback(creator_fn(*arg, **kw)) + ) + + else: + return AsyncAdapt_oracledb_connection( + self, await_only(creator_fn(*arg, **kw)) + ) + + +class OracleDialectAsync_oracledb(OracleDialect_oracledb): + is_async = True + supports_statement_cache = True + + _min_version = (2,) + + # thick_mode mode is not supported by asyncio, oracledb will raise + @classmethod + def import_dbapi(cls): + import oracledb + + return OracledbAdaptDBAPI(oracledb) + + @classmethod + def get_pool_class(cls, url): + async_fallback = url.query.get("async_fallback", False) + + if asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = OracleDialect_oracledb +dialect_async = OracleDialectAsync_oracledb diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/provision.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/provision.py new file mode 100644 index 0000000..b33c152 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/provision.py @@ -0,0 +1,220 @@ +# dialects/oracle/provision.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import create_engine +from ... import exc +from ... import inspect +from ...engine import url as sa_url +from ...testing.provision import configure_follower +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import follower_url_from_main +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import run_reap_dbs +from ...testing.provision import set_default_schema_on_connection +from ...testing.provision import stop_test_class_outside_fixtures +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import update_db_opts + + +@create_db.for_db("oracle") +def _oracle_create_db(cfg, eng, ident): + # NOTE: make sure you've run "ALTER DATABASE default tablespace users" or + # similar, so that the default tablespace is not "system"; reflection will + # fail otherwise + with eng.begin() as conn: + conn.exec_driver_sql("create user %s identified by xe" % ident) + conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident) + conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident) + conn.exec_driver_sql("grant dba to %s" % (ident,)) + conn.exec_driver_sql("grant unlimited tablespace to %s" % ident) + conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident) + conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident) + # these are needed to create materialized views + conn.exec_driver_sql("grant create table to %s" % ident) + conn.exec_driver_sql("grant create table to %s_ts1" % ident) + conn.exec_driver_sql("grant create table to %s_ts2" % ident) + + +@configure_follower.for_db("oracle") +def _oracle_configure_follower(config, ident): + config.test_schema = "%s_ts1" % ident + config.test_schema_2 = "%s_ts2" % ident + + +def _ora_drop_ignore(conn, dbname): + try: + conn.exec_driver_sql("drop user %s cascade" % dbname) + log.info("Reaped db: %s", dbname) + return True + except exc.DatabaseError as err: + log.warning("couldn't drop db: %s", err) + return False + + +@drop_all_schema_objects_pre_tables.for_db("oracle") +def _ora_drop_all_schema_objects_pre_tables(cfg, eng): + _purge_recyclebin(eng) + _purge_recyclebin(eng, cfg.test_schema) + + +@drop_all_schema_objects_post_tables.for_db("oracle") +def _ora_drop_all_schema_objects_post_tables(cfg, eng): + with eng.begin() as conn: + for syn in conn.dialect._get_synonyms(conn, None, None, None): + conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}") + + for syn in conn.dialect._get_synonyms( + conn, cfg.test_schema, None, None + ): + conn.exec_driver_sql( + f"drop synonym {cfg.test_schema}.{syn['synonym_name']}" + ) + + for tmp_table in inspect(conn).get_temp_table_names(): + conn.exec_driver_sql(f"drop table {tmp_table}") + + +@drop_db.for_db("oracle") +def _oracle_drop_db(cfg, eng, ident): + with eng.begin() as conn: + # cx_Oracle seems to occasionally leak open connections when a large + # suite it run, even if we confirm we have zero references to + # connection objects. + # while there is a "kill session" command in Oracle, + # it unfortunately does not release the connection sufficiently. + _ora_drop_ignore(conn, ident) + _ora_drop_ignore(conn, "%s_ts1" % ident) + _ora_drop_ignore(conn, "%s_ts2" % ident) + + +@stop_test_class_outside_fixtures.for_db("oracle") +def _ora_stop_test_class_outside_fixtures(config, db, cls): + try: + _purge_recyclebin(db) + except exc.DatabaseError as err: + log.warning("purge recyclebin command failed: %s", err) + + # clear statement cache on all connections that were used + # https://github.com/oracle/python-cx_Oracle/issues/519 + + for cx_oracle_conn in _all_conns: + try: + sc = cx_oracle_conn.stmtcachesize + except db.dialect.dbapi.InterfaceError: + # connection closed + pass + else: + cx_oracle_conn.stmtcachesize = 0 + cx_oracle_conn.stmtcachesize = sc + _all_conns.clear() + + +def _purge_recyclebin(eng, schema=None): + with eng.begin() as conn: + if schema is None: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 + conn.exec_driver_sql("purge recyclebin") + else: + # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501 + for owner, object_name, type_ in conn.exec_driver_sql( + "select owner, object_name,type from " + "dba_recyclebin where owner=:schema and type='TABLE'", + {"schema": conn.dialect.denormalize_name(schema)}, + ).all(): + conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"') + + +_all_conns = set() + + +@post_configure_engine.for_db("oracle") +def _oracle_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + @event.listens_for(engine, "checkout") + def checkout(dbapi_con, con_record, con_proxy): + _all_conns.add(dbapi_con) + + @event.listens_for(engine, "checkin") + def checkin(dbapi_connection, connection_record): + # work around cx_Oracle issue: + # https://github.com/oracle/python-cx_Oracle/issues/530 + # invalidate oracle connections that had 2pc set up + if "cx_oracle_xid" in connection_record.info: + connection_record.invalidate() + + +@run_reap_dbs.for_db("oracle") +def _reap_oracle_dbs(url, idents): + log.info("db reaper connecting to %r", url) + eng = create_engine(url) + with eng.begin() as conn: + log.info("identifiers in file: %s", ", ".join(idents)) + + to_reap = conn.exec_driver_sql( + "select u.username from all_users u where username " + "like 'TEST_%' and not exists (select username " + "from v$session where username=u.username)" + ) + all_names = {username.lower() for (username,) in to_reap} + to_drop = set() + for name in all_names: + if name.endswith("_ts1") or name.endswith("_ts2"): + continue + elif name in idents: + to_drop.add(name) + if "%s_ts1" % name in all_names: + to_drop.add("%s_ts1" % name) + if "%s_ts2" % name in all_names: + to_drop.add("%s_ts2" % name) + + dropped = total = 0 + for total, username in enumerate(to_drop, 1): + if _ora_drop_ignore(conn, username): + dropped += 1 + log.info( + "Dropped %d out of %d stale databases detected", dropped, total + ) + + +@follower_url_from_main.for_db("oracle") +def _oracle_follower_url_from_main(url, ident): + url = sa_url.make_url(url) + return url.set(username=ident, password="xe") + + +@temp_table_keyword_args.for_db("oracle") +def _oracle_temp_table_keyword_args(cfg, eng): + return { + "prefixes": ["GLOBAL TEMPORARY"], + "oracle_on_commit": "PRESERVE ROWS", + } + + +@set_default_schema_on_connection.for_db("oracle") +def _oracle_set_default_schema_on_connection( + cfg, dbapi_connection, schema_name +): + cursor = dbapi_connection.cursor() + cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name) + cursor.close() + + +@update_db_opts.for_db("oracle") +def _update_db_opts(db_url, db_opts, options): + """Set database options (db_opts) for a test database that we created.""" + if ( + options.oracledb_thick_mode + and sa_url.make_url(db_url).get_driver_name() == "oracledb" + ): + db_opts["thick_mode"] = True diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/types.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/types.py new file mode 100644 index 0000000..36caaa0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/oracle/types.py @@ -0,0 +1,287 @@ +# dialects/oracle/types.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +import datetime as dt +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING + +from ... import exc +from ...sql import sqltypes +from ...types import NVARCHAR +from ...types import VARCHAR + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.type_api import _LiteralProcessorType + + +class RAW(sqltypes._Binary): + __visit_name__ = "RAW" + + +OracleRaw = RAW + + +class NCLOB(sqltypes.Text): + __visit_name__ = "NCLOB" + + +class VARCHAR2(VARCHAR): + __visit_name__ = "VARCHAR2" + + +NVARCHAR2 = NVARCHAR + + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = "NUMBER" + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super().__init__(precision=precision, scale=scale, asdecimal=asdecimal) + + def adapt(self, impltype): + ret = super().adapt(impltype) + # leave a hint for the DBAPI handler + ret._is_oracle_number = True + return ret + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class FLOAT(sqltypes.FLOAT): + """Oracle FLOAT. + + This is the same as :class:`_sqltypes.FLOAT` except that + an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` + parameter is accepted, and + the :paramref:`_sqltypes.Float.precision` parameter is not accepted. + + Oracle FLOAT types indicate precision in terms of "binary precision", which + defaults to 126. For a REAL type, the value is 63. This parameter does not + cleanly map to a specific number of decimal places but is roughly + equivalent to the desired number of decimal places divided by 0.3103. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "FLOAT" + + def __init__( + self, + binary_precision=None, + asdecimal=False, + decimal_return_scale=None, + ): + r""" + Construct a FLOAT + + :param binary_precision: Oracle binary precision value to be rendered + in DDL. This may be approximated to the number of decimal characters + using the formula "decimal precision = 0.30103 * binary precision". + The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. + + :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` + + :param decimal_return_scale: See + :paramref:`_sqltypes.Float.decimal_return_scale` + + """ + super().__init__( + asdecimal=asdecimal, decimal_return_scale=decimal_return_scale + ) + self.binary_precision = binary_precision + + +class BINARY_DOUBLE(sqltypes.Double): + __visit_name__ = "BINARY_DOUBLE" + + +class BINARY_FLOAT(sqltypes.Float): + __visit_name__ = "BINARY_FLOAT" + + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = "BFILE" + + +class LONG(sqltypes.Text): + __visit_name__ = "LONG" + + +class _OracleDateLiteralRender: + def _literal_processor_datetime(self, dialect): + def process(value): + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS.FF')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS')""" + ) + return value + + return process + + def _literal_processor_date(self, dialect): + def process(value): + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + return value + + return process + + +class DATE(_OracleDateLiteralRender, sqltypes.DateTime): + """Provide the oracle DATE type. + + This type has no special Python behavior, except that it subclasses + :class:`_types.DateTime`; this is to suit the fact that the Oracle + ``DATE`` type supports a time value. + + """ + + __visit_name__ = "DATE" + + def literal_processor(self, dialect): + return self._literal_processor_datetime(dialect) + + def _compare_type_affinity(self, other): + return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) + + +class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): + def literal_processor(self, dialect): + return self._literal_processor_date(dialect) + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + __visit_name__ = "INTERVAL" + + def __init__(self, day_precision=None, second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs. + + :param day_precision: the day precision value. this is the number of + digits to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the + number of digits to store for the fractional seconds field. + Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) + + @classmethod + def adapt_emulated_to_native( + cls, interval: sqltypes.Interval, **kw # type: ignore[override] + ): + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval( + native=True, + second_precision=self.second_precision, + day_precision=self.day_precision, + ) + + @property + def python_type(self) -> Type[dt.timedelta]: + return dt.timedelta + + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[dt.timedelta]]: + def process(value: dt.timedelta) -> str: + return f"NUMTODSINTERVAL({value.total_seconds()}, 'SECOND')" + + return process + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """Oracle implementation of ``TIMESTAMP``, which supports additional + Oracle-specific modes + + .. versionadded:: 2.0 + + """ + + def __init__(self, timezone: bool = False, local_timezone: bool = False): + """Construct a new :class:`_oracle.TIMESTAMP`. + + :param timezone: boolean. Indicates that the TIMESTAMP type should + use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype. + + :param local_timezone: boolean. Indicates that the TIMESTAMP type + should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype. + + + """ + if timezone and local_timezone: + raise exc.ArgumentError( + "timezone and local_timezone are mutually exclusive" + ) + super().__init__(timezone=timezone) + self.local_timezone = local_timezone + + +class ROWID(sqltypes.TypeEngine): + """Oracle ROWID type. + + When used in a cast() or similar, generates ROWID. + + """ + + __visit_name__ = "ROWID" + + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__init__.py new file mode 100644 index 0000000..325ea88 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__init__.py @@ -0,0 +1,167 @@ +# dialects/postgresql/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from types import ModuleType + +from . import array as arraylib # noqa # keep above base and other dialects +from . import asyncpg # noqa +from . import base +from . import pg8000 # noqa +from . import psycopg # noqa +from . import psycopg2 # noqa +from . import psycopg2cffi # noqa +from .array import All +from .array import Any +from .array import ARRAY +from .array import array +from .base import BIGINT +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DOMAIN +from .base import DOUBLE_PRECISION +from .base import FLOAT +from .base import INTEGER +from .base import NUMERIC +from .base import REAL +from .base import SMALLINT +from .base import TEXT +from .base import UUID +from .base import VARCHAR +from .dml import Insert +from .dml import insert +from .ext import aggregate_order_by +from .ext import array_agg +from .ext import ExcludeConstraint +from .ext import phraseto_tsquery +from .ext import plainto_tsquery +from .ext import to_tsquery +from .ext import to_tsvector +from .ext import ts_headline +from .ext import websearch_to_tsquery +from .hstore import HSTORE +from .hstore import hstore +from .json import JSON +from .json import JSONB +from .json import JSONPATH +from .named_types import CreateDomainType +from .named_types import CreateEnumType +from .named_types import DropDomainType +from .named_types import DropEnumType +from .named_types import ENUM +from .named_types import NamedType +from .ranges import AbstractMultiRange +from .ranges import AbstractRange +from .ranges import AbstractSingleRange +from .ranges import DATEMULTIRANGE +from .ranges import DATERANGE +from .ranges import INT4MULTIRANGE +from .ranges import INT4RANGE +from .ranges import INT8MULTIRANGE +from .ranges import INT8RANGE +from .ranges import MultiRange +from .ranges import NUMMULTIRANGE +from .ranges import NUMRANGE +from .ranges import Range +from .ranges import TSMULTIRANGE +from .ranges import TSRANGE +from .ranges import TSTZMULTIRANGE +from .ranges import TSTZRANGE +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CITEXT +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MACADDR8 +from .types import MONEY +from .types import OID +from .types import REGCLASS +from .types import REGCONFIG +from .types import TIME +from .types import TIMESTAMP +from .types import TSQUERY +from .types import TSVECTOR + + +# Alias psycopg also as psycopg_async +psycopg_async = type( + "psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async} +) + +base.dialect = dialect = psycopg2.dialect + + +__all__ = ( + "INTEGER", + "BIGINT", + "SMALLINT", + "VARCHAR", + "CHAR", + "TEXT", + "NUMERIC", + "FLOAT", + "REAL", + "INET", + "CIDR", + "CITEXT", + "UUID", + "BIT", + "MACADDR", + "MACADDR8", + "MONEY", + "OID", + "REGCLASS", + "REGCONFIG", + "TSQUERY", + "TSVECTOR", + "DOUBLE_PRECISION", + "TIMESTAMP", + "TIME", + "DATE", + "BYTEA", + "BOOLEAN", + "INTERVAL", + "ARRAY", + "ENUM", + "DOMAIN", + "dialect", + "array", + "HSTORE", + "hstore", + "INT4RANGE", + "INT8RANGE", + "NUMRANGE", + "DATERANGE", + "INT4MULTIRANGE", + "INT8MULTIRANGE", + "NUMMULTIRANGE", + "DATEMULTIRANGE", + "TSVECTOR", + "TSRANGE", + "TSTZRANGE", + "TSMULTIRANGE", + "TSTZMULTIRANGE", + "JSON", + "JSONB", + "JSONPATH", + "Any", + "All", + "DropEnumType", + "DropDomainType", + "CreateDomainType", + "NamedType", + "CreateEnumType", + "ExcludeConstraint", + "Range", + "aggregate_order_by", + "array_agg", + "insert", + "Insert", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..620cf0d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/_psycopg_common.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/_psycopg_common.cpython-311.pyc new file mode 100644 index 0000000..fc1553f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/_psycopg_common.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/array.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/array.cpython-311.pyc new file mode 100644 index 0000000..a257440 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/array.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/asyncpg.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/asyncpg.cpython-311.pyc new file mode 100644 index 0000000..50a38f0 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/asyncpg.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..b5aaeaa Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/dml.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/dml.cpython-311.pyc new file mode 100644 index 0000000..27eae2b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/dml.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ext.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ext.cpython-311.pyc new file mode 100644 index 0000000..7ae58de Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ext.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/hstore.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/hstore.cpython-311.pyc new file mode 100644 index 0000000..c004810 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/hstore.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/json.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/json.cpython-311.pyc new file mode 100644 index 0000000..c429892 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/json.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/named_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/named_types.cpython-311.pyc new file mode 100644 index 0000000..a075cd8 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/named_types.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/operators.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/operators.cpython-311.pyc new file mode 100644 index 0000000..93ccdfb Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/operators.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg8000.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg8000.cpython-311.pyc new file mode 100644 index 0000000..1bbbf9c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg8000.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg_catalog.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg_catalog.cpython-311.pyc new file mode 100644 index 0000000..23fc72e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/pg_catalog.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/provision.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/provision.cpython-311.pyc new file mode 100644 index 0000000..39fb4c6 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/provision.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg.cpython-311.pyc new file mode 100644 index 0000000..aa442c1 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2.cpython-311.pyc new file mode 100644 index 0000000..1308229 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2cffi.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2cffi.cpython-311.pyc new file mode 100644 index 0000000..bd4f6c4 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/psycopg2cffi.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ranges.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ranges.cpython-311.pyc new file mode 100644 index 0000000..d0a785d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/ranges.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/types.cpython-311.pyc new file mode 100644 index 0000000..ed561da Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/__pycache__/types.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py new file mode 100644 index 0000000..46858c9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -0,0 +1,187 @@ +# dialects/postgresql/_psycopg_common.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +import decimal + +from .array import ARRAY as PGARRAY +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import PGDialect +from .base import PGExecutionContext +from .hstore import HSTORE +from .pg_catalog import _SpaceVector +from .pg_catalog import INT2VECTOR +from .pg_catalog import OIDVECTOR +from ... import exc +from ... import types as sqltypes +from ... import util +from ...engine import processors + +_server_side_id = util.counter() + + +class _PsycopgNumeric(sqltypes.Numeric): + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # psycopg returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # psycopg returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PsycopgFloat(_PsycopgNumeric): + __visit_name__ = "float" + + +class _PsycopgHStore(HSTORE): + def bind_processor(self, dialect): + if dialect._has_native_hstore: + return None + else: + return super().bind_processor(dialect) + + def result_processor(self, dialect, coltype): + if dialect._has_native_hstore: + return None + else: + return super().result_processor(dialect, coltype) + + +class _PsycopgARRAY(PGARRAY): + render_bind_cast = True + + +class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR): + pass + + +class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + +class _PGExecutionContext_common_psycopg(PGExecutionContext): + def create_server_side_cursor(self): + # use server-side cursors: + # psycopg + # https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors + # psycopg2 + # https://www.psycopg.org/docs/usage.html#server-side-cursors + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return self._dbapi_connection.cursor(ident) + + +class _PGDialect_common_psycopg(PGDialect): + supports_statement_cache = True + supports_server_side_cursors = True + + default_paramstyle = "pyformat" + + _has_native_hstore = True + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.Numeric: _PsycopgNumeric, + sqltypes.Float: _PsycopgFloat, + HSTORE: _PsycopgHStore, + sqltypes.ARRAY: _PsycopgARRAY, + INT2VECTOR: _PsycopgINT2VECTOR, + OIDVECTOR: _PsycopgOIDVECTOR, + }, + ) + + def __init__( + self, + client_encoding=None, + use_native_hstore=True, + **kwargs, + ): + PGDialect.__init__(self, **kwargs) + if not use_native_hstore: + self._has_native_hstore = False + self.use_native_hstore = use_native_hstore + self.client_encoding = client_encoding + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user", database="dbname") + + multihosts, multiports = self._split_multihost_from_url(url) + + if opts or url.query: + if not opts: + opts = {} + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + + if multihosts: + opts["host"] = ",".join(multihosts) + comma_ports = ",".join(str(p) if p else "" for p in multiports) + if comma_ports: + opts["port"] = comma_ports + return ([], opts) + else: + # no connection arguments whatsoever; psycopg2.connect() + # requires that "dsn" be present as a blank string. + return ([""], opts) + + def get_isolation_level_values(self, dbapi_connection): + return ( + "AUTOCOMMIT", + "READ COMMITTED", + "READ UNCOMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", + ) + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def _do_autocommit(self, connection, value): + connection.autocommit = value + + def do_ping(self, dbapi_connection): + cursor = None + before_autocommit = dbapi_connection.autocommit + + if not before_autocommit: + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + try: + cursor.execute(self._dialect_specific_select_one) + finally: + cursor.close() + if not before_autocommit and not dbapi_connection.closed: + dbapi_connection.autocommit = before_autocommit + + return True diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/array.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/array.py new file mode 100644 index 0000000..1d63655 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/array.py @@ -0,0 +1,425 @@ +# dialects/postgresql/array.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import re +from typing import Any +from typing import Optional +from typing import TypeVar + +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import OVERLAP +from ... import types as sqltypes +from ... import util +from ...sql import expression +from ...sql import operators +from ...sql._typing import _TypeEngineArgument + + +_T = TypeVar("_T", bound=Any) + + +def Any(other, arrexpr, operator=operators.eq): + """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method. + See that method for details. + + """ + + return arrexpr.any(other, operator) + + +def All(other, arrexpr, operator=operators.eq): + """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method. + See that method for details. + + """ + + return arrexpr.all(other, operator) + + +class array(expression.ExpressionClauseList[_T]): + """A PostgreSQL ARRAY literal. + + This is used to produce ARRAY literals in SQL expressions, e.g.:: + + from sqlalchemy.dialects.postgresql import array + from sqlalchemy.dialects import postgresql + from sqlalchemy import select, func + + stmt = select(array([1,2]) + array([3,4,5])) + + print(stmt.compile(dialect=postgresql.dialect())) + + Produces the SQL:: + + SELECT ARRAY[%(param_1)s, %(param_2)s] || + ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 + + An instance of :class:`.array` will always have the datatype + :class:`_types.ARRAY`. The "inner" type of the array is inferred from + the values present, unless the ``type_`` keyword argument is passed:: + + array(['foo', 'bar'], type_=CHAR) + + Multidimensional arrays are produced by nesting :class:`.array` constructs. + The dimensionality of the final :class:`_types.ARRAY` + type is calculated by + recursively adding the dimensions of the inner :class:`_types.ARRAY` + type:: + + stmt = select( + array([ + array([1, 2]), array([3, 4]), array([column('q'), column('x')]) + ]) + ) + print(stmt.compile(dialect=postgresql.dialect())) + + Produces:: + + SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s], + ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1 + + .. versionadded:: 1.3.6 added support for multidimensional array literals + + .. seealso:: + + :class:`_postgresql.ARRAY` + + """ + + __visit_name__ = "array" + + stringify_dialect = "postgresql" + inherit_cache = True + + def __init__(self, clauses, **kw): + type_arg = kw.pop("type_", None) + super().__init__(operators.comma_op, *clauses, **kw) + + self._type_tuple = [arg.type for arg in self.clauses] + + main_type = ( + type_arg + if type_arg is not None + else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE + ) + + if isinstance(main_type, ARRAY): + self.type = ARRAY( + main_type.item_type, + dimensions=( + main_type.dimensions + 1 + if main_type.dimensions is not None + else 2 + ), + ) + else: + self.type = ARRAY(main_type) + + @property + def _select_iterable(self): + return (self,) + + def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): + if _assume_scalar or operator is operators.getitem: + return expression.BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + ) + + else: + return array( + [ + self._bind_param( + operator, o, _assume_scalar=True, type_=type_ + ) + for o in obj + ] + ) + + def self_group(self, against=None): + if against in (operators.any_op, operators.all_op, operators.getitem): + return expression.Grouping(self) + else: + return self + + +class ARRAY(sqltypes.ARRAY): + """PostgreSQL ARRAY type. + + The :class:`_postgresql.ARRAY` type is constructed in the same way + as the core :class:`_types.ARRAY` type; a member type is required, and a + number of dimensions is recommended if the type is to be used for more + than one dimension:: + + from sqlalchemy.dialects import postgresql + + mytable = Table("mytable", metadata, + Column("data", postgresql.ARRAY(Integer, dimensions=2)) + ) + + The :class:`_postgresql.ARRAY` type provides all operations defined on the + core :class:`_types.ARRAY` type, including support for "dimensions", + indexed access, and simple matching such as + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY` + class also + provides PostgreSQL-specific methods for containment operations, including + :meth:`.postgresql.ARRAY.Comparator.contains` + :meth:`.postgresql.ARRAY.Comparator.contained_by`, and + :meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.:: + + mytable.c.data.contains([1, 2]) + + Indexed access is one-based by default, to match that of PostgreSQL; + for zero-based indexed access, set + :paramref:`_postgresql.ARRAY.zero_indexes`. + + Additionally, the :class:`_postgresql.ARRAY` + type does not work directly in + conjunction with the :class:`.ENUM` type. For a workaround, see the + special type at :ref:`postgresql_array_of_enum`. + + .. container:: topic + + **Detecting Changes in ARRAY columns when using the ORM** + + The :class:`_postgresql.ARRAY` type, when used with the SQLAlchemy ORM, + does not detect in-place mutations to the array. In order to detect + these, the :mod:`sqlalchemy.ext.mutable` extension must be used, using + the :class:`.MutableList` class:: + + from sqlalchemy.dialects.postgresql import ARRAY + from sqlalchemy.ext.mutable import MutableList + + class SomeOrmClass(Base): + # ... + + data = Column(MutableList.as_mutable(ARRAY(Integer))) + + This extension will allow "in-place" changes such to the array + such as ``.append()`` to produce events which will be detected by the + unit of work. Note that changes to elements **inside** the array, + including subarrays that are mutated in place, are **not** detected. + + Alternatively, assigning a new array value to an ORM element that + replaces the old one will always trigger a change event. + + .. seealso:: + + :class:`_types.ARRAY` - base array type + + :class:`_postgresql.array` - produces a literal array value. + + """ + + def __init__( + self, + item_type: _TypeEngineArgument[Any], + as_tuple: bool = False, + dimensions: Optional[int] = None, + zero_indexes: bool = False, + ): + """Construct an ARRAY. + + E.g.:: + + Column('myarray', ARRAY(Integer)) + + Arguments are: + + :param item_type: The data type of items of this array. Note that + dimensionality is irrelevant here, so multi-dimensional arrays like + ``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as + ``ARRAY(ARRAY(Integer))`` or such. + + :param as_tuple=False: Specify whether return results + should be converted to tuples from lists. DBAPIs such + as psycopg2 return lists by default. When tuples are + returned, the results are hashable. + + :param dimensions: if non-None, the ARRAY will assume a fixed + number of dimensions. This will cause the DDL emitted for this + ARRAY to include the exact number of bracket clauses ``[]``, + and will also optimize the performance of the type overall. + Note that PG arrays are always implicitly "non-dimensioned", + meaning they can store any number of dimensions no matter how + they were declared. + + :param zero_indexes=False: when True, index values will be converted + between Python zero-based and PostgreSQL one-based indexes, e.g. + a value of one will be added to all index values before passing + to the database. + + """ + if isinstance(item_type, ARRAY): + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) + if isinstance(item_type, type): + item_type = item_type() + self.item_type = item_type + self.as_tuple = as_tuple + self.dimensions = dimensions + self.zero_indexes = zero_indexes + + class Comparator(sqltypes.ARRAY.Comparator): + """Define comparison operations for :class:`_types.ARRAY`. + + Note that these operations are in addition to those provided + by the base :class:`.types.ARRAY.Comparator` class, including + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. + + """ + + def contains(self, other, **kwargs): + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def overlap(self, other): + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) + + comparator_factory = Comparator + + @property + def hashable(self): + return self.as_tuple + + @property + def python_type(self): + return list + + def compare_values(self, x, y): + return x == y + + @util.memoized_property + def _against_native_enum(self): + return ( + isinstance(self.item_type, sqltypes.Enum) + and self.item_type.native_enum + ) + + def literal_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).literal_processor( + dialect + ) + if item_proc is None: + return None + + def to_str(elements): + return f"ARRAY[{', '.join(elements)}]" + + def process(value): + inner = self._apply_item_processor( + value, item_proc, self.dimensions, to_str + ) + return inner + + return process + + def bind_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).bind_processor( + dialect + ) + + def process(value): + if value is None: + return value + else: + return self._apply_item_processor( + value, item_proc, self.dimensions, list + ) + + return process + + def result_processor(self, dialect, coltype): + item_proc = self.item_type.dialect_impl(dialect).result_processor( + dialect, coltype + ) + + def process(value): + if value is None: + return value + else: + return self._apply_item_processor( + value, + item_proc, + self.dimensions, + tuple if self.as_tuple else list, + ) + + if self._against_native_enum: + super_rp = process + pattern = re.compile(r"^{(.*)}$") + + def handle_raw_string(value): + inner = pattern.match(value).group(1) + return _split_enum_values(inner) + + def process(value): + if value is None: + return value + # isinstance(value, str) is required to handle + # the case where a TypeDecorator for and Array of Enum is + # used like was required in sa < 1.3.17 + return super_rp( + handle_raw_string(value) + if isinstance(value, str) + else value + ) + + return process + + +def _split_enum_values(array_string): + if '"' not in array_string: + # no escape char is present so it can just split on the comma + return array_string.split(",") if array_string else [] + + # handles quoted strings from: + # r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr' + # returns + # ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr'] + text = array_string.replace(r"\"", "_$ESC_QUOTE$_") + text = text.replace(r"\\", "\\") + result = [] + on_quotes = re.split(r'(")', text) + in_quotes = False + for tok in on_quotes: + if tok == '"': + in_quotes = not in_quotes + elif in_quotes: + result.append(tok.replace("_$ESC_QUOTE$_", '"')) + else: + result.extend(re.findall(r"([^\s,]+),?", tok)) + return result diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py new file mode 100644 index 0000000..df2656d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py @@ -0,0 +1,1262 @@ +# dialects/postgresql/asyncpg.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+asyncpg + :name: asyncpg + :dbapi: asyncpg + :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://magicstack.github.io/asyncpg/ + +The asyncpg dialect is SQLAlchemy's first Python asyncio dialect. + +Using a special asyncio mediation layer, the asyncpg dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") + +.. versionadded:: 1.4 + +.. note:: + + By default asyncpg does not decode the ``json`` and ``jsonb`` types and + returns them as strings. SQLAlchemy sets default type decoder for ``json`` + and ``jsonb`` types using the python builtin ``json.loads`` function. + The json implementation used can be changed by setting the attribute + ``json_deserializer`` when creating the engine with + :func:`create_engine` or :func:`create_async_engine`. + +.. _asyncpg_multihost: + +Multihost Connections +-------------------------- + +The asyncpg dialect features support for multiple fallback hosts in the +same way as that of the psycopg2 and psycopg dialects. The +syntax is the same, +using ``host=:`` combinations as additional query string arguments; +however, there is no default port, so all hosts must have a complete port number +present, otherwise an exception is raised:: + + engine = create_async_engine( + "postgresql+asyncpg://user:password@/dbname?host=HostA:5432&host=HostB:5432&host=HostC:5432" + ) + +For complete background on this syntax, see :ref:`psycopg2_multi_host`. + +.. versionadded:: 2.0.18 + +.. seealso:: + + :ref:`psycopg2_multi_host` + +.. _asyncpg_prepared_statement_cache: + +Prepared Statement Cache +-------------------------- + +The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()`` +for all statements. The prepared statement objects are cached after +construction which appears to grant a 10% or more performance improvement for +statement invocation. The cache is on a per-DBAPI connection basis, which +means that the primary storage for prepared statements is within DBAPI +connections pooled within the connection pool. The size of this cache +defaults to 100 statements per DBAPI connection and may be adjusted using the +``prepared_statement_cache_size`` DBAPI argument (note that while this argument +is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the +asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect +argument):: + + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500") + +To disable the prepared statement cache, use a value of zero:: + + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0") + +.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg. + + +.. warning:: The ``asyncpg`` database driver necessarily uses caches for + PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes + such as ``ENUM`` objects are changed via DDL operations. Additionally, + prepared statements themselves which are optionally cached by SQLAlchemy's + driver as described above may also become "stale" when DDL has been emitted + to the PostgreSQL database which modifies the tables or other objects + involved in a particular prepared statement. + + The SQLAlchemy asyncpg dialect will invalidate these caches within its local + process when statements that represent DDL are emitted on a local + connection, but this is only controllable within a single Python process / + database engine. If DDL changes are made from other database engines + and/or processes, a running application may encounter asyncpg exceptions + ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup + failed for type ")`` if it refers to pooled database connections which + operated upon the previous structures. The SQLAlchemy asyncpg dialect will + recover from these error cases when the driver raises these exceptions by + clearing its internal caches as well as those of the asyncpg driver in + response to them, but cannot prevent them from being raised in the first + place if the cached prepared statement or asyncpg type caches have gone + stale, nor can it retry the statement as the PostgreSQL transaction is + invalidated when these errors occur. + +.. _asyncpg_prepared_statement_name: + +Prepared Statement Name with PGBouncer +-------------------------------------- + +By default, asyncpg enumerates prepared statements in numeric order, which +can lead to errors if a name has already been taken for another prepared +statement. This issue can arise if your application uses database proxies +such as PgBouncer to handle connections. One possible workaround is to +use dynamic prepared statement names, which asyncpg now supports through +an optional ``name`` value for the statement name. This allows you to +generate your own unique names that won't conflict with existing ones. +To achieve this, you can provide a function that will be called every time +a prepared statement is prepared:: + + from uuid import uuid4 + + engine = create_async_engine( + "postgresql+asyncpg://user:pass@somepgbouncer/dbname", + poolclass=NullPool, + connect_args={ + 'prepared_statement_name_func': lambda: f'__asyncpg_{uuid4()}__', + }, + ) + +.. seealso:: + + https://github.com/MagicStack/asyncpg/issues/837 + + https://github.com/sqlalchemy/sqlalchemy/issues/6467 + +.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in + your application, it's important to use the :class:`.NullPool` pool + class, and to configure PgBouncer to use `DISCARD `_ + when returning connections. The DISCARD command is used to release resources held by the db connection, + including prepared statements. Without proper setup, prepared statements can + accumulate quickly and cause performance issues. + +Disabling the PostgreSQL JIT to improve ENUM datatype handling +--------------------------------------------------------------- + +Asyncpg has an `issue `_ when +using PostgreSQL ENUM datatypes, where upon the creation of new database +connections, an expensive query may be emitted in order to retrieve metadata +regarding custom types which has been shown to negatively affect performance. +To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the +client using this setting passed to :func:`_asyncio.create_async_engine`:: + + engine = create_async_engine( + "postgresql+asyncpg://user:password@localhost/tmp", + connect_args={"server_settings": {"jit": "off"}}, + ) + +.. seealso:: + + https://github.com/MagicStack/asyncpg/issues/727 + +""" # noqa + +from __future__ import annotations + +import collections +import decimal +import json as _py_json +import re +import time + +from . import json +from . import ranges +from .array import ARRAY as PGARRAY +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL +from .base import OID +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .base import REGCLASS +from .base import REGCONFIG +from .types import BIT +from .types import BYTEA +from .types import CITEXT +from ... import exc +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...engine import processors +from ...sql import sqltypes +from ...util.concurrency import asyncio +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncpgARRAY(PGARRAY): + render_bind_cast = True + + +class AsyncpgString(sqltypes.String): + render_bind_cast = True + + +class AsyncpgREGCONFIG(REGCONFIG): + render_bind_cast = True + + +class AsyncpgTime(sqltypes.Time): + render_bind_cast = True + + +class AsyncpgBit(BIT): + render_bind_cast = True + + +class AsyncpgByteA(BYTEA): + render_bind_cast = True + + +class AsyncpgDate(sqltypes.Date): + render_bind_cast = True + + +class AsyncpgDateTime(sqltypes.DateTime): + render_bind_cast = True + + +class AsyncpgBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class AsyncPgInterval(INTERVAL): + render_bind_cast = True + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return AsyncPgInterval(precision=interval.second_precision) + + +class AsyncPgEnum(ENUM): + render_bind_cast = True + + +class AsyncpgInteger(sqltypes.Integer): + render_bind_cast = True + + +class AsyncpgBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class AsyncpgJSON(json.JSON): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class AsyncpgJSONB(json.JSONB): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType): + pass + + +class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class AsyncpgJSONPathType(json.JSONPathType): + def bind_processor(self, dialect): + def process(value): + if isinstance(value, str): + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + return value + elif value: + tokens = [str(elem) for elem in value] + return tokens + else: + return [] + + return process + + +class AsyncpgNumeric(sqltypes.Numeric): + render_bind_cast = True + + def bind_processor(self, dialect): + return None + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class AsyncpgFloat(AsyncpgNumeric, sqltypes.Float): + __visit_name__ = "float" + render_bind_cast = True + + +class AsyncpgREGCLASS(REGCLASS): + render_bind_cast = True + + +class AsyncpgOID(OID): + render_bind_cast = True + + +class AsyncpgCHAR(sqltypes.CHAR): + render_bind_cast = True + + +class _AsyncpgRange(ranges.AbstractSingleRangeImpl): + def bind_processor(self, dialect): + asyncpg_Range = dialect.dbapi.asyncpg.Range + + def to_range(value): + if isinstance(value, ranges.Range): + value = asyncpg_Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + empty = value.isempty + value = ranges.Range( + value.lower, + value.upper, + bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and value.upper_inc else ')'}", + empty=empty, + ) + return value + + return to_range + + +class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl): + def bind_processor(self, dialect): + asyncpg_Range = dialect.dbapi.asyncpg.Range + + NoneType = type(None) + + def to_range(value): + if isinstance(value, (str, NoneType)): + return value + + def to_range(value): + if isinstance(value, ranges.Range): + value = asyncpg_Range( + value.lower, + value.upper, + lower_inc=value.bounds[0] == "[", + upper_inc=value.bounds[1] == "]", + empty=value.empty, + ) + return value + + return [to_range(element) for element in value] + + return to_range + + def result_processor(self, dialect, coltype): + def to_range_array(value): + def to_range(rvalue): + if rvalue is not None: + empty = rvalue.isempty + rvalue = ranges.Range( + rvalue.lower, + rvalue.upper, + bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501 + f"{']' if not empty and rvalue.upper_inc else ')'}", + empty=empty, + ) + return rvalue + + if value is not None: + value = ranges.MultiRange(to_range(elem) for elem in value) + + return value + + return to_range_array + + +class PGExecutionContext_asyncpg(PGExecutionContext): + def handle_dbapi_exception(self, e): + if isinstance( + e, + ( + self.dialect.dbapi.InvalidCachedStatementError, + self.dialect.dbapi.InternalServerError, + ), + ): + self.dialect._invalidate_schema_cache() + + def pre_exec(self): + if self.isddl: + self.dialect._invalidate_schema_cache() + + self.cursor._invalidate_schema_cache_asof = ( + self.dialect._invalidate_schema_cache_asof + ) + + if not self.compiled: + return + + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class PGCompiler_asyncpg(PGCompiler): + pass + + +class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): + pass + + +class AsyncAdapt_asyncpg_cursor: + __slots__ = ( + "_adapt_connection", + "_connection", + "_rows", + "description", + "arraysize", + "rowcount", + "_cursor", + "_invalidate_schema_cache_asof", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self._rows = [] + self._cursor = None + self.description = None + self.arraysize = 1 + self.rowcount = -1 + self._invalidate_schema_cache_asof = 0 + + def close(self): + self._rows[:] = [] + + def _handle_exception(self, error): + self._adapt_connection._handle_exception(error) + + async def _prepare_and_execute(self, operation, parameters): + adapt_connection = self._adapt_connection + + async with adapt_connection._execute_mutex: + if not adapt_connection._started: + await adapt_connection._start_transaction() + + if parameters is None: + parameters = () + + try: + prepared_stmt, attributes = await adapt_connection._prepare( + operation, self._invalidate_schema_cache_asof + ) + + if attributes: + self.description = [ + ( + attr.name, + attr.type.oid, + None, + None, + None, + None, + None, + ) + for attr in attributes + ] + else: + self.description = None + + if self.server_side: + self._cursor = await prepared_stmt.cursor(*parameters) + self.rowcount = -1 + else: + self._rows = await prepared_stmt.fetch(*parameters) + status = prepared_stmt.get_statusmsg() + + reg = re.match( + r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status + ) + if reg: + self.rowcount = int(reg.group(1)) + else: + self.rowcount = -1 + + except Exception as error: + self._handle_exception(error) + + async def _executemany(self, operation, seq_of_parameters): + adapt_connection = self._adapt_connection + + self.description = None + async with adapt_connection._execute_mutex: + await adapt_connection._check_type_cache_invalidation( + self._invalidate_schema_cache_asof + ) + + if not adapt_connection._started: + await adapt_connection._start_transaction() + + try: + return await self._connection.executemany( + operation, seq_of_parameters + ) + except Exception as error: + self._handle_exception(error) + + def execute(self, operation, parameters=None): + self._adapt_connection.await_( + self._prepare_and_execute(operation, parameters) + ) + + def executemany(self, operation, seq_of_parameters): + return self._adapt_connection.await_( + self._executemany(operation, seq_of_parameters) + ) + + def setinputsizes(self, *inputsizes): + raise NotImplementedError() + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): + server_side = True + __slots__ = ("_rowbuffer",) + + def __init__(self, adapt_connection): + super().__init__(adapt_connection) + self._rowbuffer = None + + def close(self): + self._cursor = None + self._rowbuffer = None + + def _buffer_rows(self): + new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) + self._rowbuffer = collections.deque(new_rows) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self._rowbuffer: + self._buffer_rows() + + while True: + while self._rowbuffer: + yield self._rowbuffer.popleft() + + self._buffer_rows() + if not self._rowbuffer: + break + + def fetchone(self): + if not self._rowbuffer: + self._buffer_rows() + if not self._rowbuffer: + return None + return self._rowbuffer.popleft() + + def fetchmany(self, size=None): + if size is None: + return self.fetchall() + + if not self._rowbuffer: + self._buffer_rows() + + buf = list(self._rowbuffer) + lb = len(buf) + if size > lb: + buf.extend( + self._adapt_connection.await_(self._cursor.fetch(size - lb)) + ) + + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result + + def fetchall(self): + ret = list(self._rowbuffer) + list( + self._adapt_connection.await_(self._all()) + ) + self._rowbuffer.clear() + return ret + + async def _all(self): + rows = [] + + # TODO: looks like we have to hand-roll some kind of batching here. + # hardcoding for the moment but this should be improved. + while True: + batch = await self._cursor.fetch(1000) + if batch: + rows.extend(batch) + continue + else: + break + return rows + + def executemany(self, operation, seq_of_parameters): + raise NotImplementedError( + "server side cursor doesn't support executemany yet" + ) + + +class AsyncAdapt_asyncpg_connection(AdaptedConnection): + __slots__ = ( + "dbapi", + "isolation_level", + "_isolation_setting", + "readonly", + "deferrable", + "_transaction", + "_started", + "_prepared_statement_cache", + "_prepared_statement_name_func", + "_invalidate_schema_cache_asof", + "_execute_mutex", + ) + + await_ = staticmethod(await_only) + + def __init__( + self, + dbapi, + connection, + prepared_statement_cache_size=100, + prepared_statement_name_func=None, + ): + self.dbapi = dbapi + self._connection = connection + self.isolation_level = self._isolation_setting = "read_committed" + self.readonly = False + self.deferrable = False + self._transaction = None + self._started = False + self._invalidate_schema_cache_asof = time.time() + self._execute_mutex = asyncio.Lock() + + if prepared_statement_cache_size: + self._prepared_statement_cache = util.LRUCache( + prepared_statement_cache_size + ) + else: + self._prepared_statement_cache = None + + if prepared_statement_name_func: + self._prepared_statement_name_func = prepared_statement_name_func + else: + self._prepared_statement_name_func = self._default_name_func + + async def _check_type_cache_invalidation(self, invalidate_timestamp): + if invalidate_timestamp > self._invalidate_schema_cache_asof: + await self._connection.reload_schema_state() + self._invalidate_schema_cache_asof = invalidate_timestamp + + async def _prepare(self, operation, invalidate_timestamp): + await self._check_type_cache_invalidation(invalidate_timestamp) + + cache = self._prepared_statement_cache + if cache is None: + prepared_stmt = await self._connection.prepare( + operation, name=self._prepared_statement_name_func() + ) + attributes = prepared_stmt.get_attributes() + return prepared_stmt, attributes + + # asyncpg uses a type cache for the "attributes" which seems to go + # stale independently of the PreparedStatement itself, so place that + # collection in the cache as well. + if operation in cache: + prepared_stmt, attributes, cached_timestamp = cache[operation] + + # preparedstatements themselves also go stale for certain DDL + # changes such as size of a VARCHAR changing, so there is also + # a cross-connection invalidation timestamp + if cached_timestamp > invalidate_timestamp: + return prepared_stmt, attributes + + prepared_stmt = await self._connection.prepare( + operation, name=self._prepared_statement_name_func() + ) + attributes = prepared_stmt.get_attributes() + cache[operation] = (prepared_stmt, attributes, time.time()) + + return prepared_stmt, attributes + + def _handle_exception(self, error): + if self._connection.is_closed(): + self._transaction = None + self._started = False + + if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): + exception_mapping = self.dbapi._asyncpg_error_translate + + for super_ in type(error).__mro__: + if super_ in exception_mapping: + translated_error = exception_mapping[super_]( + "%s: %s" % (type(error), error) + ) + translated_error.pgcode = translated_error.sqlstate = ( + getattr(error, "sqlstate", None) + ) + raise translated_error from error + else: + raise error + else: + raise error + + @property + def autocommit(self): + return self.isolation_level == "autocommit" + + @autocommit.setter + def autocommit(self, value): + if value: + self.isolation_level = "autocommit" + else: + self.isolation_level = self._isolation_setting + + def ping(self): + try: + _ = self.await_(self._async_ping()) + except Exception as error: + self._handle_exception(error) + + async def _async_ping(self): + if self._transaction is None and self.isolation_level != "autocommit": + # create a tranasction explicitly to support pgbouncer + # transaction mode. See #10226 + tr = self._connection.transaction() + await tr.start() + try: + await self._connection.fetchrow(";") + finally: + await tr.rollback() + else: + await self._connection.fetchrow(";") + + def set_isolation_level(self, level): + if self._started: + self.rollback() + self.isolation_level = self._isolation_setting = level + + async def _start_transaction(self): + if self.isolation_level == "autocommit": + return + + try: + self._transaction = self._connection.transaction( + isolation=self.isolation_level, + readonly=self.readonly, + deferrable=self.deferrable, + ) + await self._transaction.start() + except Exception as error: + self._handle_exception(error) + else: + self._started = True + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncpg_ss_cursor(self) + else: + return AsyncAdapt_asyncpg_cursor(self) + + def rollback(self): + if self._started: + try: + self.await_(self._transaction.rollback()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False + + def commit(self): + if self._started: + try: + self.await_(self._transaction.commit()) + except Exception as error: + self._handle_exception(error) + finally: + self._transaction = None + self._started = False + + def close(self): + self.rollback() + + self.await_(self._connection.close()) + + def terminate(self): + if util.concurrency.in_greenlet(): + # in a greenlet; this is the connection was invalidated + # case. + try: + # try to gracefully close; see #10717 + # timeout added in asyncpg 0.14.0 December 2017 + self.await_(self._connection.close(timeout=2)) + except ( + asyncio.TimeoutError, + OSError, + self.dbapi.asyncpg.PostgresError, + ): + # in the case where we are recycling an old connection + # that may have already been disconnected, close() will + # fail with the above timeout. in this case, terminate + # the connection without any further waiting. + # see issue #8419 + self._connection.terminate() + else: + # not in a greenlet; this is the gc cleanup case + self._connection.terminate() + self._started = False + + @staticmethod + def _default_name_func(): + return None + + +class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_asyncpg_dbapi: + def __init__(self, asyncpg): + self.asyncpg = asyncpg + self.paramstyle = "numeric_dollar" + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect) + prepared_statement_cache_size = kw.pop( + "prepared_statement_cache_size", 100 + ) + prepared_statement_name_func = kw.pop( + "prepared_statement_name_func", None + ) + + if util.asbool(async_fallback): + return AsyncAdaptFallback_asyncpg_connection( + self, + await_fallback(creator_fn(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, + prepared_statement_name_func=prepared_statement_name_func, + ) + else: + return AsyncAdapt_asyncpg_connection( + self, + await_only(creator_fn(*arg, **kw)), + prepared_statement_cache_size=prepared_statement_cache_size, + prepared_statement_name_func=prepared_statement_name_func, + ) + + class Error(Exception): + pass + + class Warning(Exception): # noqa + pass + + class InterfaceError(Error): + pass + + class DatabaseError(Error): + pass + + class InternalError(DatabaseError): + pass + + class OperationalError(DatabaseError): + pass + + class ProgrammingError(DatabaseError): + pass + + class IntegrityError(DatabaseError): + pass + + class DataError(DatabaseError): + pass + + class NotSupportedError(DatabaseError): + pass + + class InternalServerError(InternalError): + pass + + class InvalidCachedStatementError(NotSupportedError): + def __init__(self, message): + super().__init__( + message + " (SQLAlchemy asyncpg dialect will now invalidate " + "all prepared caches in response to this exception)", + ) + + # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't + # used, however the test suite looks for these in a few cases. + STRING = util.symbol("STRING") + NUMBER = util.symbol("NUMBER") + DATETIME = util.symbol("DATETIME") + + @util.memoized_property + def _asyncpg_error_translate(self): + import asyncpg + + return { + asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501 + asyncpg.exceptions.PostgresError: self.Error, + asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError, + asyncpg.exceptions.InterfaceError: self.InterfaceError, + asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501 + asyncpg.exceptions.InternalServerError: self.InternalServerError, + } + + def Binary(self, value): + return value + + +class PGDialect_asyncpg(PGDialect): + driver = "asyncpg" + supports_statement_cache = True + + supports_server_side_cursors = True + + render_bind_cast = True + has_terminate = True + + default_paramstyle = "numeric_dollar" + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_asyncpg + statement_compiler = PGCompiler_asyncpg + preparer = PGIdentifierPreparer_asyncpg + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.String: AsyncpgString, + sqltypes.ARRAY: AsyncpgARRAY, + BIT: AsyncpgBit, + CITEXT: CITEXT, + REGCONFIG: AsyncpgREGCONFIG, + sqltypes.Time: AsyncpgTime, + sqltypes.Date: AsyncpgDate, + sqltypes.DateTime: AsyncpgDateTime, + sqltypes.Interval: AsyncPgInterval, + INTERVAL: AsyncPgInterval, + sqltypes.Boolean: AsyncpgBoolean, + sqltypes.Integer: AsyncpgInteger, + sqltypes.BigInteger: AsyncpgBigInteger, + sqltypes.Numeric: AsyncpgNumeric, + sqltypes.Float: AsyncpgFloat, + sqltypes.JSON: AsyncpgJSON, + sqltypes.LargeBinary: AsyncpgByteA, + json.JSONB: AsyncpgJSONB, + sqltypes.JSON.JSONPathType: AsyncpgJSONPathType, + sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType, + sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType, + sqltypes.Enum: AsyncPgEnum, + OID: AsyncpgOID, + REGCLASS: AsyncpgREGCLASS, + sqltypes.CHAR: AsyncpgCHAR, + ranges.AbstractSingleRange: _AsyncpgRange, + ranges.AbstractMultiRange: _AsyncpgMultiRange, + }, + ) + is_async = True + _invalidate_schema_cache_asof = 0 + + def _invalidate_schema_cache(self): + self._invalidate_schema_cache_asof = time.time() + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg")) + + @util.memoized_property + def _isolation_lookup(self): + return { + "AUTOCOMMIT": "autocommit", + "READ COMMITTED": "read_committed", + "REPEATABLE READ": "repeatable_read", + "SERIALIZABLE": "serializable", + } + + def get_isolation_level_values(self, dbapi_connection): + return list(self._isolation_lookup) + + def set_isolation_level(self, dbapi_connection, level): + dbapi_connection.set_isolation_level(self._isolation_lookup[level]) + + def set_readonly(self, connection, value): + connection.readonly = value + + def get_readonly(self, connection): + return connection.readonly + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def do_terminate(self, dbapi_connection) -> None: + dbapi_connection.terminate() + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + multihosts, multiports = self._split_multihost_from_url(url) + + opts.update(url.query) + + if multihosts: + assert multiports + if len(multihosts) == 1: + opts["host"] = multihosts[0] + if multiports[0] is not None: + opts["port"] = multiports[0] + elif not all(multihosts): + raise exc.ArgumentError( + "All hosts are required to be present" + " for asyncpg multiple host URL" + ) + elif not all(multiports): + raise exc.ArgumentError( + "All ports are required to be present" + " for asyncpg multiple host URL" + ) + else: + opts["host"] = list(multihosts) + opts["port"] = list(multiports) + else: + util.coerce_kw_type(opts, "port", int) + util.coerce_kw_type(opts, "prepared_statement_cache_size", int) + return ([], opts) + + def do_ping(self, dbapi_connection): + dbapi_connection.ping() + return True + + @classmethod + def get_pool_class(cls, url): + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def is_disconnect(self, e, connection, cursor): + if connection: + return connection._connection.is_closed() + else: + return isinstance( + e, self.dbapi.InterfaceError + ) and "connection is closed" in str(e) + + async def setup_asyncpg_json_codec(self, conn): + """set up JSON codec for asyncpg. + + This occurs for all new connections and + can be overridden by third party dialects. + + .. versionadded:: 1.4.27 + + """ + + asyncpg_connection = conn._connection + deserializer = self._json_deserializer or _py_json.loads + + def _json_decoder(bin_value): + return deserializer(bin_value.decode()) + + await asyncpg_connection.set_type_codec( + "json", + encoder=str.encode, + decoder=_json_decoder, + schema="pg_catalog", + format="binary", + ) + + async def setup_asyncpg_jsonb_codec(self, conn): + """set up JSONB codec for asyncpg. + + This occurs for all new connections and + can be overridden by third party dialects. + + .. versionadded:: 1.4.27 + + """ + + asyncpg_connection = conn._connection + deserializer = self._json_deserializer or _py_json.loads + + def _jsonb_encoder(str_value): + # \x01 is the prefix for jsonb used by PostgreSQL. + # asyncpg requires it when format='binary' + return b"\x01" + str_value.encode() + + deserializer = self._json_deserializer or _py_json.loads + + def _jsonb_decoder(bin_value): + # the byte is the \x01 prefix for jsonb used by PostgreSQL. + # asyncpg returns it when format='binary' + return deserializer(bin_value[1:].decode()) + + await asyncpg_connection.set_type_codec( + "jsonb", + encoder=_jsonb_encoder, + decoder=_jsonb_decoder, + schema="pg_catalog", + format="binary", + ) + + async def _disable_asyncpg_inet_codecs(self, conn): + asyncpg_connection = conn._connection + + await asyncpg_connection.set_type_codec( + "inet", + encoder=lambda s: s, + decoder=lambda s: s, + schema="pg_catalog", + format="text", + ) + + await asyncpg_connection.set_type_codec( + "cidr", + encoder=lambda s: s, + decoder=lambda s: s, + schema="pg_catalog", + format="text", + ) + + def on_connect(self): + """on_connect for asyncpg + + A major component of this for asyncpg is to set up type decoders at the + asyncpg level. + + See https://github.com/MagicStack/asyncpg/issues/623 for + notes on JSON/JSONB implementation. + + """ + + super_connect = super().on_connect() + + def connect(conn): + conn.await_(self.setup_asyncpg_json_codec(conn)) + conn.await_(self.setup_asyncpg_jsonb_codec(conn)) + + if self._native_inet_types is False: + conn.await_(self._disable_asyncpg_inet_codecs(conn)) + if super_connect is not None: + super_connect(conn) + + return connect + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = PGDialect_asyncpg diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/base.py new file mode 100644 index 0000000..4ab3ca2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/base.py @@ -0,0 +1,5007 @@ +# dialects/postgresql/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql + :name: PostgreSQL + :full_support: 12, 13, 14, 15 + :normal_support: 9.6+ + :best_effort: 9+ + +.. _postgresql_sequences: + +Sequences/SERIAL/IDENTITY +------------------------- + +PostgreSQL supports sequences, and SQLAlchemy uses these as the default means +of creating new primary key values for integer-based primary key columns. When +creating tables, SQLAlchemy will issue the ``SERIAL`` datatype for +integer-based primary key columns, which generates a sequence and server side +default corresponding to the column. + +To specify a specific named sequence to be used for primary key generation, +use the :func:`~sqlalchemy.schema.Sequence` construct:: + + Table( + "sometable", + metadata, + Column( + "id", Integer, Sequence("some_id_seq", start=1), primary_key=True + ) + ) + +When SQLAlchemy issues a single INSERT statement, to fulfill the contract of +having the "last insert identifier" available, a RETURNING clause is added to +the INSERT statement which specifies the primary key columns should be +returned after the statement completes. The RETURNING functionality only takes +place if PostgreSQL 8.2 or later is in use. As a fallback approach, the +sequence, whether specified explicitly or implicitly via ``SERIAL``, is +executed independently beforehand, the returned value to be used in the +subsequent insert. Note that when an +:func:`~sqlalchemy.sql.expression.insert()` construct is executed using +"executemany" semantics, the "last inserted identifier" functionality does not +apply; no RETURNING clause is emitted nor is the sequence pre-executed in this +case. + + +PostgreSQL 10 and above IDENTITY columns +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL 10 and above have a new IDENTITY feature that supersedes the use +of SERIAL. The :class:`_schema.Identity` construct in a +:class:`_schema.Column` can be used to control its behavior:: + + from sqlalchemy import Table, Column, MetaData, Integer, Computed + + metadata = MetaData() + + data = Table( + "data", + metadata, + Column( + 'id', Integer, Identity(start=42, cycle=True), primary_key=True + ), + Column('data', String) + ) + +The CREATE TABLE for the above :class:`_schema.Table` object would be: + +.. sourcecode:: sql + + CREATE TABLE data ( + id INTEGER GENERATED BY DEFAULT AS IDENTITY (START WITH 42 CYCLE), + data VARCHAR, + PRIMARY KEY (id) + ) + +.. versionchanged:: 1.4 Added :class:`_schema.Identity` construct + in a :class:`_schema.Column` to specify the option of an autoincrementing + column. + +.. note:: + + Previous versions of SQLAlchemy did not have built-in support for rendering + of IDENTITY, and could use the following compilation hook to replace + occurrences of SERIAL with IDENTITY:: + + from sqlalchemy.schema import CreateColumn + from sqlalchemy.ext.compiler import compiles + + + @compiles(CreateColumn, 'postgresql') + def use_identity(element, compiler, **kw): + text = compiler.visit_create_column(element, **kw) + text = text.replace( + "SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY" + ) + return text + + Using the above, a table such as:: + + t = Table( + 't', m, + Column('id', Integer, primary_key=True), + Column('data', String) + ) + + Will generate on the backing database as:: + + CREATE TABLE t ( + id INT GENERATED BY DEFAULT AS IDENTITY, + data VARCHAR, + PRIMARY KEY (id) + ) + +.. _postgresql_ss_cursors: + +Server Side Cursors +------------------- + +Server-side cursor support is available for the psycopg2, asyncpg +dialects and may also be available in others. + +Server side cursors are enabled on a per-statement basis by using the +:paramref:`.Connection.execution_options.stream_results` connection execution +option:: + + with engine.connect() as conn: + result = conn.execution_options(stream_results=True).execute(text("select * from table")) + +Note that some kinds of SQL statements may not be supported with +server side cursors; generally, only SQL statements that return rows should be +used with this option. + +.. deprecated:: 1.4 The dialect-level server_side_cursors flag is deprecated + and will be removed in a future release. Please use the + :paramref:`_engine.Connection.stream_results` execution option for + unbuffered cursor support. + +.. seealso:: + + :ref:`engine_stream_results` + +.. _postgresql_isolation_level: + +Transaction Isolation Level +--------------------------- + +Most SQLAlchemy dialects support setting of transaction isolation level +using the :paramref:`_sa.create_engine.isolation_level` parameter +at the :func:`_sa.create_engine` level, and at the :class:`_engine.Connection` +level via the :paramref:`.Connection.execution_options.isolation_level` +parameter. + +For PostgreSQL dialects, this feature works either by making use of the +DBAPI-specific features, such as psycopg2's isolation level flags which will +embed the isolation level setting inline with the ``"BEGIN"`` statement, or for +DBAPIs with no direct support by emitting ``SET SESSION CHARACTERISTICS AS +TRANSACTION ISOLATION LEVEL `` ahead of the ``"BEGIN"`` statement +emitted by the DBAPI. For the special AUTOCOMMIT isolation level, +DBAPI-specific techniques are used which is typically an ``.autocommit`` +flag on the DBAPI connection object. + +To set isolation level using :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+pg8000://scott:tiger@localhost/test", + isolation_level = "REPEATABLE READ" + ) + +To set using per-connection execution options:: + + with engine.connect() as conn: + conn = conn.execution_options( + isolation_level="REPEATABLE READ" + ) + with conn.begin(): + # ... work with transaction + +There are also more options for isolation level configurations, such as +"sub-engine" objects linked to a main :class:`_engine.Engine` which each apply +different isolation level settings. See the discussion at +:ref:`dbapi_autocommit` for background. + +Valid values for ``isolation_level`` on most PostgreSQL dialects include: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`dbapi_autocommit` + + :ref:`postgresql_readonly_deferrable` + + :ref:`psycopg2_isolation_level` + + :ref:`pg8000_isolation_level` + +.. _postgresql_readonly_deferrable: + +Setting READ ONLY / DEFERRABLE +------------------------------ + +Most PostgreSQL dialects support setting the "READ ONLY" and "DEFERRABLE" +characteristics of the transaction, which is in addition to the isolation level +setting. These two attributes can be established either in conjunction with or +independently of the isolation level by passing the ``postgresql_readonly`` and +``postgresql_deferrable`` flags with +:meth:`_engine.Connection.execution_options`. The example below illustrates +passing the ``"SERIALIZABLE"`` isolation level at the same time as setting +"READ ONLY" and "DEFERRABLE":: + + with engine.connect() as conn: + conn = conn.execution_options( + isolation_level="SERIALIZABLE", + postgresql_readonly=True, + postgresql_deferrable=True + ) + with conn.begin(): + # ... work with transaction + +Note that some DBAPIs such as asyncpg only support "readonly" with +SERIALIZABLE isolation. + +.. versionadded:: 1.4 added support for the ``postgresql_readonly`` + and ``postgresql_deferrable`` execution options. + +.. _postgresql_reset_on_return: + +Temporary Table / Resource Reset for Connection Pooling +------------------------------------------------------- + +The :class:`.QueuePool` connection pool implementation used +by the SQLAlchemy :class:`.Engine` object includes +:ref:`reset on return ` behavior that will invoke +the DBAPI ``.rollback()`` method when connections are returned to the pool. +While this rollback will clear out the immediate state used by the previous +transaction, it does not cover a wider range of session-level state, including +temporary tables as well as other server state such as prepared statement +handles and statement caches. The PostgreSQL database includes a variety +of commands which may be used to reset this state, including +``DISCARD``, ``RESET``, ``DEALLOCATE``, and ``UNLISTEN``. + + +To install +one or more of these commands as the means of performing reset-on-return, +the :meth:`.PoolEvents.reset` event hook may be used, as demonstrated +in the example below. The implementation +will end transactions in progress as well as discard temporary tables +using the ``CLOSE``, ``RESET`` and ``DISCARD`` commands; see the PostgreSQL +documentation for background on what each of these statements do. + +The :paramref:`_sa.create_engine.pool_reset_on_return` parameter +is set to ``None`` so that the custom scheme can replace the default behavior +completely. The custom hook implementation calls ``.rollback()`` in any case, +as it's usually important that the DBAPI's own tracking of commit/rollback +will remain consistent with the state of the transaction:: + + + from sqlalchemy import create_engine + from sqlalchemy import event + + postgresql_engine = create_engine( + "postgresql+pyscopg2://scott:tiger@hostname/dbname", + + # disable default reset-on-return scheme + pool_reset_on_return=None, + ) + + + @event.listens_for(postgresql_engine, "reset") + def _reset_postgresql(dbapi_connection, connection_record, reset_state): + if not reset_state.terminate_only: + dbapi_connection.execute("CLOSE ALL") + dbapi_connection.execute("RESET ALL") + dbapi_connection.execute("DISCARD TEMP") + + # so that the DBAPI itself knows that the connection has been + # reset + dbapi_connection.rollback() + +.. versionchanged:: 2.0.0b3 Added additional state arguments to + the :meth:`.PoolEvents.reset` event and additionally ensured the event + is invoked for all "reset" occurrences, so that it's appropriate + as a place for custom "reset" handlers. Previous schemes which + use the :meth:`.PoolEvents.checkin` handler remain usable as well. + +.. seealso:: + + :ref:`pool_reset_on_return` - in the :ref:`pooling_toplevel` documentation + +.. _postgresql_alternate_search_path: + +Setting Alternate Search Paths on Connect +------------------------------------------ + +The PostgreSQL ``search_path`` variable refers to the list of schema names +that will be implicitly referenced when a particular table or other +object is referenced in a SQL statement. As detailed in the next section +:ref:`postgresql_schema_reflection`, SQLAlchemy is generally organized around +the concept of keeping this variable at its default value of ``public``, +however, in order to have it set to any arbitrary name or names when connections +are used automatically, the "SET SESSION search_path" command may be invoked +for all connections in a pool using the following event handler, as discussed +at :ref:`schema_set_default_connections`:: + + from sqlalchemy import event + from sqlalchemy import create_engine + + engine = create_engine("postgresql+psycopg2://scott:tiger@host/dbname") + + @event.listens_for(engine, "connect", insert=True) + def set_search_path(dbapi_connection, connection_record): + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + cursor.execute("SET SESSION search_path='%s'" % schema_name) + cursor.close() + dbapi_connection.autocommit = existing_autocommit + +The reason the recipe is complicated by use of the ``.autocommit`` DBAPI +attribute is so that when the ``SET SESSION search_path`` directive is invoked, +it is invoked outside of the scope of any transaction and therefore will not +be reverted when the DBAPI connection has a rollback. + +.. seealso:: + + :ref:`schema_set_default_connections` - in the :ref:`metadata_toplevel` documentation + + + + +.. _postgresql_schema_reflection: + +Remote-Schema Table Introspection and PostgreSQL search_path +------------------------------------------------------------ + +.. admonition:: Section Best Practices Summarized + + keep the ``search_path`` variable set to its default of ``public``, without + any other schema names. Ensure the username used to connect **does not** + match remote schemas, or ensure the ``"$user"`` token is **removed** from + ``search_path``. For other schema names, name these explicitly + within :class:`_schema.Table` definitions. Alternatively, the + ``postgresql_ignore_search_path`` option will cause all reflected + :class:`_schema.Table` objects to have a :attr:`_schema.Table.schema` + attribute set up. + +The PostgreSQL dialect can reflect tables from any schema, as outlined in +:ref:`metadata_reflection_schemas`. + +In all cases, the first thing SQLAlchemy does when reflecting tables is +to **determine the default schema for the current database connection**. +It does this using the PostgreSQL ``current_schema()`` +function, illustated below using a PostgreSQL client session (i.e. using +the ``psql`` tool):: + + test=> select current_schema(); + current_schema + ---------------- + public + (1 row) + +Above we see that on a plain install of PostgreSQL, the default schema name +is the name ``public``. + +However, if your database username **matches the name of a schema**, PostgreSQL's +default is to then **use that name as the default schema**. Below, we log in +using the username ``scott``. When we create a schema named ``scott``, **it +implicitly changes the default schema**:: + + test=> select current_schema(); + current_schema + ---------------- + public + (1 row) + + test=> create schema scott; + CREATE SCHEMA + test=> select current_schema(); + current_schema + ---------------- + scott + (1 row) + +The behavior of ``current_schema()`` is derived from the +`PostgreSQL search path +`_ +variable ``search_path``, which in modern PostgreSQL versions defaults to this:: + + test=> show search_path; + search_path + ----------------- + "$user", public + (1 row) + +Where above, the ``"$user"`` variable will inject the current username as the +default schema, if one exists. Otherwise, ``public`` is used. + +When a :class:`_schema.Table` object is reflected, if it is present in the +schema indicated by the ``current_schema()`` function, **the schema name assigned +to the ".schema" attribute of the Table is the Python "None" value**. Otherwise, the +".schema" attribute will be assigned the string name of that schema. + +With regards to tables which these :class:`_schema.Table` +objects refer to via foreign key constraint, a decision must be made as to how +the ``.schema`` is represented in those remote tables, in the case where that +remote schema name is also a member of the current ``search_path``. + +By default, the PostgreSQL dialect mimics the behavior encouraged by +PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function +returns a sample definition for a particular foreign key constraint, +omitting the referenced schema name from that definition when the name is +also in the PostgreSQL schema search path. The interaction below +illustrates this behavior:: + + test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY); + CREATE TABLE + test=> CREATE TABLE referring( + test(> id INTEGER PRIMARY KEY, + test(> referred_id INTEGER REFERENCES test_schema.referred(id)); + CREATE TABLE + test=> SET search_path TO public, test_schema; + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f' + test-> ; + pg_get_constraintdef + --------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES referred(id) + (1 row) + +Above, we created a table ``referred`` as a member of the remote schema +``test_schema``, however when we added ``test_schema`` to the +PG ``search_path`` and then asked ``pg_get_constraintdef()`` for the +``FOREIGN KEY`` syntax, ``test_schema`` was not included in the output of +the function. + +On the other hand, if we set the search path back to the typical default +of ``public``:: + + test=> SET search_path TO public; + SET + +The same query against ``pg_get_constraintdef()`` now returns the fully +schema-qualified name for us:: + + test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM + test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n + test-> ON n.oid = c.relnamespace + test-> JOIN pg_catalog.pg_constraint r ON c.oid = r.conrelid + test-> WHERE c.relname='referring' AND r.contype = 'f'; + pg_get_constraintdef + --------------------------------------------------------------- + FOREIGN KEY (referred_id) REFERENCES test_schema.referred(id) + (1 row) + +SQLAlchemy will by default use the return value of ``pg_get_constraintdef()`` +in order to determine the remote schema name. That is, if our ``search_path`` +were set to include ``test_schema``, and we invoked a table +reflection process as follows:: + + >>> from sqlalchemy import Table, MetaData, create_engine, text + >>> engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + >>> with engine.connect() as conn: + ... conn.execute(text("SET search_path TO test_schema, public")) + ... metadata_obj = MetaData() + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn) + ... + + +The above process would deliver to the :attr:`_schema.MetaData.tables` +collection +``referred`` table named **without** the schema:: + + >>> metadata_obj.tables['referred'].schema is None + True + +To alter the behavior of reflection such that the referred schema is +maintained regardless of the ``search_path`` setting, use the +``postgresql_ignore_search_path`` option, which can be specified as a +dialect-specific argument to both :class:`_schema.Table` as well as +:meth:`_schema.MetaData.reflect`:: + + >>> with engine.connect() as conn: + ... conn.execute(text("SET search_path TO test_schema, public")) + ... metadata_obj = MetaData() + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn, + ... postgresql_ignore_search_path=True) + ... + + +We will now have ``test_schema.referred`` stored as schema-qualified:: + + >>> metadata_obj.tables['test_schema.referred'].schema + 'test_schema' + +.. sidebar:: Best Practices for PostgreSQL Schema reflection + + The description of PostgreSQL schema reflection behavior is complex, and + is the product of many years of dealing with widely varied use cases and + user preferences. But in fact, there's no need to understand any of it if + you just stick to the simplest use pattern: leave the ``search_path`` set + to its default of ``public`` only, never refer to the name ``public`` as + an explicit schema name otherwise, and refer to all other schema names + explicitly when building up a :class:`_schema.Table` object. The options + described here are only for those users who can't, or prefer not to, stay + within these guidelines. + +.. seealso:: + + :ref:`reflection_schema_qualified_interaction` - discussion of the issue + from a backend-agnostic perspective + + `The Schema Search Path + `_ + - on the PostgreSQL website. + +INSERT/UPDATE...RETURNING +------------------------- + +The dialect supports PG 8.2's ``INSERT..RETURNING``, ``UPDATE..RETURNING`` and +``DELETE..RETURNING`` syntaxes. ``INSERT..RETURNING`` is used by default +for single-row INSERT statements in order to fetch newly generated +primary key identifiers. To specify an explicit ``RETURNING`` clause, +use the :meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = table.insert().returning(table.c.col1, table.c.col2).\ + values(name='foo') + print(result.fetchall()) + + # UPDATE..RETURNING + result = table.update().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo').values(name='bar') + print(result.fetchall()) + + # DELETE..RETURNING + result = table.delete().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo') + print(result.fetchall()) + +.. _postgresql_insert_on_conflict: + +INSERT...ON CONFLICT (Upsert) +------------------------------ + +Starting with version 9.5, PostgreSQL allows "upserts" (update or insert) of +rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` statement. A +candidate row will only be inserted if that row does not violate any unique +constraints. In the case of a unique constraint violation, a secondary action +can occur which can be either "DO UPDATE", indicating that the data in the +target row should be updated, or "DO NOTHING", which indicates to silently skip +this row. + +Conflicts are determined using existing unique constraints and indexes. These +constraints may be identified either using their name as stated in DDL, +or they may be inferred by stating the columns and conditions that comprise +the indexes. + +SQLAlchemy provides ``ON CONFLICT`` support via the PostgreSQL-specific +:func:`_postgresql.insert()` function, which provides +the generative methods :meth:`_postgresql.Insert.on_conflict_do_update` +and :meth:`~.postgresql.Insert.on_conflict_do_nothing`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.postgresql import insert + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] + ... ) + >>> print(do_nothing_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO NOTHING + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='pk_my_table', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT pk_my_table DO UPDATE SET data = %(param_1)s + +.. seealso:: + + `INSERT .. ON CONFLICT + `_ + - in the PostgreSQL documentation. + +Specifying the Target +^^^^^^^^^^^^^^^^^^^^^ + +Both methods supply the "target" of the conflict using either the +named constraint or by column inference: + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` argument + specifies a sequence containing string column names, :class:`_schema.Column` + objects, and/or SQL expression elements, which would identify a unique + index: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=[my_table.c.id], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +* When using :paramref:`_postgresql.Insert.on_conflict_do_update.index_elements` to + infer an index, a partial index can be inferred by also specifying the + use the :paramref:`_postgresql.Insert.on_conflict_do_update.index_where` parameter: + + .. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + >>> stmt = stmt.on_conflict_do_update( + ... index_elements=[my_table.c.user_email], + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) + ... ) + >>> print(stmt) + {printsql}INSERT INTO my_table (data, user_email) + VALUES (%(data)s, %(user_email)s) ON CONFLICT (user_email) + WHERE user_email LIKE %(user_email_1)s DO UPDATE SET data = excluded.data + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument is + used to specify an index directly rather than inferring it. This can be + the name of a UNIQUE constraint, a PRIMARY KEY constraint, or an INDEX: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='my_table_idx_1', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT my_table_idx_1 DO UPDATE SET data = %(param_1)s + {stop} + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint='my_table_pk', + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT ON CONSTRAINT my_table_pk DO UPDATE SET data = %(param_1)s + {stop} + +* The :paramref:`_postgresql.Insert.on_conflict_do_update.constraint` argument may + also refer to a SQLAlchemy construct representing a constraint, + e.g. :class:`.UniqueConstraint`, :class:`.PrimaryKeyConstraint`, + :class:`.Index`, or :class:`.ExcludeConstraint`. In this use, + if the constraint has a name, it is used directly. Otherwise, if the + constraint is unnamed, then inference will be used, where the expressions + and optional WHERE clause of the constraint will be spelled out in the + construct. This use is especially convenient + to refer to the named or unnamed primary key of a :class:`_schema.Table` + using the + :attr:`_schema.Table.primary_key` attribute: + + .. sourcecode:: pycon+sql + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... constraint=my_table.primary_key, + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +The SET Clause +^^^^^^^^^^^^^^^ + +``ON CONFLICT...DO UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are specified using the +:paramref:`_postgresql.Insert.on_conflict_do_update.set_` parameter. This +parameter accepts a dictionary which consists of direct values +for UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s + +.. warning:: + + The :meth:`_expression.Insert.on_conflict_do_update` + method does **not** take into + account Python-side default UPDATE values or generation functions, e.g. + those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of UPDATE, + unless they are manually specified in the + :paramref:`_postgresql.Insert.on_conflict_do_update.set_` dictionary. + +Updating using the Excluded INSERT Values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to refer to the proposed insertion row, the special alias +:attr:`~.postgresql.Insert.excluded` is available as an attribute on +the :class:`_postgresql.Insert` object; this object is a +:class:`_expression.ColumnCollection` +which alias contains all columns of the target +table: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) + ... ) + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) + VALUES (%(id)s, %(data)s, %(author)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author + +Additional WHERE Criteria +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`_expression.Insert.on_conflict_do_update` method also accepts +a WHERE clause using the :paramref:`_postgresql.Insert.on_conflict_do_update.where` +parameter, which will limit those rows which receive an UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + >>> on_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) + ... ) + >>> print(on_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) + VALUES (%(id)s, %(data)s, %(author)s) + ON CONFLICT (id) DO UPDATE SET data = %(param_1)s, author = excluded.author + WHERE my_table.status = %(status_1)s + +Skipping Rows with DO NOTHING +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ON CONFLICT`` may be used to skip inserting a row entirely +if any conflict with a unique or exclusion constraint occurs; below +this is illustrated using the +:meth:`~.postgresql.Insert.on_conflict_do_nothing` method: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT (id) DO NOTHING + +If ``DO NOTHING`` is used without specifying any columns or constraint, +it has the effect of skipping the INSERT for any unique or exclusion +constraint violation which occurs: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing() + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) + ON CONFLICT DO NOTHING + +.. _postgresql_match: + +Full Text Search +---------------- + +PostgreSQL's full text search system is available through the use of the +:data:`.func` namespace, combined with the use of custom operators +via the :meth:`.Operators.bool_op` method. For simple cases with some +degree of cross-backend compatibility, the :meth:`.Operators.match` operator +may also be used. + +.. _postgresql_simple_match: + +Simple plain text matching with ``match()`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`.Operators.match` operator provides for cross-compatible simple +text matching. For the PostgreSQL backend, it's hardcoded to generate +an expression using the ``@@`` operator in conjunction with the +``plainto_tsquery()`` PostgreSQL function. + +On the PostgreSQL dialect, an expression like the following:: + + select(sometable.c.text.match("search string")) + +would emit to the database:: + + SELECT text @@ plainto_tsquery('search string') FROM table + +Above, passing a plain string to :meth:`.Operators.match` will automatically +make use of ``plainto_tsquery()`` to specify the type of tsquery. This +establishes basic database cross-compatibility for :meth:`.Operators.match` +with other backends. + +.. versionchanged:: 2.0 The default tsquery generation function used by the + PostgreSQL dialect with :meth:`.Operators.match` is ``plainto_tsquery()``. + + To render exactly what was rendered in 1.4, use the following form:: + + from sqlalchemy import func + + select( + sometable.c.text.bool_op("@@")(func.to_tsquery("search string")) + ) + + Which would emit:: + + SELECT text @@ to_tsquery('search string') FROM table + +Using PostgreSQL full text functions and operators directly +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Text search operations beyond the simple use of :meth:`.Operators.match` +may make use of the :data:`.func` namespace to generate PostgreSQL full-text +functions, in combination with :meth:`.Operators.bool_op` to generate +any boolean operator. + +For example, the query:: + + select( + func.to_tsquery('cat').bool_op("@>")(func.to_tsquery('cat & rat')) + ) + +would generate: + +.. sourcecode:: sql + + SELECT to_tsquery('cat') @> to_tsquery('cat & rat') + + +The :class:`_postgresql.TSVECTOR` type can provide for explicit CAST:: + + from sqlalchemy.dialects.postgresql import TSVECTOR + from sqlalchemy import select, cast + select(cast("some text", TSVECTOR)) + +produces a statement equivalent to:: + + SELECT CAST('some text' AS TSVECTOR) AS anon_1 + +The ``func`` namespace is augmented by the PostgreSQL dialect to set up +correct argument and return types for most full text search functions. +These functions are used automatically by the :attr:`_sql.func` namespace +assuming the ``sqlalchemy.dialects.postgresql`` package has been imported, +or :func:`_sa.create_engine` has been invoked using a ``postgresql`` +dialect. These functions are documented at: + +* :class:`_postgresql.to_tsvector` +* :class:`_postgresql.to_tsquery` +* :class:`_postgresql.plainto_tsquery` +* :class:`_postgresql.phraseto_tsquery` +* :class:`_postgresql.websearch_to_tsquery` +* :class:`_postgresql.ts_headline` + +Specifying the "regconfig" with ``match()`` or custom operators +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL's ``plainto_tsquery()`` function accepts an optional +"regconfig" argument that is used to instruct PostgreSQL to use a +particular pre-computed GIN or GiST index in order to perform the search. +When using :meth:`.Operators.match`, this additional parameter may be +specified using the ``postgresql_regconfig`` parameter, such as:: + + select(mytable.c.id).where( + mytable.c.title.match('somestring', postgresql_regconfig='english') + ) + +Which would emit:: + + SELECT mytable.id FROM mytable + WHERE mytable.title @@ plainto_tsquery('english', 'somestring') + +When using other PostgreSQL search functions with :data:`.func`, the +"regconfig" parameter may be passed directly as the initial argument:: + + select(mytable.c.id).where( + func.to_tsvector("english", mytable.c.title).bool_op("@@")( + func.to_tsquery("english", "somestring") + ) + ) + +produces a statement equivalent to:: + + SELECT mytable.id FROM mytable + WHERE to_tsvector('english', mytable.title) @@ + to_tsquery('english', 'somestring') + +It is recommended that you use the ``EXPLAIN ANALYZE...`` tool from +PostgreSQL to ensure that you are generating queries with SQLAlchemy that +take full advantage of any indexes you may have created for full text search. + +.. seealso:: + + `Full Text Search `_ - in the PostgreSQL documentation + + +FROM ONLY ... +------------- + +The dialect supports PostgreSQL's ONLY keyword for targeting only a particular +table in an inheritance hierarchy. This can be used to produce the +``SELECT ... FROM ONLY``, ``UPDATE ONLY ...``, and ``DELETE FROM ONLY ...`` +syntaxes. It uses SQLAlchemy's hints mechanism:: + + # SELECT ... FROM ONLY ... + result = table.select().with_hint(table, 'ONLY', 'postgresql') + print(result.fetchall()) + + # UPDATE ONLY ... + table.update(values=dict(foo='bar')).with_hint('ONLY', + dialect_name='postgresql') + + # DELETE FROM ONLY ... + table.delete().with_hint('ONLY', dialect_name='postgresql') + + +.. _postgresql_indexes: + +PostgreSQL-Specific Index Options +--------------------------------- + +Several extensions to the :class:`.Index` construct are available, specific +to the PostgreSQL dialect. + +Covering Indexes +^^^^^^^^^^^^^^^^ + +The ``postgresql_include`` option renders INCLUDE(colname) for the given +string names:: + + Index("my_index", table.c.x, postgresql_include=['y']) + +would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` + +Note that this feature requires PostgreSQL 11 or later. + +.. versionadded:: 1.4 + +.. _postgresql_partial_indexes: + +Partial Indexes +^^^^^^^^^^^^^^^ + +Partial indexes add criterion to the index definition so that the index is +applied to a subset of rows. These can be specified on :class:`.Index` +using the ``postgresql_where`` keyword argument:: + + Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10) + +.. _postgresql_operator_classes: + +Operator Classes +^^^^^^^^^^^^^^^^ + +PostgreSQL allows the specification of an *operator class* for each column of +an index (see +https://www.postgresql.org/docs/current/interactive/indexes-opclass.html). +The :class:`.Index` construct allows these to be specified via the +``postgresql_ops`` keyword argument:: + + Index( + 'my_index', my_table.c.id, my_table.c.data, + postgresql_ops={ + 'data': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +Note that the keys in the ``postgresql_ops`` dictionaries are the +"key" name of the :class:`_schema.Column`, i.e. the name used to access it from +the ``.c`` collection of :class:`_schema.Table`, which can be configured to be +different than the actual name of the column as expressed in the database. + +If ``postgresql_ops`` is to be used against a complex SQL expression such +as a function call, then to apply to the column it must be given a label +that is identified in the dictionary by name, e.g.:: + + Index( + 'my_index', my_table.c.id, + func.lower(my_table.c.data).label('data_lower'), + postgresql_ops={ + 'data_lower': 'text_pattern_ops', + 'id': 'int4_ops' + }) + +Operator classes are also supported by the +:class:`_postgresql.ExcludeConstraint` construct using the +:paramref:`_postgresql.ExcludeConstraint.ops` parameter. See that parameter for +details. + +.. versionadded:: 1.3.21 added support for operator classes with + :class:`_postgresql.ExcludeConstraint`. + + +Index Types +^^^^^^^^^^^ + +PostgreSQL provides several index types: B-Tree, Hash, GiST, and GIN, as well +as the ability for users to create their own (see +https://www.postgresql.org/docs/current/static/indexes-types.html). These can be +specified on :class:`.Index` using the ``postgresql_using`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_using='gin') + +The value passed to the keyword argument will be simply passed through to the +underlying CREATE INDEX command, so it *must* be a valid index type for your +version of PostgreSQL. + +.. _postgresql_index_storage: + +Index Storage Parameters +^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL allows storage parameters to be set on indexes. The storage +parameters available depend on the index method used by the index. Storage +parameters can be specified on :class:`.Index` using the ``postgresql_with`` +keyword argument:: + + Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50}) + +PostgreSQL allows to define the tablespace in which to create the index. +The tablespace can be specified on :class:`.Index` using the +``postgresql_tablespace`` keyword argument:: + + Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace') + +Note that the same option is available on :class:`_schema.Table` as well. + +.. _postgresql_index_concurrently: + +Indexes with CONCURRENTLY +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The PostgreSQL index option CONCURRENTLY is supported by passing the +flag ``postgresql_concurrently`` to the :class:`.Index` construct:: + + tbl = Table('testtbl', m, Column('data', Integer)) + + idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) + +The above index construct will render DDL for CREATE INDEX, assuming +PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as:: + + CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data) + +For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for +a connection-less dialect, it will emit:: + + DROP INDEX CONCURRENTLY test_idx1 + +When using CONCURRENTLY, the PostgreSQL database requires that the statement +be invoked outside of a transaction block. The Python DBAPI enforces that +even for a single statement, a transaction is present, so to use this +construct, the DBAPI's "autocommit" mode must be used:: + + metadata = MetaData() + table = Table( + "foo", metadata, + Column("id", String)) + index = Index( + "foo_idx", table.c.id, postgresql_concurrently=True) + + with engine.connect() as conn: + with conn.execution_options(isolation_level='AUTOCOMMIT'): + table.create(conn) + +.. seealso:: + + :ref:`postgresql_isolation_level` + +.. _postgresql_index_reflection: + +PostgreSQL Index Reflection +--------------------------- + +The PostgreSQL database creates a UNIQUE INDEX implicitly whenever the +UNIQUE CONSTRAINT construct is used. When inspecting a table using +:class:`_reflection.Inspector`, the :meth:`_reflection.Inspector.get_indexes` +and the :meth:`_reflection.Inspector.get_unique_constraints` +will report on these +two constructs distinctly; in the case of the index, the key +``duplicates_constraint`` will be present in the index entry if it is +detected as mirroring a constraint. When performing reflection using +``Table(..., autoload_with=engine)``, the UNIQUE INDEX is **not** returned +in :attr:`_schema.Table.indexes` when it is detected as mirroring a +:class:`.UniqueConstraint` in the :attr:`_schema.Table.constraints` collection +. + +Special Reflection Options +-------------------------- + +The :class:`_reflection.Inspector` +used for the PostgreSQL backend is an instance +of :class:`.PGInspector`, which offers additional methods:: + + from sqlalchemy import create_engine, inspect + + engine = create_engine("postgresql+psycopg2://localhost/test") + insp = inspect(engine) # will be a PGInspector + + print(insp.get_enums()) + +.. autoclass:: PGInspector + :members: + +.. _postgresql_table_options: + +PostgreSQL Table Options +------------------------ + +Several options for CREATE TABLE are supported directly by the PostgreSQL +dialect in conjunction with the :class:`_schema.Table` construct: + +* ``INHERITS``:: + + Table("some_table", metadata, ..., postgresql_inherits="some_supertable") + + Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) + +* ``ON COMMIT``:: + + Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS') + +* ``PARTITION BY``:: + + Table("some_table", metadata, ..., + postgresql_partition_by='LIST (part_column)') + + .. versionadded:: 1.2.6 + +* ``TABLESPACE``:: + + Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace') + + The above option is also available on the :class:`.Index` construct. + +* ``USING``:: + + Table("some_table", metadata, ..., postgresql_using='heap') + + .. versionadded:: 2.0.26 + +* ``WITH OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=True) + +* ``WITHOUT OIDS``:: + + Table("some_table", metadata, ..., postgresql_with_oids=False) + +.. seealso:: + + `PostgreSQL CREATE TABLE options + `_ - + in the PostgreSQL documentation. + +.. _postgresql_constraint_options: + +PostgreSQL Constraint Options +----------------------------- + +The following option(s) are supported by the PostgreSQL dialect in conjunction +with selected constraint constructs: + +* ``NOT VALID``: This option applies towards CHECK and FOREIGN KEY constraints + when the constraint is being added to an existing table via ALTER TABLE, + and has the effect that existing rows are not scanned during the ALTER + operation against the constraint being added. + + When using a SQL migration tool such as `Alembic `_ + that renders ALTER TABLE constructs, the ``postgresql_not_valid`` argument + may be specified as an additional keyword argument within the operation + that creates the constraint, as in the following Alembic example:: + + def update(): + op.create_foreign_key( + "fk_user_address", + "address", + "user", + ["user_id"], + ["id"], + postgresql_not_valid=True + ) + + The keyword is ultimately accepted directly by the + :class:`_schema.CheckConstraint`, :class:`_schema.ForeignKeyConstraint` + and :class:`_schema.ForeignKey` constructs; when using a tool like + Alembic, dialect-specific keyword arguments are passed through to + these constructs from the migration operation directives:: + + CheckConstraint("some_field IS NOT NULL", postgresql_not_valid=True) + + ForeignKeyConstraint(["some_id"], ["some_table.some_id"], postgresql_not_valid=True) + + .. versionadded:: 1.4.32 + + .. seealso:: + + `PostgreSQL ALTER TABLE options + `_ - + in the PostgreSQL documentation. + +.. _postgresql_table_valued_overview: + +Table values, Table and Column valued functions, Row and Tuple objects +----------------------------------------------------------------------- + +PostgreSQL makes great use of modern SQL forms such as table-valued functions, +tables and rows as values. These constructs are commonly used as part +of PostgreSQL's support for complex datatypes such as JSON, ARRAY, and other +datatypes. SQLAlchemy's SQL expression language has native support for +most table-valued and row-valued forms. + +.. _postgresql_table_valued: + +Table-Valued Functions +^^^^^^^^^^^^^^^^^^^^^^^ + +Many PostgreSQL built-in functions are intended to be used in the FROM clause +of a SELECT statement, and are capable of returning table rows or sets of table +rows. A large portion of PostgreSQL's JSON functions for example such as +``json_array_elements()``, ``json_object_keys()``, ``json_each_text()``, +``json_each()``, ``json_to_record()``, ``json_populate_recordset()`` use such +forms. These classes of SQL function calling forms in SQLAlchemy are available +using the :meth:`_functions.FunctionElement.table_valued` method in conjunction +with :class:`_functions.Function` objects generated from the :data:`_sql.func` +namespace. + +Examples from PostgreSQL's reference documentation follow below: + +* ``json_each()``: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func + >>> stmt = select(func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value")) + >>> print(stmt) + {printsql}SELECT anon_1.key, anon_1.value + FROM json_each(:json_each_1) AS anon_1 + +* ``json_populate_record()``: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func, literal_column + >>> stmt = select( + ... func.json_populate_record( + ... literal_column("null::myrowtype"), + ... '{"a":1,"b":2}' + ... ).table_valued("a", "b", name="x") + ... ) + >>> print(stmt) + {printsql}SELECT x.a, x.b + FROM json_populate_record(null::myrowtype, :json_populate_record_1) AS x + +* ``json_to_record()`` - this form uses a PostgreSQL specific form of derived + columns in the alias, where we may make use of :func:`_sql.column` elements with + types to produce them. The :meth:`_functions.FunctionElement.table_valued` + method produces a :class:`_sql.TableValuedAlias` construct, and the method + :meth:`_sql.TableValuedAlias.render_derived` method sets up the derived + columns specification: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func, column, Integer, Text + >>> stmt = select( + ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}').table_valued( + ... column("a", Integer), column("b", Text), column("d", Text), + ... ).render_derived(name="x", with_types=True) + ... ) + >>> print(stmt) + {printsql}SELECT x.a, x.b, x.d + FROM json_to_record(:json_to_record_1) AS x(a INTEGER, b TEXT, d TEXT) + +* ``WITH ORDINALITY`` - part of the SQL standard, ``WITH ORDINALITY`` adds an + ordinal counter to the output of a function and is accepted by a limited set + of PostgreSQL functions including ``unnest()`` and ``generate_series()``. The + :meth:`_functions.FunctionElement.table_valued` method accepts a keyword + parameter ``with_ordinality`` for this purpose, which accepts the string name + that will be applied to the "ordinality" column: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func + >>> stmt = select( + ... func.generate_series(4, 1, -1). + ... table_valued("value", with_ordinality="ordinality"). + ... render_derived() + ... ) + >>> print(stmt) + {printsql}SELECT anon_1.value, anon_1.ordinality + FROM generate_series(:generate_series_1, :generate_series_2, :generate_series_3) + WITH ORDINALITY AS anon_1(value, ordinality) + +.. versionadded:: 1.4.0b2 + +.. seealso:: + + :ref:`tutorial_functions_table_valued` - in the :ref:`unified_tutorial` + +.. _postgresql_column_valued: + +Column Valued Functions +^^^^^^^^^^^^^^^^^^^^^^^ + +Similar to the table valued function, a column valued function is present +in the FROM clause, but delivers itself to the columns clause as a single +scalar value. PostgreSQL functions such as ``json_array_elements()``, +``unnest()`` and ``generate_series()`` may use this form. Column valued functions are available using the +:meth:`_functions.FunctionElement.column_valued` method of :class:`_functions.FunctionElement`: + +* ``json_array_elements()``: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import select, func + >>> stmt = select(func.json_array_elements('["one", "two"]').column_valued("x")) + >>> print(stmt) + {printsql}SELECT x + FROM json_array_elements(:json_array_elements_1) AS x + +* ``unnest()`` - in order to generate a PostgreSQL ARRAY literal, the + :func:`_postgresql.array` construct may be used: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.postgresql import array + >>> from sqlalchemy import select, func + >>> stmt = select(func.unnest(array([1, 2])).column_valued()) + >>> print(stmt) + {printsql}SELECT anon_1 + FROM unnest(ARRAY[%(param_1)s, %(param_2)s]) AS anon_1 + + The function can of course be used against an existing table-bound column + that's of type :class:`_types.ARRAY`: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import table, column, ARRAY, Integer + >>> from sqlalchemy import select, func + >>> t = table("t", column('value', ARRAY(Integer))) + >>> stmt = select(func.unnest(t.c.value).column_valued("unnested_value")) + >>> print(stmt) + {printsql}SELECT unnested_value + FROM unnest(t.value) AS unnested_value + +.. seealso:: + + :ref:`tutorial_functions_column_valued` - in the :ref:`unified_tutorial` + + +Row Types +^^^^^^^^^ + +Built-in support for rendering a ``ROW`` may be approximated using +``func.ROW`` with the :attr:`_sa.func` namespace, or by using the +:func:`_sql.tuple_` construct: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import table, column, func, tuple_ + >>> t = table("t", column("id"), column("fk")) + >>> stmt = t.select().where( + ... tuple_(t.c.id, t.c.fk) > (1,2) + ... ).where( + ... func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7) + ... ) + >>> print(stmt) + {printsql}SELECT t.id, t.fk + FROM t + WHERE (t.id, t.fk) > (:param_1, :param_2) AND ROW(t.id, t.fk) < ROW(:ROW_1, :ROW_2) + +.. seealso:: + + `PostgreSQL Row Constructors + `_ + + `PostgreSQL Row Constructor Comparison + `_ + +Table Types passed to Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +PostgreSQL supports passing a table as an argument to a function, which is +known as a "record" type. SQLAlchemy :class:`_sql.FromClause` objects +such as :class:`_schema.Table` support this special form using the +:meth:`_sql.FromClause.table_valued` method, which is comparable to the +:meth:`_functions.FunctionElement.table_valued` method except that the collection +of columns is already established by that of the :class:`_sql.FromClause` +itself: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import table, column, func, select + >>> a = table( "a", column("id"), column("x"), column("y")) + >>> stmt = select(func.row_to_json(a.table_valued())) + >>> print(stmt) + {printsql}SELECT row_to_json(a) AS row_to_json_1 + FROM a + +.. versionadded:: 1.4.0b2 + + + +""" # noqa: E501 + +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache +import re +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from . import arraylib as _array +from . import json as _json +from . import pg_catalog +from . import ranges as _ranges +from .ext import _regconfig_fn +from .ext import aggregate_order_by +from .hstore import HSTORE +from .named_types import CreateDomainType as CreateDomainType # noqa: F401 +from .named_types import CreateEnumType as CreateEnumType # noqa: F401 +from .named_types import DOMAIN as DOMAIN # noqa: F401 +from .named_types import DropDomainType as DropDomainType # noqa: F401 +from .named_types import DropEnumType as DropEnumType # noqa: F401 +from .named_types import ENUM as ENUM # noqa: F401 +from .named_types import NamedType as NamedType # noqa: F401 +from .types import _DECIMAL_TYPES # noqa: F401 +from .types import _FLOAT_TYPES # noqa: F401 +from .types import _INT_TYPES # noqa: F401 +from .types import BIT as BIT +from .types import BYTEA as BYTEA +from .types import CIDR as CIDR +from .types import CITEXT as CITEXT +from .types import INET as INET +from .types import INTERVAL as INTERVAL +from .types import MACADDR as MACADDR +from .types import MACADDR8 as MACADDR8 +from .types import MONEY as MONEY +from .types import OID as OID +from .types import PGBit as PGBit # noqa: F401 +from .types import PGCidr as PGCidr # noqa: F401 +from .types import PGInet as PGInet # noqa: F401 +from .types import PGInterval as PGInterval # noqa: F401 +from .types import PGMacAddr as PGMacAddr # noqa: F401 +from .types import PGMacAddr8 as PGMacAddr8 # noqa: F401 +from .types import PGUuid as PGUuid +from .types import REGCLASS as REGCLASS +from .types import REGCONFIG as REGCONFIG # noqa: F401 +from .types import TIME as TIME +from .types import TIMESTAMP as TIMESTAMP +from .types import TSVECTOR as TSVECTOR +from ... import exc +from ... import schema +from ... import select +from ... import sql +from ... import util +from ...engine import characteristics +from ...engine import default +from ...engine import interfaces +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...engine import reflection +from ...engine import URL +from ...engine.reflection import ReflectionDefaults +from ...sql import bindparam +from ...sql import coercions +from ...sql import compiler +from ...sql import elements +from ...sql import expression +from ...sql import roles +from ...sql import sqltypes +from ...sql import util as sql_util +from ...sql.compiler import InsertmanyvaluesSentinelOpts +from ...sql.visitors import InternalTraversal +from ...types import BIGINT +from ...types import BOOLEAN +from ...types import CHAR +from ...types import DATE +from ...types import DOUBLE_PRECISION +from ...types import FLOAT +from ...types import INTEGER +from ...types import NUMERIC +from ...types import REAL +from ...types import SMALLINT +from ...types import TEXT +from ...types import UUID as UUID +from ...types import VARCHAR +from ...util.typing import TypedDict + +IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I) + +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "leading", + "limit", + "localtime", + "localtimestamp", + "new", + "not", + "null", + "of", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", + "authorization", + "between", + "binary", + "cross", + "current_schema", + "freeze", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "notnull", + "outer", + "over", + "overlaps", + "right", + "similar", + "verbose", +} + +colspecs = { + sqltypes.ARRAY: _array.ARRAY, + sqltypes.Interval: INTERVAL, + sqltypes.Enum: ENUM, + sqltypes.JSON.JSONPathType: _json.JSONPATH, + sqltypes.JSON: _json.JSON, + sqltypes.Uuid: PGUuid, +} + + +ischema_names = { + "_array": _array.ARRAY, + "hstore": HSTORE, + "json": _json.JSON, + "jsonb": _json.JSONB, + "int4range": _ranges.INT4RANGE, + "int8range": _ranges.INT8RANGE, + "numrange": _ranges.NUMRANGE, + "daterange": _ranges.DATERANGE, + "tsrange": _ranges.TSRANGE, + "tstzrange": _ranges.TSTZRANGE, + "int4multirange": _ranges.INT4MULTIRANGE, + "int8multirange": _ranges.INT8MULTIRANGE, + "nummultirange": _ranges.NUMMULTIRANGE, + "datemultirange": _ranges.DATEMULTIRANGE, + "tsmultirange": _ranges.TSMULTIRANGE, + "tstzmultirange": _ranges.TSTZMULTIRANGE, + "integer": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "character varying": VARCHAR, + "character": CHAR, + '"char"': sqltypes.String, + "name": sqltypes.String, + "text": TEXT, + "numeric": NUMERIC, + "float": FLOAT, + "real": REAL, + "inet": INET, + "cidr": CIDR, + "citext": CITEXT, + "uuid": UUID, + "bit": BIT, + "bit varying": BIT, + "macaddr": MACADDR, + "macaddr8": MACADDR8, + "money": MONEY, + "oid": OID, + "regclass": REGCLASS, + "double precision": DOUBLE_PRECISION, + "timestamp": TIMESTAMP, + "timestamp with time zone": TIMESTAMP, + "timestamp without time zone": TIMESTAMP, + "time with time zone": TIME, + "time without time zone": TIME, + "date": DATE, + "time": TIME, + "bytea": BYTEA, + "boolean": BOOLEAN, + "interval": INTERVAL, + "tsvector": TSVECTOR, +} + + +class PGCompiler(compiler.SQLCompiler): + def visit_to_tsvector_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_to_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_plainto_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_phraseto_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_websearch_to_tsquery_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def visit_ts_headline_func(self, element, **kw): + return self._assert_pg_ts_ext(element, **kw) + + def _assert_pg_ts_ext(self, element, **kw): + if not isinstance(element, _regconfig_fn): + # other options here include trying to rewrite the function + # with the correct types. however, that means we have to + # "un-SQL-ize" the first argument, which can't work in a + # generalized way. Also, parent compiler class has already added + # the incorrect return type to the result map. So let's just + # make sure the function we want is used up front. + + raise exc.CompileError( + f'Can\'t compile "{element.name}()" full text search ' + f"function construct that does not originate from the " + f'"sqlalchemy.dialects.postgresql" package. ' + f'Please ensure "import sqlalchemy.dialects.postgresql" is ' + f"called before constructing " + f'"sqlalchemy.func.{element.name}()" to ensure registration ' + f"of the correct argument and return types." + ) + + return f"{element.name}{self.function_argspec(element, **kw)}" + + def render_bind_cast(self, type_, dbapi_type, sqltext): + if dbapi_type._type_affinity is sqltypes.String and dbapi_type.length: + # use VARCHAR with no length for VARCHAR cast. + # see #9511 + dbapi_type = sqltypes.STRINGTYPE + return f"""{sqltext}::{ + self.dialect.type_compiler_instance.process( + dbapi_type, identifier_preparer=self.preparer + ) + }""" + + def visit_array(self, element, **kw): + return "ARRAY[%s]" % self.visit_clauselist(element, **kw) + + def visit_slice(self, element, **kw): + return "%s:%s" % ( + self.process(element.start, **kw), + self.process(element.stop, **kw), + ) + + def visit_bitwise_xor_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " # ", **kw) + + def visit_json_getitem_op_binary( + self, binary, operator, _cast_applied=False, **kw + ): + if ( + not _cast_applied + and binary.type._type_affinity is not sqltypes.JSON + ): + kw["_cast_applied"] = True + return self.process(sql.cast(binary, binary.type), **kw) + + kw["eager_grouping"] = True + + return self._generate_generic_binary( + binary, " -> " if not _cast_applied else " ->> ", **kw + ) + + def visit_json_path_getitem_op_binary( + self, binary, operator, _cast_applied=False, **kw + ): + if ( + not _cast_applied + and binary.type._type_affinity is not sqltypes.JSON + ): + kw["_cast_applied"] = True + return self.process(sql.cast(binary, binary.type), **kw) + + kw["eager_grouping"] = True + return self._generate_generic_binary( + binary, " #> " if not _cast_applied else " #>> ", **kw + ) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_aggregate_order_by(self, element, **kw): + return "%s ORDER BY %s" % ( + self.process(element.target, **kw), + self.process(element.order_by, **kw), + ) + + def visit_match_op_binary(self, binary, operator, **kw): + if "postgresql_regconfig" in binary.modifiers: + regconfig = self.render_literal_value( + binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE + ) + if regconfig: + return "%s @@ plainto_tsquery(%s, %s)" % ( + self.process(binary.left, **kw), + regconfig, + self.process(binary.right, **kw), + ) + return "%s @@ plainto_tsquery(%s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_ilike_case_insensitive_operand(self, element, **kw): + return element.element._compiler_dispatch(self, **kw) + + def visit_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + return "%s ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape is not None + else "" + ) + + def visit_not_ilike_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "%s NOT ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape is not None + else "" + ) + + def _regexp_match(self, base_op, binary, operator, kw): + flags = binary.modifiers["flags"] + if flags is None: + return self._generate_generic_binary( + binary, " %s " % base_op, **kw + ) + if flags == "i": + return self._generate_generic_binary( + binary, " %s* " % base_op, **kw + ) + return "%s %s CONCAT('(?', %s, ')', %s)" % ( + self.process(binary.left, **kw), + base_op, + self.render_literal_value(flags, sqltypes.STRINGTYPE), + self.process(binary.right, **kw), + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match("~", binary, operator, kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._regexp_match("!~", binary, operator, kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + string = self.process(binary.left, **kw) + pattern_replace = self.process(binary.right, **kw) + flags = binary.modifiers["flags"] + if flags is None: + return "REGEXP_REPLACE(%s, %s)" % ( + string, + pattern_replace, + ) + else: + return "REGEXP_REPLACE(%s, %s, %s)" % ( + string, + pattern_replace, + self.render_literal_value(flags, sqltypes.STRINGTYPE), + ) + + def visit_empty_set_expr(self, element_types, **kw): + # cast the empty set to the type we are comparing against. if + # we are comparing against the null type, pick an arbitrary + # datatype for the empty set + return "SELECT %s WHERE 1!=1" % ( + ", ".join( + "CAST(NULL AS %s)" + % self.dialect.type_compiler_instance.process( + INTEGER() if type_._isnull else type_ + ) + for type_ in element_types or [INTEGER()] + ), + ) + + def render_literal_value(self, value, type_): + value = super().render_literal_value(value, type_) + + if self.dialect._backslash_escapes: + value = value.replace("\\", "\\\\") + return value + + def visit_aggregate_strings_func(self, fn, **kw): + return "string_agg%s" % self.function_argspec(fn) + + def visit_sequence(self, seq, **kw): + return "nextval('%s')" % self.preparer.format_sequence(seq) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += " \n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT ALL" + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + if hint.upper() != "ONLY": + raise exc.CompileError("Unrecognized hint: %r" % hint) + return "ONLY " + sqltext + + def get_select_precolumns(self, select, **kw): + # Do not call super().get_select_precolumns because + # it will warn/raise when distinct on is present + if select._distinct or select._distinct_on: + if select._distinct_on: + return ( + "DISTINCT ON (" + + ", ".join( + [ + self.process(col, **kw) + for col in select._distinct_on + ] + ) + + ") " + ) + else: + return "DISTINCT " + else: + return "" + + def for_update_clause(self, select, **kw): + if select._for_update_arg.read: + if select._for_update_arg.key_share: + tmp = " FOR KEY SHARE" + else: + tmp = " FOR SHARE" + elif select._for_update_arg.key_share: + tmp = " FOR NO KEY UPDATE" + else: + tmp = " FOR UPDATE" + + if select._for_update_arg.of: + tables = util.OrderedSet() + for c in select._for_update_arg.of: + tables.update(sql_util.surface_selectables_only(c)) + + tmp += " OF " + ", ".join( + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables + ) + + if select._for_update_arg.nowait: + tmp += " NOWAIT" + if select._for_update_arg.skip_locked: + tmp += " SKIP LOCKED" + + return tmp + + def visit_substring_func(self, func, **kw): + s = self.process(func.clauses.clauses[0], **kw) + start = self.process(func.clauses.clauses[1], **kw) + if len(func.clauses.clauses) > 2: + length = self.process(func.clauses.clauses[2], **kw) + return "SUBSTRING(%s FROM %s FOR %s)" % (s, start, length) + else: + return "SUBSTRING(%s FROM %s)" % (s, start) + + def _on_conflict_target(self, clause, **kw): + if clause.constraint_target is not None: + # target may be a name of an Index, UniqueConstraint or + # ExcludeConstraint. While there is a separate + # "max_identifier_length" for indexes, PostgreSQL uses the same + # length for all objects so we can use + # truncate_and_render_constraint_name + target_text = ( + "ON CONSTRAINT %s" + % self.preparer.truncate_and_render_constraint_name( + clause.constraint_target + ) + ) + elif clause.inferred_target_elements is not None: + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, str) + else self.process(c, include_table=False, use_schema=False) + ) + for c in clause.inferred_target_elements + ) + if clause.inferred_target_whereclause is not None: + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + ) + else: + target_text = "" + + return target_text + + def visit_on_conflict_do_nothing(self, on_conflict, **kw): + target_text = self._on_conflict_target(on_conflict, **kw) + + if target_text: + return "ON CONFLICT %s DO NOTHING" % target_text + else: + return "ON CONFLICT DO NOTHING" + + def visit_on_conflict_do_update(self, on_conflict, **kw): + clause = on_conflict + + target_text = self._on_conflict_target(on_conflict, **kw) + + action_set_ops = [] + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + + insert_statement = self.stack[-1]["selectable"] + cols = insert_statement.table.c + for c in cols: + col_key = c.key + + if col_key in set_parameters: + value = set_parameters.pop(col_key) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue + + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(c.name) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.current_executable.table.name, + (", ".join("'%s'" % c for c in set_parameters)), + ) + ) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, str) + else self.process(k, use_schema=False) + ) + value_text = self.process( + coercions.expect(roles.ExpressionElementRole, v), + use_schema=False, + ) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + action_text = ", ".join(action_set_ops) + if clause.update_whereclause is not None: + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) + + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): + """Render the DELETE .. USING clause specific to PostgreSQL.""" + kw["asfrom"] = True + return "USING " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def fetch_clause(self, select, **kw): + # pg requires parens for non literal clauses. It's also required for + # bind parameters if a ::type casts is used by the driver (asyncpg), + # so it's easiest to just always add it + text = "" + if select._offset_clause is not None: + text += "\n OFFSET (%s) ROWS" % self.process( + select._offset_clause, **kw + ) + if select._fetch_clause is not None: + text += "\n FETCH FIRST (%s)%s ROWS %s" % ( + self.process(select._fetch_clause, **kw), + " PERCENT" if select._fetch_clause_options["percent"] else "", + ( + "WITH TIES" + if select._fetch_clause_options["with_ties"] + else "ONLY" + ), + ) + return text + + +class PGDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + impl_type = column.type.dialect_impl(self.dialect) + if isinstance(impl_type, sqltypes.TypeDecorator): + impl_type = impl_type.impl + + has_identity = ( + column.identity is not None + and self.dialect.supports_identity_columns + ) + + if ( + column.primary_key + and column is column.table._autoincrement_column + and ( + self.dialect.supports_smallserial + or not isinstance(impl_type, sqltypes.SmallInteger) + ) + and not has_identity + and ( + column.default is None + or ( + isinstance(column.default, schema.Sequence) + and column.default.optional + ) + ) + ): + if isinstance(impl_type, sqltypes.BigInteger): + colspec += " BIGSERIAL" + elif isinstance(impl_type, sqltypes.SmallInteger): + colspec += " SMALLSERIAL" + else: + colspec += " SERIAL" + else: + colspec += " " + self.dialect.type_compiler_instance.process( + column.type, + type_expression=column, + identifier_preparer=self.preparer, + ) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.computed is not None: + colspec += " " + self.process(column.computed) + if has_identity: + colspec += " " + self.process(column.identity) + + if not column.nullable and not has_identity: + colspec += " NOT NULL" + elif column.nullable and has_identity: + colspec += " NULL" + return colspec + + def _define_constraint_validity(self, constraint): + not_valid = constraint.dialect_options["postgresql"]["not_valid"] + return " NOT VALID" if not_valid else "" + + def visit_check_constraint(self, constraint, **kw): + if constraint._type_bound: + typ = list(constraint.columns)[0].type + if ( + isinstance(typ, sqltypes.ARRAY) + and isinstance(typ.item_type, sqltypes.Enum) + and not typ.item_type.native_enum + ): + raise exc.CompileError( + "PostgreSQL dialect cannot produce the CHECK constraint " + "for ARRAY of non-native ENUM; please specify " + "create_constraint=False on this Enum datatype." + ) + + text = super().visit_check_constraint(constraint) + text += self._define_constraint_validity(constraint) + return text + + def visit_foreign_key_constraint(self, constraint, **kw): + text = super().visit_foreign_key_constraint(constraint) + text += self._define_constraint_validity(constraint) + return text + + def visit_create_enum_type(self, create, **kw): + type_ = create.element + + return "CREATE TYPE %s AS ENUM (%s)" % ( + self.preparer.format_type(type_), + ", ".join( + self.sql_compiler.process(sql.literal(e), literal_binds=True) + for e in type_.enums + ), + ) + + def visit_drop_enum_type(self, drop, **kw): + type_ = drop.element + + return "DROP TYPE %s" % (self.preparer.format_type(type_)) + + def visit_create_domain_type(self, create, **kw): + domain: DOMAIN = create.element + + options = [] + if domain.collation is not None: + options.append(f"COLLATE {self.preparer.quote(domain.collation)}") + if domain.default is not None: + default = self.render_default_string(domain.default) + options.append(f"DEFAULT {default}") + if domain.constraint_name is not None: + name = self.preparer.truncate_and_render_constraint_name( + domain.constraint_name + ) + options.append(f"CONSTRAINT {name}") + if domain.not_null: + options.append("NOT NULL") + if domain.check is not None: + check = self.sql_compiler.process( + domain.check, include_table=False, literal_binds=True + ) + options.append(f"CHECK ({check})") + + return ( + f"CREATE DOMAIN {self.preparer.format_type(domain)} AS " + f"{self.type_compiler.process(domain.data_type)} " + f"{' '.join(options)}" + ) + + def visit_drop_domain_type(self, drop, **kw): + domain = drop.element + return f"DROP DOMAIN {self.preparer.format_type(domain)}" + + def visit_create_index(self, create, **kw): + preparer = self.preparer + index = create.element + self._verify_index_table(index) + text = "CREATE " + if index.unique: + text += "UNIQUE " + + text += "INDEX " + + if self.dialect._supports_create_index_concurrently: + concurrently = index.dialect_options["postgresql"]["concurrently"] + if concurrently: + text += "CONCURRENTLY " + + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s " % ( + self._prepared_index_name(index, include_schema=False), + preparer.format_table(index.table), + ) + + using = index.dialect_options["postgresql"]["using"] + if using: + text += ( + "USING %s " + % self.preparer.validate_sql_phrase(using, IDX_USING).lower() + ) + + ops = index.dialect_options["postgresql"]["ops"] + text += "(%s)" % ( + ", ".join( + [ + self.sql_compiler.process( + ( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr + ), + include_table=False, + literal_binds=True, + ) + + ( + (" " + ops[expr.key]) + if hasattr(expr, "key") and expr.key in ops + else "" + ) + for expr in index.expressions + ] + ) + ) + + includeclause = index.dialect_options["postgresql"]["include"] + if includeclause: + inclusions = [ + index.table.c[col] if isinstance(col, str) else col + for col in includeclause + ] + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) + + nulls_not_distinct = index.dialect_options["postgresql"][ + "nulls_not_distinct" + ] + if nulls_not_distinct is True: + text += " NULLS NOT DISTINCT" + elif nulls_not_distinct is False: + text += " NULLS DISTINCT" + + withclause = index.dialect_options["postgresql"]["with"] + if withclause: + text += " WITH (%s)" % ( + ", ".join( + [ + "%s = %s" % storage_parameter + for storage_parameter in withclause.items() + ] + ) + ) + + tablespace_name = index.dialect_options["postgresql"]["tablespace"] + if tablespace_name: + text += " TABLESPACE %s" % preparer.quote(tablespace_name) + + whereclause = index.dialect_options["postgresql"]["where"] + if whereclause is not None: + whereclause = coercions.expect( + roles.DDLExpressionRole, whereclause + ) + + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def define_unique_constraint_distinct(self, constraint, **kw): + nulls_not_distinct = constraint.dialect_options["postgresql"][ + "nulls_not_distinct" + ] + if nulls_not_distinct is True: + nulls_not_distinct_param = "NULLS NOT DISTINCT " + elif nulls_not_distinct is False: + nulls_not_distinct_param = "NULLS DISTINCT " + else: + nulls_not_distinct_param = "" + return nulls_not_distinct_param + + def visit_drop_index(self, drop, **kw): + index = drop.element + + text = "\nDROP INDEX " + + if self.dialect._supports_drop_index_concurrently: + concurrently = index.dialect_options["postgresql"]["concurrently"] + if concurrently: + text += "CONCURRENTLY " + + if drop.if_exists: + text += "IF EXISTS " + + text += self._prepared_index_name(index, include_schema=True) + return text + + def visit_exclude_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) + elements = [] + kw["include_table"] = False + kw["literal_binds"] = True + for expr, name, op in constraint._render_exprs: + exclude_element = self.sql_compiler.process(expr, **kw) + ( + (" " + constraint.ops[expr.key]) + if hasattr(expr, "key") and expr.key in constraint.ops + else "" + ) + + elements.append("%s WITH %s" % (exclude_element, op)) + text += "EXCLUDE USING %s (%s)" % ( + self.preparer.validate_sql_phrase( + constraint.using, IDX_USING + ).lower(), + ", ".join(elements), + ) + if constraint.where is not None: + text += " WHERE (%s)" % self.sql_compiler.process( + constraint.where, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def post_create_table(self, table): + table_opts = [] + pg_opts = table.dialect_options["postgresql"] + + inherits = pg_opts.get("inherits") + if inherits is not None: + if not isinstance(inherits, (list, tuple)): + inherits = (inherits,) + table_opts.append( + "\n INHERITS ( " + + ", ".join(self.preparer.quote(name) for name in inherits) + + " )" + ) + + if pg_opts["partition_by"]: + table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"]) + + if pg_opts["using"]: + table_opts.append("\n USING %s" % pg_opts["using"]) + + if pg_opts["with_oids"] is True: + table_opts.append("\n WITH OIDS") + elif pg_opts["with_oids"] is False: + table_opts.append("\n WITHOUT OIDS") + + if pg_opts["on_commit"]: + on_commit_options = pg_opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) + + if pg_opts["tablespace"]: + tablespace_name = pg_opts["tablespace"] + table_opts.append( + "\n TABLESPACE %s" % self.preparer.quote(tablespace_name) + ) + + return "".join(table_opts) + + def visit_computed_column(self, generated, **kw): + if generated.persisted is False: + raise exc.CompileError( + "PostrgreSQL computed columns do not support 'virtual' " + "persistence; set the 'persisted' flag to None or True for " + "PostgreSQL support." + ) + + return "GENERATED ALWAYS AS (%s) STORED" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + + def visit_create_sequence(self, create, **kw): + prefix = None + if create.element.data_type is not None: + prefix = " AS %s" % self.type_compiler.process( + create.element.data_type + ) + + return super().visit_create_sequence(create, prefix=prefix, **kw) + + def _can_comment_on_constraint(self, ddl_instance): + constraint = ddl_instance.element + if constraint.name is None: + raise exc.CompileError( + f"Can't emit COMMENT ON for constraint {constraint!r}: " + "it has no name" + ) + if constraint.table is None: + raise exc.CompileError( + f"Can't emit COMMENT ON for constraint {constraint!r}: " + "it has no associated table" + ) + + def visit_set_constraint_comment(self, create, **kw): + self._can_comment_on_constraint(create) + return "COMMENT ON CONSTRAINT %s ON %s IS %s" % ( + self.preparer.format_constraint(create.element), + self.preparer.format_table(create.element.table), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_constraint_comment(self, drop, **kw): + self._can_comment_on_constraint(drop) + return "COMMENT ON CONSTRAINT %s ON %s IS NULL" % ( + self.preparer.format_constraint(drop.element), + self.preparer.format_table(drop.element.table), + ) + + +class PGTypeCompiler(compiler.GenericTypeCompiler): + def visit_TSVECTOR(self, type_, **kw): + return "TSVECTOR" + + def visit_TSQUERY(self, type_, **kw): + return "TSQUERY" + + def visit_INET(self, type_, **kw): + return "INET" + + def visit_CIDR(self, type_, **kw): + return "CIDR" + + def visit_CITEXT(self, type_, **kw): + return "CITEXT" + + def visit_MACADDR(self, type_, **kw): + return "MACADDR" + + def visit_MACADDR8(self, type_, **kw): + return "MACADDR8" + + def visit_MONEY(self, type_, **kw): + return "MONEY" + + def visit_OID(self, type_, **kw): + return "OID" + + def visit_REGCONFIG(self, type_, **kw): + return "REGCONFIG" + + def visit_REGCLASS(self, type_, **kw): + return "REGCLASS" + + def visit_FLOAT(self, type_, **kw): + if not type_.precision: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {"precision": type_.precision} + + def visit_double(self, type_, **kw): + return self.visit_DOUBLE_PRECISION(type, **kw) + + def visit_BIGINT(self, type_, **kw): + return "BIGINT" + + def visit_HSTORE(self, type_, **kw): + return "HSTORE" + + def visit_JSON(self, type_, **kw): + return "JSON" + + def visit_JSONB(self, type_, **kw): + return "JSONB" + + def visit_INT4MULTIRANGE(self, type_, **kw): + return "INT4MULTIRANGE" + + def visit_INT8MULTIRANGE(self, type_, **kw): + return "INT8MULTIRANGE" + + def visit_NUMMULTIRANGE(self, type_, **kw): + return "NUMMULTIRANGE" + + def visit_DATEMULTIRANGE(self, type_, **kw): + return "DATEMULTIRANGE" + + def visit_TSMULTIRANGE(self, type_, **kw): + return "TSMULTIRANGE" + + def visit_TSTZMULTIRANGE(self, type_, **kw): + return "TSTZMULTIRANGE" + + def visit_INT4RANGE(self, type_, **kw): + return "INT4RANGE" + + def visit_INT8RANGE(self, type_, **kw): + return "INT8RANGE" + + def visit_NUMRANGE(self, type_, **kw): + return "NUMRANGE" + + def visit_DATERANGE(self, type_, **kw): + return "DATERANGE" + + def visit_TSRANGE(self, type_, **kw): + return "TSRANGE" + + def visit_TSTZRANGE(self, type_, **kw): + return "TSTZRANGE" + + def visit_json_int_index(self, type_, **kw): + return "INT" + + def visit_json_str_index(self, type_, **kw): + return "TEXT" + + def visit_datetime(self, type_, **kw): + return self.visit_TIMESTAMP(type_, **kw) + + def visit_enum(self, type_, **kw): + if not type_.native_enum or not self.dialect.supports_native_enum: + return super().visit_enum(type_, **kw) + else: + return self.visit_ENUM(type_, **kw) + + def visit_ENUM(self, type_, identifier_preparer=None, **kw): + if identifier_preparer is None: + identifier_preparer = self.dialect.identifier_preparer + return identifier_preparer.format_type(type_) + + def visit_DOMAIN(self, type_, identifier_preparer=None, **kw): + if identifier_preparer is None: + identifier_preparer = self.dialect.identifier_preparer + return identifier_preparer.format_type(type_) + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP%s %s" % ( + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + + def visit_TIME(self, type_, **kw): + return "TIME%s %s" % ( + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", + ) + + def visit_INTERVAL(self, type_, **kw): + text = "INTERVAL" + if type_.fields is not None: + text += " " + type_.fields + if type_.precision is not None: + text += " (%d)" % type_.precision + return text + + def visit_BIT(self, type_, **kw): + if type_.varying: + compiled = "BIT VARYING" + if type_.length is not None: + compiled += "(%d)" % type_.length + else: + compiled = "BIT(%d)" % type_.length + return compiled + + def visit_uuid(self, type_, **kw): + if type_.native_uuid: + return self.visit_UUID(type_, **kw) + else: + return super().visit_uuid(type_, **kw) + + def visit_UUID(self, type_, **kw): + return "UUID" + + def visit_large_binary(self, type_, **kw): + return self.visit_BYTEA(type_, **kw) + + def visit_BYTEA(self, type_, **kw): + return "BYTEA" + + def visit_ARRAY(self, type_, **kw): + inner = self.process(type_.item_type, **kw) + return re.sub( + r"((?: COLLATE.*)?)$", + ( + r"%s\1" + % ( + "[]" + * (type_.dimensions if type_.dimensions is not None else 1) + ) + ), + inner, + count=1, + ) + + def visit_json_path(self, type_, **kw): + return self.visit_JSONPATH(type_, **kw) + + def visit_JSONPATH(self, type_, **kw): + return "JSONPATH" + + +class PGIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = RESERVED_WORDS + + def _unquote_identifier(self, value): + if value[0] == self.initial_quote: + value = value[1:-1].replace( + self.escape_to_quote, self.escape_quote + ) + return value + + def format_type(self, type_, use_schema=True): + if not type_.name: + raise exc.CompileError( + f"PostgreSQL {type_.__class__.__name__} type requires a name." + ) + + name = self.quote(type_.name) + effective_schema = self.schema_for_object(type_) + + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): + name = f"{self.quote_schema(effective_schema)}.{name}" + return name + + +class ReflectedNamedType(TypedDict): + """Represents a reflected named type.""" + + name: str + """Name of the type.""" + schema: str + """The schema of the type.""" + visible: bool + """Indicates if this type is in the current search path.""" + + +class ReflectedDomainConstraint(TypedDict): + """Represents a reflect check constraint of a domain.""" + + name: str + """Name of the constraint.""" + check: str + """The check constraint text.""" + + +class ReflectedDomain(ReflectedNamedType): + """Represents a reflected enum.""" + + type: str + """The string name of the underlying data type of the domain.""" + nullable: bool + """Indicates if the domain allows null or not.""" + default: Optional[str] + """The string representation of the default value of this domain + or ``None`` if none present. + """ + constraints: List[ReflectedDomainConstraint] + """The constraints defined in the domain, if any. + The constraint are in order of evaluation by postgresql. + """ + collation: Optional[str] + """The collation for the domain.""" + + +class ReflectedEnum(ReflectedNamedType): + """Represents a reflected enum.""" + + labels: List[str] + """The labels that compose the enum.""" + + +class PGInspector(reflection.Inspector): + dialect: PGDialect + + def get_table_oid( + self, table_name: str, schema: Optional[str] = None + ) -> int: + """Return the OID for the given table name. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ + + with self._operation_context() as conn: + return self.dialect.get_table_oid( + conn, table_name, schema, info_cache=self.info_cache + ) + + def get_domains( + self, schema: Optional[str] = None + ) -> List[ReflectedDomain]: + """Return a list of DOMAIN objects. + + Each member is a dictionary containing these fields: + + * name - name of the domain + * schema - the schema name for the domain. + * visible - boolean, whether or not this domain is visible + in the default search path. + * type - the type defined by this domain. + * nullable - Indicates if this domain can be ``NULL``. + * default - The default value of the domain or ``None`` if the + domain has no default. + * constraints - A list of dict wit the constraint defined by this + domain. Each element constaints two keys: ``name`` of the + constraint and ``check`` with the constraint text. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to ``'*'`` to + indicate load domains for all schemas. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect._load_domains( + conn, schema, info_cache=self.info_cache + ) + + def get_enums(self, schema: Optional[str] = None) -> List[ReflectedEnum]: + """Return a list of ENUM objects. + + Each member is a dictionary containing these fields: + + * name - name of the enum + * schema - the schema name for the enum. + * visible - boolean, whether or not this enum is visible + in the default search path. + * labels - a list of string labels that apply to the enum. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to ``'*'`` to + indicate load enums for all schemas. + + """ + with self._operation_context() as conn: + return self.dialect._load_enums( + conn, schema, info_cache=self.info_cache + ) + + def get_foreign_table_names( + self, schema: Optional[str] = None + ) -> List[str]: + """Return a list of FOREIGN TABLE names. + + Behavior is similar to that of + :meth:`_reflection.Inspector.get_table_names`, + except that the list is limited to those tables that report a + ``relkind`` value of ``f``. + + """ + with self._operation_context() as conn: + return self.dialect._get_foreign_table_names( + conn, schema, info_cache=self.info_cache + ) + + def has_type( + self, type_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + """Return if the database has the specified type in the provided + schema. + + :param type_name: the type to check. + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to ``'*'`` to + check in all schemas. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_type( + conn, type_name, schema, info_cache=self.info_cache + ) + + +class PGExecutionContext(default.DefaultExecutionContext): + def fire_sequence(self, seq, type_): + return self._execute_scalar( + ( + "select nextval('%s')" + % self.identifier_preparer.format_sequence(seq) + ), + type_, + ) + + def get_insert_default(self, column): + if column.primary_key and column is column.table._autoincrement_column: + if column.server_default and column.server_default.has_argument: + # pre-execute passive defaults on primary key columns + return self._execute_scalar( + "select %s" % column.server_default.arg, column.type + ) + + elif column.default is None or ( + column.default.is_sequence and column.default.optional + ): + # execute the sequence associated with a SERIAL primary + # key column. for non-primary-key SERIAL, the ID just + # generates server side. + + try: + seq_name = column._postgresql_seq_name + except AttributeError: + tab = column.table.name + col = column.name + tab = tab[0 : 29 + max(0, (29 - len(col)))] + col = col[0 : 29 + max(0, (29 - len(tab)))] + name = "%s_%s_seq" % (tab, col) + column._postgresql_seq_name = seq_name = name + + if column.table is not None: + effective_schema = self.connection.schema_for_object( + column.table + ) + else: + effective_schema = None + + if effective_schema is not None: + exc = 'select nextval(\'"%s"."%s"\')' % ( + effective_schema, + seq_name, + ) + else: + exc = "select nextval('\"%s\"')" % (seq_name,) + + return self._execute_scalar(exc, column.type) + + return super().get_insert_default(column) + + +class PGReadOnlyConnectionCharacteristic( + characteristics.ConnectionCharacteristic +): + transactional = True + + def reset_characteristic(self, dialect, dbapi_conn): + dialect.set_readonly(dbapi_conn, False) + + def set_characteristic(self, dialect, dbapi_conn, value): + dialect.set_readonly(dbapi_conn, value) + + def get_characteristic(self, dialect, dbapi_conn): + return dialect.get_readonly(dbapi_conn) + + +class PGDeferrableConnectionCharacteristic( + characteristics.ConnectionCharacteristic +): + transactional = True + + def reset_characteristic(self, dialect, dbapi_conn): + dialect.set_deferrable(dbapi_conn, False) + + def set_characteristic(self, dialect, dbapi_conn, value): + dialect.set_deferrable(dbapi_conn, value) + + def get_characteristic(self, dialect, dbapi_conn): + return dialect.get_deferrable(dbapi_conn) + + +class PGDialect(default.DefaultDialect): + name = "postgresql" + supports_statement_cache = True + supports_alter = True + max_identifier_length = 63 + supports_sane_rowcount = True + + bind_typing = interfaces.BindTyping.RENDER_CASTS + + supports_native_enum = True + supports_native_boolean = True + supports_native_uuid = True + supports_smallserial = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = True + postfetch_lastrowid = False + use_insertmanyvalues = True + + returns_native_bytes = True + + insertmanyvalues_implicit_sentinel = ( + InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + | InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT + | InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS + ) + + supports_comments = True + supports_constraint_comments = True + supports_default_values = True + + supports_default_metavalue = True + + supports_empty_insert = False + supports_multivalues_insert = True + + supports_identity_columns = True + + default_paramstyle = "pyformat" + ischema_names = ischema_names + colspecs = colspecs + + statement_compiler = PGCompiler + ddl_compiler = PGDDLCompiler + type_compiler_cls = PGTypeCompiler + preparer = PGIdentifierPreparer + execution_ctx_cls = PGExecutionContext + inspector = PGInspector + + update_returning = True + delete_returning = True + insert_returning = True + update_returning_multifrom = True + delete_returning_multifrom = True + + connection_characteristics = ( + default.DefaultDialect.connection_characteristics + ) + connection_characteristics = connection_characteristics.union( + { + "postgresql_readonly": PGReadOnlyConnectionCharacteristic(), + "postgresql_deferrable": PGDeferrableConnectionCharacteristic(), + } + ) + + construct_arguments = [ + ( + schema.Index, + { + "using": False, + "include": None, + "where": None, + "ops": {}, + "concurrently": False, + "with": {}, + "tablespace": None, + "nulls_not_distinct": None, + }, + ), + ( + schema.Table, + { + "ignore_search_path": False, + "tablespace": None, + "partition_by": None, + "with_oids": None, + "on_commit": None, + "inherits": None, + "using": None, + }, + ), + ( + schema.CheckConstraint, + { + "not_valid": False, + }, + ), + ( + schema.ForeignKeyConstraint, + { + "not_valid": False, + }, + ), + ( + schema.UniqueConstraint, + {"nulls_not_distinct": None}, + ), + ] + + reflection_options = ("postgresql_ignore_search_path",) + + _backslash_escapes = True + _supports_create_index_concurrently = True + _supports_drop_index_concurrently = True + + def __init__( + self, + native_inet_types=None, + json_serializer=None, + json_deserializer=None, + **kwargs, + ): + default.DefaultDialect.__init__(self, **kwargs) + + self._native_inet_types = native_inet_types + self._json_deserializer = json_deserializer + self._json_serializer = json_serializer + + def initialize(self, connection): + super().initialize(connection) + + # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 + self.supports_smallserial = self.server_version_info >= (9, 2) + + self._set_backslash_escapes(connection) + + self._supports_drop_index_concurrently = self.server_version_info >= ( + 9, + 2, + ) + self.supports_identity_columns = self.server_version_info >= (10,) + + def get_isolation_level_values(self, dbapi_conn): + # note the generic dialect doesn't have AUTOCOMMIT, however + # all postgresql dialects should include AUTOCOMMIT. + return ( + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ) + + def set_isolation_level(self, dbapi_connection, level): + cursor = dbapi_connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + f"ISOLATION LEVEL {level}" + ) + cursor.execute("COMMIT") + cursor.close() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("show transaction isolation level") + val = cursor.fetchone()[0] + cursor.close() + return val.upper() + + def set_readonly(self, connection, value): + raise NotImplementedError() + + def get_readonly(self, connection): + raise NotImplementedError() + + def set_deferrable(self, connection, value): + raise NotImplementedError() + + def get_deferrable(self, connection): + raise NotImplementedError() + + def _split_multihost_from_url(self, url: URL) -> Union[ + Tuple[None, None], + Tuple[Tuple[Optional[str], ...], Tuple[Optional[int], ...]], + ]: + hosts: Optional[Tuple[Optional[str], ...]] = None + ports_str: Union[str, Tuple[Optional[str], ...], None] = None + + integrated_multihost = False + + if "host" in url.query: + if isinstance(url.query["host"], (list, tuple)): + integrated_multihost = True + hosts, ports_str = zip( + *[ + token.split(":") if ":" in token else (token, None) + for token in url.query["host"] + ] + ) + + elif isinstance(url.query["host"], str): + hosts = tuple(url.query["host"].split(",")) + + if ( + "port" not in url.query + and len(hosts) == 1 + and ":" in hosts[0] + ): + # internet host is alphanumeric plus dots or hyphens. + # this is essentially rfc1123, which refers to rfc952. + # https://stackoverflow.com/questions/3523028/ + # valid-characters-of-a-hostname + host_port_match = re.match( + r"^([a-zA-Z0-9\-\.]*)(?:\:(\d*))?$", hosts[0] + ) + if host_port_match: + integrated_multihost = True + h, p = host_port_match.group(1, 2) + if TYPE_CHECKING: + assert isinstance(h, str) + assert isinstance(p, str) + hosts = (h,) + ports_str = cast( + "Tuple[Optional[str], ...]", (p,) if p else (None,) + ) + + if "port" in url.query: + if integrated_multihost: + raise exc.ArgumentError( + "Can't mix 'multihost' formats together; use " + '"host=h1,h2,h3&port=p1,p2,p3" or ' + '"host=h1:p1&host=h2:p2&host=h3:p3" separately' + ) + if isinstance(url.query["port"], (list, tuple)): + ports_str = url.query["port"] + elif isinstance(url.query["port"], str): + ports_str = tuple(url.query["port"].split(",")) + + ports: Optional[Tuple[Optional[int], ...]] = None + + if ports_str: + try: + ports = tuple(int(x) if x else None for x in ports_str) + except ValueError: + raise exc.ArgumentError( + f"Received non-integer port arguments: {ports_str}" + ) from None + + if ports and ( + (not hosts and len(ports) > 1) + or ( + hosts + and ports + and len(hosts) != len(ports) + and (len(hosts) > 1 or len(ports) > 1) + ) + ): + raise exc.ArgumentError("number of hosts and ports don't match") + + if hosts is not None: + if ports is None: + ports = tuple(None for _ in hosts) + + return hosts, ports # type: ignore + + def do_begin_twophase(self, connection, xid): + self.do_begin(connection.connection) + + def do_prepare_twophase(self, connection, xid): + connection.exec_driver_sql("PREPARE TRANSACTION '%s'" % xid) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + if recover: + # FIXME: ugly hack to get out of transaction + # context when committing recoverable transactions + # Must find out a way how to make the dbapi not + # open a transaction. + connection.exec_driver_sql("ROLLBACK") + connection.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid) + connection.exec_driver_sql("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + if recover: + connection.exec_driver_sql("ROLLBACK") + connection.exec_driver_sql("COMMIT PREPARED '%s'" % xid) + connection.exec_driver_sql("BEGIN") + self.do_rollback(connection.connection) + else: + self.do_commit(connection.connection) + + def do_recover_twophase(self, connection): + return connection.scalars( + sql.text("SELECT gid FROM pg_prepared_xacts") + ).all() + + def _get_default_schema_name(self, connection): + return connection.exec_driver_sql("select current_schema()").scalar() + + @reflection.cache + def has_schema(self, connection, schema, **kw): + query = select(pg_catalog.pg_namespace.c.nspname).where( + pg_catalog.pg_namespace.c.nspname == schema + ) + return bool(connection.scalar(query)) + + def _pg_class_filter_scope_schema( + self, query, schema, scope, pg_class_table=None + ): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + query = query.join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace, + ) + + if scope is ObjectScope.DEFAULT: + query = query.where(pg_class_table.c.relpersistence != "t") + elif scope is ObjectScope.TEMPORARY: + query = query.where(pg_class_table.c.relpersistence == "t") + + if schema is None: + query = query.where( + pg_catalog.pg_table_is_visible(pg_class_table.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + else: + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + def _pg_class_relkind_condition(self, relkinds, pg_class_table=None): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + # uses the any form instead of in otherwise postgresql complaings + # that 'IN could not convert type character to "char"' + return pg_class_table.c.relkind == sql.any_(_array.array(relkinds)) + + @lru_cache() + def _has_table_query(self, schema): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relname == bindparam("table_name"), + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + return self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + query = self._has_table_query(schema) + return bool(connection.scalar(query, {"table_name": table_name})) + + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relkind == "S", + pg_catalog.pg_class.c.relname == sequence_name, + ) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + return bool(connection.scalar(query)) + + @reflection.cache + def has_type(self, connection, type_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_type.c.typname) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .where(pg_catalog.pg_type.c.typname == type_name) + ) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + + return bool(connection.scalar(query)) + + def _get_server_version_info(self, connection): + v = connection.exec_driver_sql("select pg_catalog.version()").scalar() + m = re.match( + r".*(?:PostgreSQL|EnterpriseDB) " + r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?", + v, + ) + if not m: + raise AssertionError( + "Could not determine version from string '%s'" % v + ) + return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) + + @reflection.cache + def get_table_oid(self, connection, table_name, schema=None, **kw): + """Fetch the oid for schema.table_name.""" + query = select(pg_catalog.pg_class.c.oid).where( + pg_catalog.pg_class.c.relname == table_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + table_oid = connection.scalar(query) + if table_oid is None: + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + return table_oid + + @reflection.cache + def get_schema_names(self, connection, **kw): + query = ( + select(pg_catalog.pg_namespace.c.nspname) + .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%")) + .order_by(pg_catalog.pg_namespace.c.nspname) + ) + return connection.scalars(query).all() + + def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope): + query = select(pg_catalog.pg_class.c.relname).where( + self._pg_class_relkind_condition(relkinds) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + return connection.scalars(query).all() + + @reflection.cache + def get_table_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + return self._get_relnames_for_relkinds( + connection, + schema=None, + relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.TEMPORARY, + ) + + @reflection.cache + def _get_foreign_table_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("f",), scope=ObjectScope.ANY + ) + + @reflection.cache + def get_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_materialized_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_MAT_VIEW, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + # NOTE: do not include temp materialzied views (that do not + # seem to be a thing at least up to version 14) + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.TEMPORARY, + ) + + @reflection.cache + def get_sequence_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("S",), scope=ObjectScope.ANY + ) + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid)) + .select_from(pg_catalog.pg_class) + .where( + pg_catalog.pg_class.c.relname == view_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW + ), + ) + ) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + res = connection.scalar(query) + if res is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return res + + def _value_or_raise(self, data, table, schema): + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + return True, {"filter_names": filter_names} + else: + return False, {} + + def _kind_to_relkinds(self, kind: ObjectKind) -> Tuple[str, ...]: + if kind is ObjectKind.ANY: + return pg_catalog.RELKINDS_ALL_TABLE_LIKE + relkinds = () + if ObjectKind.TABLE in kind: + relkinds += pg_catalog.RELKINDS_TABLE + if ObjectKind.VIEW in kind: + relkinds += pg_catalog.RELKINDS_VIEW + if ObjectKind.MATERIALIZED_VIEW in kind: + relkinds += pg_catalog.RELKINDS_MAT_VIEW + return relkinds + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _columns_query(self, schema, has_filter_names, scope, kind): + # NOTE: the query with the default and identity options scalar + # subquery is faster than trying to use outer joins for them + generated = ( + pg_catalog.pg_attribute.c.attgenerated.label("generated") + if self.server_version_info >= (12,) + else sql.null().label("generated") + ) + if self.server_version_info >= (10,): + # join lateral performs worse (~2x slower) than a scalar_subquery + identity = ( + select( + sql.func.json_build_object( + "always", + pg_catalog.pg_attribute.c.attidentity == "a", + "start", + pg_catalog.pg_sequence.c.seqstart, + "increment", + pg_catalog.pg_sequence.c.seqincrement, + "minvalue", + pg_catalog.pg_sequence.c.seqmin, + "maxvalue", + pg_catalog.pg_sequence.c.seqmax, + "cache", + pg_catalog.pg_sequence.c.seqcache, + "cycle", + pg_catalog.pg_sequence.c.seqcycle, + ) + ) + .select_from(pg_catalog.pg_sequence) + .where( + # attidentity != '' is required or it will reflect also + # serial columns as identity. + pg_catalog.pg_attribute.c.attidentity != "", + pg_catalog.pg_sequence.c.seqrelid + == sql.cast( + sql.cast( + pg_catalog.pg_get_serial_sequence( + sql.cast( + sql.cast( + pg_catalog.pg_attribute.c.attrelid, + REGCLASS, + ), + TEXT, + ), + pg_catalog.pg_attribute.c.attname, + ), + REGCLASS, + ), + OID, + ), + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("identity_options") + ) + else: + identity = sql.null().label("identity_options") + + # join lateral performs the same as scalar_subquery here + default = ( + select( + pg_catalog.pg_get_expr( + pg_catalog.pg_attrdef.c.adbin, + pg_catalog.pg_attrdef.c.adrelid, + ) + ) + .select_from(pg_catalog.pg_attrdef) + .where( + pg_catalog.pg_attrdef.c.adrelid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attrdef.c.adnum + == pg_catalog.pg_attribute.c.attnum, + pg_catalog.pg_attribute.c.atthasdef, + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("default") + ) + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_attribute.c.attname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_attribute.c.atttypid, + pg_catalog.pg_attribute.c.atttypmod, + ).label("format_type"), + default, + pg_catalog.pg_attribute.c.attnotnull.label("not_null"), + pg_catalog.pg_class.c.relname.label("table_name"), + pg_catalog.pg_description.c.description.label("comment"), + generated, + identity, + ) + .select_from(pg_catalog.pg_class) + # NOTE: postgresql support table with no user column, meaning + # there is no row with pg_attribute.attnum > 0. use a left outer + # join to avoid filtering these tables. + .outerjoin( + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attribute.c.attnum > 0, + ~pg_catalog.pg_attribute.c.attisdropped, + ), + ) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_description.c.objsubid + == pg_catalog.pg_attribute.c.attnum, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + .order_by( + pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum + ) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_columns( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._columns_query(schema, has_filter_names, scope, kind) + rows = connection.execute(query, params).mappings() + + # dictionary with (name, ) if default search path or (schema, name) + # as keys + domains = { + ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d + for d in self._load_domains( + connection, schema="*", info_cache=kw.get("info_cache") + ) + } + + # dictionary with (name, ) if default search path or (schema, name) + # as keys + enums = dict( + ( + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + ) + for rec in self._load_enums( + connection, schema="*", info_cache=kw.get("info_cache") + ) + ) + + columns = self._get_columns_info(rows, domains, enums, schema) + + return columns.items() + + _format_type_args_pattern = re.compile(r"\((.*)\)") + _format_type_args_delim = re.compile(r"\s*,\s*") + _format_array_spec_pattern = re.compile(r"((?:\[\])*)$") + + def _reflect_type( + self, + format_type: Optional[str], + domains: dict[str, ReflectedDomain], + enums: dict[str, ReflectedEnum], + type_description: str, + ) -> sqltypes.TypeEngine[Any]: + """ + Attempts to reconstruct a column type defined in ischema_names based + on the information available in the format_type. + + If the `format_type` cannot be associated with a known `ischema_names`, + it is treated as a reference to a known PostgreSQL named `ENUM` or + `DOMAIN` type. + """ + type_description = type_description or "unknown type" + if format_type is None: + util.warn( + "PostgreSQL format_type() returned NULL for %s" + % type_description + ) + return sqltypes.NULLTYPE + + attype_args_match = self._format_type_args_pattern.search(format_type) + if attype_args_match and attype_args_match.group(1): + attype_args = self._format_type_args_delim.split( + attype_args_match.group(1) + ) + else: + attype_args = () + + match_array_dim = self._format_array_spec_pattern.search(format_type) + # Each "[]" in array specs corresponds to an array dimension + array_dim = len(match_array_dim.group(1) or "") // 2 + + # Remove all parameters and array specs from format_type to obtain an + # ischema_name candidate + attype = self._format_type_args_pattern.sub("", format_type) + attype = self._format_array_spec_pattern.sub("", attype) + + schema_type = self.ischema_names.get(attype.lower(), None) + args, kwargs = (), {} + + if attype == "numeric": + if len(attype_args) == 2: + precision, scale = map(int, attype_args) + args = (precision, scale) + + elif attype == "double precision": + args = (53,) + + elif attype == "integer": + args = () + + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + elif attype == "bit varying": + kwargs["varying"] = True + if len(attype_args) == 1: + charlen = int(attype_args[0]) + args = (charlen,) + + elif attype.startswith("interval"): + schema_type = INTERVAL + + field_match = re.match(r"interval (.+)", attype) + if field_match: + kwargs["fields"] = field_match.group(1) + + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + else: + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + if enum_or_domain_key in enums: + schema_type = ENUM + enum = enums[enum_or_domain_key] + + args = tuple(enum["labels"]) + kwargs["name"] = enum["name"] + + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + elif enum_or_domain_key in domains: + schema_type = DOMAIN + domain = domains[enum_or_domain_key] + + data_type = self._reflect_type( + domain["type"], + domains, + enums, + type_description="DOMAIN '%s'" % domain["name"], + ) + args = (domain["name"], data_type) + + kwargs["collation"] = domain["collation"] + kwargs["default"] = domain["default"] + kwargs["not_null"] = not domain["nullable"] + kwargs["create_type"] = False + + if domain["constraints"]: + # We only support a single constraint + check_constraint = domain["constraints"][0] + + kwargs["constraint_name"] = check_constraint["name"] + kwargs["check"] = check_constraint["check"] + + if not domain["visible"]: + kwargs["schema"] = domain["schema"] + + else: + try: + charlen = int(attype_args[0]) + args = (charlen, *attype_args[1:]) + except (ValueError, IndexError): + args = attype_args + + if not schema_type: + util.warn( + "Did not recognize type '%s' of %s" + % (attype, type_description) + ) + return sqltypes.NULLTYPE + + data_type = schema_type(*args, **kwargs) + if array_dim >= 1: + # postgres does not preserve dimensionality or size of array types. + data_type = _array.ARRAY(data_type) + + return data_type + + def _get_columns_info(self, rows, domains, enums, schema): + columns = defaultdict(list) + for row_dict in rows: + # ensure that each table has an entry, even if it has no columns + if row_dict["name"] is None: + columns[(schema, row_dict["table_name"])] = ( + ReflectionDefaults.columns() + ) + continue + table_cols = columns[(schema, row_dict["table_name"])] + + coltype = self._reflect_type( + row_dict["format_type"], + domains, + enums, + type_description="column '%s'" % row_dict["name"], + ) + + default = row_dict["default"] + name = row_dict["name"] + generated = row_dict["generated"] + nullable = not row_dict["not_null"] + + if isinstance(coltype, DOMAIN): + if not default: + # domain can override the default value but + # cant set it to None + if coltype.default is not None: + default = coltype.default + + nullable = nullable and not coltype.not_null + + identity = row_dict["identity_options"] + + # If a zero byte or blank string depending on driver (is also + # absent for older PG versions), then not a generated column. + # Otherwise, s = stored. (Other values might be added in the + # future.) + if generated not in (None, "", b"\x00"): + computed = dict( + sqltext=default, persisted=generated in ("s", b"s") + ) + default = None + else: + computed = None + + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + if "." not in match.group(2) and schema is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = ( + match.group(1) + + ('"%s"' % schema) + + "." + + match.group(2) + + match.group(3) + ) + + column_info = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": autoincrement or identity is not None, + "comment": row_dict["comment"], + } + if computed is not None: + column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity + + table_cols.append(column_info) + + return columns + + @lru_cache() + def _table_oids_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + oid_q = select( + pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname + ).where(self._pg_class_relkind_condition(relkinds)) + oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope) + + if has_filter_names: + oid_q = oid_q.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return oid_q + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("kind", InternalTraversal.dp_plain_obj), + ("scope", InternalTraversal.dp_plain_obj), + ) + def _get_table_oids( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + oid_q = self._table_oids_query(schema, has_filter_names, scope, kind) + result = connection.execute(oid_q, params) + return result.all() + + @lru_cache() + def _constraint_query(self, is_unique): + con_sq = ( + select( + pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.conname, + pg_catalog.pg_constraint.c.conindid, + sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( + "attnum" + ), + sql.func.generate_subscripts( + pg_catalog.pg_constraint.c.conkey, 1 + ).label("ord"), + pg_catalog.pg_description.c.description, + ) + .outerjoin( + pg_catalog.pg_description, + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_constraint.c.oid, + ) + .where( + pg_catalog.pg_constraint.c.contype == bindparam("contype"), + pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + ) + .subquery("con") + ) + + attr_sq = ( + select( + con_sq.c.conrelid, + con_sq.c.conname, + con_sq.c.conindid, + con_sq.c.description, + con_sq.c.ord, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + con_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid, + ), + ) + .where( + # NOTE: restate the condition here, since pg15 otherwise + # seems to get confused on pscopg2 sometimes, doing + # a sequential scan of pg_attribute. + # The condition in the con_sq subquery is not actually needed + # in pg15, but it may be needed in older versions. Keeping it + # does not seems to have any inpact in any case. + con_sq.c.conrelid.in_(bindparam("oids")) + ) + .subquery("attr") + ) + + constraint_query = ( + select( + attr_sq.c.conrelid, + sql.func.array_agg( + # NOTE: cast since some postgresql derivatives may + # not support array_agg on the name type + aggregate_order_by( + attr_sq.c.attname.cast(TEXT), attr_sq.c.ord + ) + ).label("cols"), + attr_sq.c.conname, + sql.func.min(attr_sq.c.description).label("description"), + ) + .group_by(attr_sq.c.conrelid, attr_sq.c.conname) + .order_by(attr_sq.c.conrelid, attr_sq.c.conname) + ) + + if is_unique: + if self.server_version_info >= (15,): + constraint_query = constraint_query.join( + pg_catalog.pg_index, + attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid, + ).add_columns( + sql.func.bool_and( + pg_catalog.pg_index.c.indnullsnotdistinct + ).label("indnullsnotdistinct") + ) + else: + constraint_query = constraint_query.add_columns( + sql.false().label("indnullsnotdistinct") + ) + else: + constraint_query = constraint_query.add_columns( + sql.null().label("extra") + ) + return constraint_query + + def _reflect_constraint( + self, connection, contype, schema, filter_names, scope, kind, **kw + ): + # used to reflect primary and unique constraint + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + batches = list(table_oids) + is_unique = contype == "u" + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._constraint_query(is_unique), + {"oids": [r[0] for r in batch], "contype": contype}, + ) + + result_by_oid = defaultdict(list) + for oid, cols, constraint_name, comment, extra in result: + result_by_oid[oid].append( + (cols, constraint_name, comment, extra) + ) + + for oid, tablename in batch: + for_oid = result_by_oid.get(oid, ()) + if for_oid: + for cols, constraint, comment, extra in for_oid: + if is_unique: + yield tablename, cols, constraint, comment, { + "nullsnotdistinct": extra + } + else: + yield tablename, cols, constraint, comment, None + else: + yield tablename, None, None, None, None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_pk_constraint( + self, connection, schema, filter_names, scope, kind, **kw + ): + result = self._reflect_constraint( + connection, "p", schema, filter_names, scope, kind, **kw + ) + + # only a single pk can be present for each table. Return an entry + # even if a table has no primary key + default = ReflectionDefaults.pk_constraint + return ( + ( + (schema, table_name), + ( + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + "comment": comment, + } + if pk_name is not None + else default() + ), + ) + for table_name, cols, pk_name, comment, _ in result + ) + + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + postgresql_ignore_search_path=False, + **kw, + ): + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + postgresql_ignore_search_path=postgresql_ignore_search_path, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _foreing_key_query(self, schema, has_filter_names, scope, kind): + pg_class_ref = pg_catalog.pg_class.alias("cls_ref") + pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref") + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + # NOTE: avoid calling pg_get_constraintdef when not needed + # to speed up the query + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_namespace_ref.c.nspname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "f", + ), + ) + .outerjoin( + pg_class_ref, + pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid, + ) + .outerjoin( + pg_namespace_ref, + pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid, + ) + .outerjoin( + pg_catalog.pg_description, + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_constraint.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + @util.memoized_property + def _fk_regex_pattern(self): + # optionally quoted token + qtoken = '(?:"[^"]+"|[A-Za-z0-9_]+?)' + + # https://www.postgresql.org/docs/current/static/sql-createtable.html + return re.compile( + r"FOREIGN KEY \((.*?)\) " + rf"REFERENCES (?:({qtoken})\.)?({qtoken})\(((?:{qtoken}(?: *, *)?)+)\)" # noqa: E501 + r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" + r"[\s]?(ON UPDATE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(ON DELETE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?" + r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" + ) + + def get_multi_foreign_keys( + self, + connection, + schema, + filter_names, + scope, + kind, + postgresql_ignore_search_path=False, + **kw, + ): + preparer = self.identifier_preparer + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._foreing_key_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + FK_REGEX = self._fk_regex_pattern + + fkeys = defaultdict(list) + default = ReflectionDefaults.foreign_keys + for table_name, conname, condef, conschema, comment in result: + # ensure that each table has an entry, even if it has + # no foreign keys + if conname is None: + fkeys[(schema, table_name)] = default() + continue + table_fks = fkeys[(schema, table_name)] + m = re.search(FK_REGEX, condef).groups() + + ( + constrained_columns, + referred_schema, + referred_table, + referred_columns, + _, + match, + _, + onupdate, + _, + ondelete, + deferrable, + _, + initially, + ) = m + + if deferrable is not None: + deferrable = True if deferrable == "DEFERRABLE" else False + constrained_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s*", constrained_columns) + ] + + if postgresql_ignore_search_path: + # when ignoring search path, we use the actual schema + # provided it isn't the "default" schema + if conschema != self.default_schema_name: + referred_schema = conschema + else: + referred_schema = schema + elif referred_schema: + # referred_schema is the schema that we regexp'ed from + # pg_get_constraintdef(). If the schema is in the search + # path, pg_get_constraintdef() will give us None. + referred_schema = preparer._unquote_identifier(referred_schema) + elif schema is not None and schema == conschema: + # If the actual schema matches the schema of the table + # we're reflecting, then we will use that. + referred_schema = schema + + referred_table = preparer._unquote_identifier(referred_table) + referred_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s", referred_columns) + ] + options = { + k: v + for k, v in [ + ("onupdate", onupdate), + ("ondelete", ondelete), + ("initially", initially), + ("deferrable", deferrable), + ("match", match), + ] + if v is not None and v != "NO ACTION" + } + fkey_d = { + "name": conname, + "constrained_columns": constrained_columns, + "referred_schema": referred_schema, + "referred_table": referred_table, + "referred_columns": referred_columns, + "options": options, + "comment": comment, + } + table_fks.append(fkey_d) + return fkeys.items() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @util.memoized_property + def _index_query(self): + pg_class_index = pg_catalog.pg_class.alias("cls_idx") + # NOTE: repeating oids clause improve query performance + + # subquery to get the columns + idx_sq = ( + select( + pg_catalog.pg_index.c.indexrelid, + pg_catalog.pg_index.c.indrelid, + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + sql.func.generate_subscripts( + pg_catalog.pg_index.c.indkey, 1 + ).label("ord"), + ) + .where( + ~pg_catalog.pg_index.c.indisprimary, + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ) + .subquery("idx") + ) + + attr_sq = ( + select( + idx_sq.c.indexrelid, + idx_sq.c.indrelid, + idx_sq.c.ord, + # NOTE: always using pg_get_indexdef is too slow so just + # invoke when the element is an expression + sql.case( + ( + idx_sq.c.attnum == 0, + pg_catalog.pg_get_indexdef( + idx_sq.c.indexrelid, idx_sq.c.ord + 1, True + ), + ), + # NOTE: need to cast this since attname is of type "name" + # that's limited to 63 bytes, while pg_get_indexdef + # returns "text" so its output may get cut + else_=pg_catalog.pg_attribute.c.attname.cast(TEXT), + ).label("element"), + (idx_sq.c.attnum == 0).label("is_expr"), + ) + .select_from(idx_sq) + .outerjoin( + # do not remove rows where idx_sq.c.attnum is 0 + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, + ), + ) + .where(idx_sq.c.indrelid.in_(bindparam("oids"))) + .subquery("idx_attr") + ) + + cols_sq = ( + select( + attr_sq.c.indexrelid, + sql.func.min(attr_sq.c.indrelid), + sql.func.array_agg( + aggregate_order_by(attr_sq.c.element, attr_sq.c.ord) + ).label("elements"), + sql.func.array_agg( + aggregate_order_by(attr_sq.c.is_expr, attr_sq.c.ord) + ).label("elements_is_expr"), + ) + .group_by(attr_sq.c.indexrelid) + .subquery("idx_cols") + ) + + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = sql.null().label("indnkeyatts") + + if self.server_version_info >= (15,): + nulls_not_distinct = pg_catalog.pg_index.c.indnullsnotdistinct + else: + nulls_not_distinct = sql.false().label("indnullsnotdistinct") + + return ( + select( + pg_catalog.pg_index.c.indrelid, + pg_class_index.c.relname.label("relname_index"), + pg_catalog.pg_index.c.indisunique, + pg_catalog.pg_constraint.c.conrelid.is_not(None).label( + "has_constraint" + ), + pg_catalog.pg_index.c.indoption, + pg_class_index.c.reloptions, + pg_catalog.pg_am.c.amname, + # NOTE: pg_get_expr is very fast so this case has almost no + # performance impact + sql.case( + ( + pg_catalog.pg_index.c.indpred.is_not(None), + pg_catalog.pg_get_expr( + pg_catalog.pg_index.c.indpred, + pg_catalog.pg_index.c.indrelid, + ), + ), + else_=None, + ).label("filter_definition"), + indnkeyatts, + nulls_not_distinct, + cols_sq.c.elements, + cols_sq.c.elements_is_expr, + ) + .select_from(pg_catalog.pg_index) + .where( + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ~pg_catalog.pg_index.c.indisprimary, + ) + .join( + pg_class_index, + pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, + ) + .join( + pg_catalog.pg_am, + pg_class_index.c.relam == pg_catalog.pg_am.c.oid, + ) + .outerjoin( + cols_sq, + pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid, + ) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_index.c.indrelid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_index.c.indexrelid + == pg_catalog.pg_constraint.c.conindid, + pg_catalog.pg_constraint.c.contype + == sql.any_(_array.array(("p", "u", "x"))), + ), + ) + .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) + ) + + def get_multi_indexes( + self, connection, schema, filter_names, scope, kind, **kw + ): + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + + indexes = defaultdict(list) + default = ReflectionDefaults.indexes + + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._index_query, {"oids": [r[0] for r in batch]} + ).mappings() + + result_by_oid = defaultdict(list) + for row_dict in result: + result_by_oid[row_dict["indrelid"]].append(row_dict) + + for oid, table_name in batch: + if oid not in result_by_oid: + # ensure that each table has an entry, even if reflection + # is skipped because not supported + indexes[(schema, table_name)] = default() + continue + + for row in result_by_oid[oid]: + index_name = row["relname_index"] + + table_indexes = indexes[(schema, table_name)] + + all_elements = row["elements"] + all_elements_is_expr = row["elements_is_expr"] + indnkeyatts = row["indnkeyatts"] + # "The number of key columns in the index, not counting any + # included columns, which are merely stored and do not + # participate in the index semantics" + if indnkeyatts and len(all_elements) > indnkeyatts: + # this is a "covering index" which has INCLUDE columns + # as well as regular index columns + inc_cols = all_elements[indnkeyatts:] + idx_elements = all_elements[:indnkeyatts] + idx_elements_is_expr = all_elements_is_expr[ + :indnkeyatts + ] + # postgresql does not support expression on included + # columns as of v14: "ERROR: expressions are not + # supported in included columns". + assert all( + not is_expr + for is_expr in all_elements_is_expr[indnkeyatts:] + ) + else: + idx_elements = all_elements + idx_elements_is_expr = all_elements_is_expr + inc_cols = [] + + index = {"name": index_name, "unique": row["indisunique"]} + if any(idx_elements_is_expr): + index["column_names"] = [ + None if is_expr else expr + for expr, is_expr in zip( + idx_elements, idx_elements_is_expr + ) + ] + index["expressions"] = idx_elements + else: + index["column_names"] = idx_elements + + sorting = {} + for col_index, col_flags in enumerate(row["indoption"]): + col_sorting = () + # try to set flags only if they differ from PG + # defaults... + if col_flags & 0x01: + col_sorting += ("desc",) + if not (col_flags & 0x02): + col_sorting += ("nulls_last",) + else: + if col_flags & 0x02: + col_sorting += ("nulls_first",) + if col_sorting: + sorting[idx_elements[col_index]] = col_sorting + if sorting: + index["column_sorting"] = sorting + if row["has_constraint"]: + index["duplicates_constraint"] = index_name + + dialect_options = {} + if row["reloptions"]: + dialect_options["postgresql_with"] = dict( + [option.split("=") for option in row["reloptions"]] + ) + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + amname = row["amname"] + if amname != "btree": + dialect_options["postgresql_using"] = row["amname"] + if row["filter_definition"]: + dialect_options["postgresql_where"] = row[ + "filter_definition" + ] + if self.server_version_info >= (11,): + # NOTE: this is legacy, this is part of + # dialect_options now as of #7382 + index["include_columns"] = inc_cols + dialect_options["postgresql_include"] = inc_cols + if row["indnullsnotdistinct"]: + # the default is False, so ignore it. + dialect_options["postgresql_nulls_not_distinct"] = row[ + "indnullsnotdistinct" + ] + + if dialect_options: + index["dialect_options"] = dialect_options + + table_indexes.append(index) + return indexes.items() + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_unique_constraints( + self, + connection, + schema, + filter_names, + scope, + kind, + **kw, + ): + result = self._reflect_constraint( + connection, "u", schema, filter_names, scope, kind, **kw + ) + + # each table can have multiple unique constraints + uniques = defaultdict(list) + default = ReflectionDefaults.unique_constraints + for table_name, cols, con_name, comment, options in result: + # ensure a list is created for each table. leave it empty if + # the table has no unique cosntraint + if con_name is None: + uniques[(schema, table_name)] = default() + continue + + uc_dict = { + "column_names": cols, + "name": con_name, + "comment": comment, + } + if options: + if options["nullsnotdistinct"]: + uc_dict["dialect_options"] = { + "postgresql_nulls_not_distinct": options[ + "nullsnotdistinct" + ] + } + + uniques[(schema, table_name)].append(uc_dict) + return uniques.items() + + @reflection.cache + def get_table_comment(self, connection, table_name, schema=None, **kw): + data = self.get_multi_table_comment( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _comment_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_description.c.objoid, + pg_catalog.pg_description.c.objsubid == 0, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_table_comment( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + default = ReflectionDefaults.table_comment + return ( + ( + (schema, table), + {"text": comment} if comment is not None else default(), + ) + for table, comment in result + ) + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + data = self.get_multi_check_constraints( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _check_constraint_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + # NOTE: avoid calling pg_get_constraintdef when not needed + # to speed up the query + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "c", + ), + ) + .outerjoin( + pg_catalog.pg_description, + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_constraint.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_check_constraints( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._check_constraint_query( + schema, has_filter_names, scope, kind + ) + result = connection.execute(query, params) + + check_constraints = defaultdict(list) + default = ReflectionDefaults.check_constraints + for table_name, check_name, src, comment in result: + # only two cases for check_name and src: both null or both defined + if check_name is None and src is None: + check_constraints[(schema, table_name)] = default() + continue + # samples: + # "CHECK (((a > 1) AND (a < 5)))" + # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))" + # "CHECK (((a > 1) AND (a < 5))) NOT VALID" + # "CHECK (some_boolean_function(a))" + # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)" + # "CHECK (a NOT NULL) NO INHERIT" + # "CHECK (a NOT NULL) NO INHERIT NOT VALID" + + m = re.match( + r"^CHECK *\((.+)\)( NO INHERIT)?( NOT VALID)?$", + src, + flags=re.DOTALL, + ) + if not m: + util.warn("Could not parse CHECK constraint text: %r" % src) + sqltext = "" + else: + sqltext = re.compile( + r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL + ).sub(r"\1", m.group(1)) + entry = { + "name": check_name, + "sqltext": sqltext, + "comment": comment, + } + if m: + do = {} + if " NOT VALID" in m.groups(): + do["not_valid"] = True + if " NO INHERIT" in m.groups(): + do["no_inherit"] = True + if do: + entry["dialect_options"] = do + + check_constraints[(schema, table_name)].append(entry) + return check_constraints.items() + + def _pg_type_filter_schema(self, query, schema): + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + @lru_cache() + def _enum_query(self, schema): + lbl_agg_sq = ( + select( + pg_catalog.pg_enum.c.enumtypid, + sql.func.array_agg( + aggregate_order_by( + # NOTE: cast since some postgresql derivatives may + # not support array_agg on the name type + pg_catalog.pg_enum.c.enumlabel.cast(TEXT), + pg_catalog.pg_enum.c.enumsortorder, + ) + ).label("labels"), + ) + .group_by(pg_catalog.pg_enum.c.enumtypid) + .subquery("lbl_agg") + ) + + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + lbl_agg_sq.c.labels.label("labels"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid + ) + .where(pg_catalog.pg_type.c.typtype == "e") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) + ) + + return self._pg_type_filter_schema(query, schema) + + @reflection.cache + def _load_enums(self, connection, schema=None, **kw): + if not self.supports_native_enum: + return [] + + result = connection.execute(self._enum_query(schema)) + + enums = [] + for name, visible, schema, labels in result: + enums.append( + { + "name": name, + "schema": schema, + "visible": visible, + "labels": [] if labels is None else labels, + } + ) + return enums + + @lru_cache() + def _domain_query(self, schema): + con_sq = ( + select( + pg_catalog.pg_constraint.c.contypid, + sql.func.array_agg( + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ) + ).label("condefs"), + sql.func.array_agg( + # NOTE: cast since some postgresql derivatives may + # not support array_agg on the name type + pg_catalog.pg_constraint.c.conname.cast(TEXT) + ).label("connames"), + ) + # The domain this constraint is on; zero if not a domain constraint + .where(pg_catalog.pg_constraint.c.contypid != 0) + .group_by(pg_catalog.pg_constraint.c.contypid) + .subquery("domain_constraints") + ) + + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_type.c.typbasetype, + pg_catalog.pg_type.c.typtypmod, + ).label("attype"), + (~pg_catalog.pg_type.c.typnotnull).label("nullable"), + pg_catalog.pg_type.c.typdefault.label("default"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + con_sq.c.condefs, + con_sq.c.connames, + pg_catalog.pg_collation.c.collname, + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + pg_catalog.pg_collation, + pg_catalog.pg_type.c.typcollation + == pg_catalog.pg_collation.c.oid, + ) + .outerjoin( + con_sq, + pg_catalog.pg_type.c.oid == con_sq.c.contypid, + ) + .where(pg_catalog.pg_type.c.typtype == "d") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) + ) + return self._pg_type_filter_schema(query, schema) + + @reflection.cache + def _load_domains(self, connection, schema=None, **kw): + result = connection.execute(self._domain_query(schema)) + + domains: List[ReflectedDomain] = [] + for domain in result.mappings(): + # strip (30) from character varying(30) + attype = re.search(r"([^\(]+)", domain["attype"]).group(1) + constraints: List[ReflectedDomainConstraint] = [] + if domain["connames"]: + # When a domain has multiple CHECK constraints, they will + # be tested in alphabetical order by name. + sorted_constraints = sorted( + zip(domain["connames"], domain["condefs"]), + key=lambda t: t[0], + ) + for name, def_ in sorted_constraints: + # constraint is in the form "CHECK (expression)". + # remove "CHECK (" and the tailing ")". + check = def_[7:-1] + constraints.append({"name": name, "check": check}) + + domain_rec: ReflectedDomain = { + "name": domain["name"], + "schema": domain["schema"], + "visible": domain["visible"], + "type": attype, + "nullable": domain["nullable"], + "default": domain["default"], + "constraints": constraints, + "collation": domain["collname"], + } + domains.append(domain_rec) + + return domains + + def _set_backslash_escapes(self, connection): + # this method is provided as an override hook for descendant + # dialects (e.g. Redshift), so removing it may break them + std_string = connection.exec_driver_sql( + "show standard_conforming_strings" + ).scalar() + self._backslash_escapes = std_string == "off" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/dml.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/dml.py new file mode 100644 index 0000000..4404ecd --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/dml.py @@ -0,0 +1,310 @@ +# dialects/postgresql/dml.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +from typing import Any +from typing import Optional + +from . import ext +from .._typing import _OnConflictConstraintT +from .._typing import _OnConflictIndexElementsT +from .._typing import _OnConflictIndexWhereT +from .._typing import _OnConflictSetT +from .._typing import _OnConflictWhereT +from ... import util +from ...sql import coercions +from ...sql import roles +from ...sql import schema +from ...sql._typing import _DMLTableArgument +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.base import ReadOnlyColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.elements import KeyedColumnElement +from ...sql.expression import alias +from ...util.typing import Self + + +__all__ = ("Insert", "insert") + + +def insert(table: _DMLTableArgument) -> Insert: + """Construct a PostgreSQL-specific variant :class:`_postgresql.Insert` + construct. + + .. container:: inherited_member + + The :func:`sqlalchemy.dialects.postgresql.insert` function creates + a :class:`sqlalchemy.dialects.postgresql.Insert`. This class is based + on the dialect-agnostic :class:`_sql.Insert` construct which may + be constructed using the :func:`_sql.insert` function in + SQLAlchemy Core. + + The :class:`_postgresql.Insert` construct includes additional methods + :meth:`_postgresql.Insert.on_conflict_do_update`, + :meth:`_postgresql.Insert.on_conflict_do_nothing`. + + """ + return Insert(table) + + +class Insert(StandardInsert): + """PostgreSQL-specific implementation of INSERT. + + Adds methods for PG-specific syntaxes such as ON CONFLICT. + + The :class:`_postgresql.Insert` object is created using the + :func:`sqlalchemy.dialects.postgresql.insert` function. + + """ + + stringify_dialect = "postgresql" + inherit_cache = False + + @util.memoized_property + def excluded( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + """Provide the ``excluded`` namespace for an ON CONFLICT statement + + PG's ON CONFLICT clause allows reference to the row that would + be inserted, known as ``excluded``. This attribute provides + all columns in this row to be referenceable. + + .. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an + instance of :class:`_expression.ColumnCollection`, which provides + an interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.excluded.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.excluded["column name"]`` or + ``stmt.excluded["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` - example of how + to use :attr:`_expression.Insert.excluded` + + """ + return alias(self.table, name="excluded").columns + + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_update( + self, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, + ) -> Self: + r""" + Specifies a DO UPDATE SET action for ON CONFLICT clause. + + Either the ``constraint`` or ``index_elements`` argument is + required, but only one of these can be specified. + + :param constraint: + The name of a unique or exclusion constraint on the table, + or the constraint object itself if it has a .name attribute. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + :param set\_: + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_postgresql.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of + UPDATE, unless they are manually specified in the + :paramref:`.Insert.on_conflict_do_update.set_` dictionary. + + :param where: + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). + + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` + + """ + self._post_values_clause = OnConflictDoUpdate( + constraint, index_elements, index_where, set_, where + ) + return self + + @_generative + @_on_conflict_exclusive + def on_conflict_do_nothing( + self, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + ) -> Self: + """ + Specifies a DO NOTHING action for ON CONFLICT clause. + + The ``constraint`` and ``index_elements`` arguments + are optional, but only one of these can be specified. + + :param constraint: + The name of a unique or exclusion constraint on the table, + or the constraint object itself if it has a .name attribute. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + .. seealso:: + + :ref:`postgresql_insert_on_conflict` + + """ + self._post_values_clause = OnConflictDoNothing( + constraint, index_elements, index_where + ) + return self + + +class OnConflictClause(ClauseElement): + stringify_dialect = "postgresql" + + constraint_target: Optional[str] + inferred_target_elements: _OnConflictIndexElementsT + inferred_target_whereclause: _OnConflictIndexWhereT + + def __init__( + self, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + ): + if constraint is not None: + if not isinstance(constraint, str) and isinstance( + constraint, + (schema.Constraint, ext.ExcludeConstraint), + ): + constraint = getattr(constraint, "name") or constraint + + if constraint is not None: + if index_elements is not None: + raise ValueError( + "'constraint' and 'index_elements' are mutually exclusive" + ) + + if isinstance(constraint, str): + self.constraint_target = constraint + self.inferred_target_elements = None + self.inferred_target_whereclause = None + elif isinstance(constraint, schema.Index): + index_elements = constraint.expressions + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) + elif isinstance(constraint, ext.ExcludeConstraint): + index_elements = constraint.columns + index_where = constraint.where + else: + index_elements = constraint.columns + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) + + if index_elements is not None: + self.constraint_target = None + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where + elif constraint is None: + self.constraint_target = self.inferred_target_elements = ( + self.inferred_target_whereclause + ) = None + + +class OnConflictDoNothing(OnConflictClause): + __visit_name__ = "on_conflict_do_nothing" + + +class OnConflictDoUpdate(OnConflictClause): + __visit_name__ = "on_conflict_do_update" + + def __init__( + self, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, + ): + super().__init__( + constraint=constraint, + index_elements=index_elements, + index_where=index_where, + ) + + if ( + self.inferred_target_elements is None + and self.constraint_target is None + ): + raise ValueError( + "Either constraint or index_elements, " + "but not both, must be specified unless DO NOTHING" + ) + + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update_values_to_set = [ + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() + ] + self.update_whereclause = where diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ext.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ext.py new file mode 100644 index 0000000..7fc0895 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ext.py @@ -0,0 +1,496 @@ +# dialects/postgresql/ext.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import types +from .array import ARRAY +from ...sql import coercions +from ...sql import elements +from ...sql import expression +from ...sql import functions +from ...sql import roles +from ...sql import schema +from ...sql.schema import ColumnCollectionConstraint +from ...sql.sqltypes import TEXT +from ...sql.visitors import InternalTraversal + +_T = TypeVar("_T", bound=Any) + +if TYPE_CHECKING: + from ...sql.visitors import _TraverseInternalsType + + +class aggregate_order_by(expression.ColumnElement): + """Represent a PostgreSQL aggregate order by expression. + + E.g.:: + + from sqlalchemy.dialects.postgresql import aggregate_order_by + expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) + stmt = select(expr) + + would represent the expression:: + + SELECT array_agg(a ORDER BY b DESC) FROM table; + + Similarly:: + + expr = func.string_agg( + table.c.a, + aggregate_order_by(literal_column("','"), table.c.a) + ) + stmt = select(expr) + + Would represent:: + + SELECT string_agg(a, ',' ORDER BY a) FROM table; + + .. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms + + .. seealso:: + + :class:`_functions.array_agg` + + """ + + __visit_name__ = "aggregate_order_by" + + stringify_dialect = "postgresql" + _traverse_internals: _TraverseInternalsType = [ + ("target", InternalTraversal.dp_clauseelement), + ("type", InternalTraversal.dp_type), + ("order_by", InternalTraversal.dp_clauseelement), + ] + + def __init__(self, target, *order_by): + self.target = coercions.expect(roles.ExpressionElementRole, target) + self.type = self.target.type + + _lob = len(order_by) + if _lob == 0: + raise TypeError("at least one ORDER BY element is required") + elif _lob == 1: + self.order_by = coercions.expect( + roles.ExpressionElementRole, order_by[0] + ) + else: + self.order_by = elements.ClauseList( + *order_by, _literal_as_text_role=roles.ExpressionElementRole + ) + + def self_group(self, against=None): + return self + + def get_children(self, **kwargs): + return self.target, self.order_by + + def _copy_internals(self, clone=elements._clone, **kw): + self.target = clone(self.target, **kw) + self.order_by = clone(self.order_by, **kw) + + @property + def _from_objects(self): + return self.target._from_objects + self.order_by._from_objects + + +class ExcludeConstraint(ColumnCollectionConstraint): + """A table-level EXCLUDE constraint. + + Defines an EXCLUDE constraint as described in the `PostgreSQL + documentation`__. + + __ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE + + """ # noqa + + __visit_name__ = "exclude_constraint" + + where = None + inherit_cache = False + + create_drop_stringify_dialect = "postgresql" + + @elements._document_text_coercion( + "where", + ":class:`.ExcludeConstraint`", + ":paramref:`.ExcludeConstraint.where`", + ) + def __init__(self, *elements, **kw): + r""" + Create an :class:`.ExcludeConstraint` object. + + E.g.:: + + const = ExcludeConstraint( + (Column('period'), '&&'), + (Column('group'), '='), + where=(Column('group') != 'some group'), + ops={'group': 'my_operator_class'} + ) + + The constraint is normally embedded into the :class:`_schema.Table` + construct + directly, or added later using :meth:`.append_constraint`:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('period', TSRANGE()), + Column('group', String) + ) + + some_table.append_constraint( + ExcludeConstraint( + (some_table.c.period, '&&'), + (some_table.c.group, '='), + where=some_table.c.group != 'some group', + name='some_table_excl_const', + ops={'group': 'my_operator_class'} + ) + ) + + The exclude constraint defined in this example requires the + ``btree_gist`` extension, that can be created using the + command ``CREATE EXTENSION btree_gist;``. + + :param \*elements: + + A sequence of two tuples of the form ``(column, operator)`` where + "column" is either a :class:`_schema.Column` object, or a SQL + expression element (e.g. ``func.int8range(table.from, table.to)``) + or the name of a column as string, and "operator" is a string + containing the operator to use (e.g. `"&&"` or `"="`). + + In order to specify a column name when a :class:`_schema.Column` + object is not available, while ensuring + that any necessary quoting rules take effect, an ad-hoc + :class:`_schema.Column` or :func:`_expression.column` + object should be used. + The ``column`` may also be a string SQL expression when + passed as :func:`_expression.literal_column` or + :func:`_expression.text` + + :param name: + Optional, the in-database name of this constraint. + + :param deferrable: + Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when + issuing DDL for this constraint. + + :param initially: + Optional string. If set, emit INITIALLY when issuing DDL + for this constraint. + + :param using: + Optional string. If set, emit USING when issuing DDL + for this constraint. Defaults to 'gist'. + + :param where: + Optional SQL expression construct or literal SQL string. + If set, emit WHERE when issuing DDL + for this constraint. + + :param ops: + Optional dictionary. Used to define operator classes for the + elements; works the same way as that of the + :ref:`postgresql_ops ` + parameter specified to the :class:`_schema.Index` construct. + + .. versionadded:: 1.3.21 + + .. seealso:: + + :ref:`postgresql_operator_classes` - general description of how + PostgreSQL operator classes are specified. + + """ + columns = [] + render_exprs = [] + self.operators = {} + + expressions, operators = zip(*elements) + + for (expr, column, strname, add_element), operator in zip( + coercions.expect_col_expression_collection( + roles.DDLConstraintColumnRole, expressions + ), + operators, + ): + if add_element is not None: + columns.append(add_element) + + name = column.name if column is not None else strname + + if name is not None: + # backwards compat + self.operators[name] = operator + + render_exprs.append((expr, name, operator)) + + self._render_exprs = render_exprs + + ColumnCollectionConstraint.__init__( + self, + *columns, + name=kw.get("name"), + deferrable=kw.get("deferrable"), + initially=kw.get("initially"), + ) + self.using = kw.get("using", "gist") + where = kw.get("where") + if where is not None: + self.where = coercions.expect(roles.StatementOptionRole, where) + + self.ops = kw.get("ops", {}) + + def _set_parent(self, table, **kw): + super()._set_parent(table) + + self._render_exprs = [ + ( + expr if not isinstance(expr, str) else table.c[expr], + name, + operator, + ) + for expr, name, operator in (self._render_exprs) + ] + + def _copy(self, target_table=None, **kw): + elements = [ + ( + schema._copy_expression(expr, self.parent, target_table), + operator, + ) + for expr, _, operator in self._render_exprs + ] + c = self.__class__( + *elements, + name=self.name, + deferrable=self.deferrable, + initially=self.initially, + where=self.where, + using=self.using, + ) + c.dispatch._update(self.dispatch) + return c + + +def array_agg(*arg, **kw): + """PostgreSQL-specific form of :class:`_functions.array_agg`, ensures + return type is :class:`_postgresql.ARRAY` and not + the plain :class:`_types.ARRAY`, unless an explicit ``type_`` + is passed. + + """ + kw["_default_array_type"] = ARRAY + return functions.func.array_agg(*arg, **kw) + + +class _regconfig_fn(functions.GenericFunction[_T]): + inherit_cache = True + + def __init__(self, *args, **kwargs): + args = list(args) + if len(args) > 1: + initial_arg = coercions.expect( + roles.ExpressionElementRole, + args.pop(0), + name=getattr(self, "name", None), + apply_propagate_attrs=self, + type_=types.REGCONFIG, + ) + initial_arg = [initial_arg] + else: + initial_arg = [] + + addtl_args = [ + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + apply_propagate_attrs=self, + ) + for c in args + ] + super().__init__(*(initial_arg + addtl_args), **kwargs) + + +class to_tsvector(_regconfig_fn): + """The PostgreSQL ``to_tsvector`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSVECTOR`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.to_tsvector` will be used automatically when invoking + ``sqlalchemy.func.to_tsvector()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSVECTOR + + +class to_tsquery(_regconfig_fn): + """The PostgreSQL ``to_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.to_tsquery` will be used automatically when invoking + ``sqlalchemy.func.to_tsquery()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class plainto_tsquery(_regconfig_fn): + """The PostgreSQL ``plainto_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.plainto_tsquery` will be used automatically when + invoking ``sqlalchemy.func.plainto_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class phraseto_tsquery(_regconfig_fn): + """The PostgreSQL ``phraseto_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.phraseto_tsquery` will be used automatically when + invoking ``sqlalchemy.func.phraseto_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class websearch_to_tsquery(_regconfig_fn): + """The PostgreSQL ``websearch_to_tsquery`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_postgresql.TSQUERY`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.websearch_to_tsquery` will be used automatically when + invoking ``sqlalchemy.func.websearch_to_tsquery()``, ensuring the correct + argument and return type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = types.TSQUERY + + +class ts_headline(_regconfig_fn): + """The PostgreSQL ``ts_headline`` SQL function. + + This function applies automatic casting of the REGCONFIG argument + to use the :class:`_postgresql.REGCONFIG` datatype automatically, + and applies a return type of :class:`_types.TEXT`. + + Assuming the PostgreSQL dialect has been imported, either by invoking + ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL + engine using ``create_engine("postgresql...")``, + :class:`_postgresql.ts_headline` will be used automatically when invoking + ``sqlalchemy.func.ts_headline()``, ensuring the correct argument and return + type handlers are used at compile and execution time. + + .. versionadded:: 2.0.0rc1 + + """ + + inherit_cache = True + type = TEXT + + def __init__(self, *args, **kwargs): + args = list(args) + + # parse types according to + # https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-HEADLINE + if len(args) < 2: + # invalid args; don't do anything + has_regconfig = False + elif ( + isinstance(args[1], elements.ColumnElement) + and args[1].type._type_affinity is types.TSQUERY + ): + # tsquery is second argument, no regconfig argument + has_regconfig = False + else: + has_regconfig = True + + if has_regconfig: + initial_arg = coercions.expect( + roles.ExpressionElementRole, + args.pop(0), + apply_propagate_attrs=self, + name=getattr(self, "name", None), + type_=types.REGCONFIG, + ) + initial_arg = [initial_arg] + else: + initial_arg = [] + + addtl_args = [ + coercions.expect( + roles.ExpressionElementRole, + c, + name=getattr(self, "name", None), + apply_propagate_attrs=self, + ) + for c in args + ] + super().__init__(*(initial_arg + addtl_args), **kwargs) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/hstore.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/hstore.py new file mode 100644 index 0000000..04c8cf1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/hstore.py @@ -0,0 +1,397 @@ +# dialects/postgresql/hstore.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +import re + +from .array import ARRAY +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import GETITEM +from .operators import HAS_ALL +from .operators import HAS_ANY +from .operators import HAS_KEY +from ... import types as sqltypes +from ...sql import functions as sqlfunc + + +__all__ = ("HSTORE", "hstore") + + +class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): + """Represent the PostgreSQL HSTORE type. + + The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', HSTORE) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + :class:`.HSTORE` provides for a wide range of operations, including: + + * Index operations:: + + data_table.c.data['some key'] == 'some value' + + * Containment operations:: + + data_table.c.data.has_key('some key') + + data_table.c.data.has_all(['one', 'two', 'three']) + + * Concatenation:: + + data_table.c.data + {"k1": "v1"} + + For a full list of special methods see + :class:`.HSTORE.comparator_factory`. + + .. container:: topic + + **Detecting Changes in HSTORE columns when using the ORM** + + For usage with the SQLAlchemy ORM, it may be desirable to combine the + usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary now + part of the :mod:`sqlalchemy.ext.mutable` extension. This extension + will allow "in-place" changes to the dictionary, e.g. addition of new + keys or replacement/removal of existing keys to/from the current + dictionary, to produce events which will be detected by the unit of + work:: + + from sqlalchemy.ext.mutable import MutableDict + + class MyClass(Base): + __tablename__ = 'data_table' + + id = Column(Integer, primary_key=True) + data = Column(MutableDict.as_mutable(HSTORE)) + + my_object = session.query(MyClass).one() + + # in-place mutation, requires Mutable extension + # in order for the ORM to detect + my_object.data['some_key'] = 'some value' + + session.commit() + + When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM + will not be alerted to any changes to the contents of an existing + dictionary, unless that dictionary value is re-assigned to the + HSTORE-attribute itself, thus generating a change event. + + .. seealso:: + + :class:`.hstore` - render the PostgreSQL ``hstore()`` function. + + + """ + + __visit_name__ = "HSTORE" + hashable = False + text_type = sqltypes.Text() + + def __init__(self, text_type=None): + """Construct a new :class:`.HSTORE`. + + :param text_type: the type that should be used for indexed values. + Defaults to :class:`_types.Text`. + + """ + if text_type is not None: + self.text_type = text_type + + class Comparator( + sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator + ): + """Define comparison operations for :class:`.HSTORE`.""" + + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in jsonb""" + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in jsonb""" + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument jsonb expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def _setup_getitem(self, index): + return GETITEM, index, self.type.text_type + + def defined(self, key): + """Boolean expression. Test for presence of a non-NULL value for + the key. Note that the key may be a SQLA expression. + """ + return _HStoreDefinedFunction(self.expr, key) + + def delete(self, key): + """HStore expression. Returns the contents of this hstore with the + given key deleted. Note that the key may be a SQLA expression. + """ + if isinstance(key, dict): + key = _serialize_hstore(key) + return _HStoreDeleteFunction(self.expr, key) + + def slice(self, array): + """HStore expression. Returns a subset of an hstore defined by + array of keys. + """ + return _HStoreSliceFunction(self.expr, array) + + def keys(self): + """Text array expression. Returns array of keys.""" + return _HStoreKeysFunction(self.expr) + + def vals(self): + """Text array expression. Returns array of values.""" + return _HStoreValsFunction(self.expr) + + def array(self): + """Text array expression. Returns array of alternating keys and + values. + """ + return _HStoreArrayFunction(self.expr) + + def matrix(self): + """Text array expression. Returns array of [key, value] pairs.""" + return _HStoreMatrixFunction(self.expr) + + comparator_factory = Comparator + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, dict): + return _serialize_hstore(value) + else: + return value + + return process + + def result_processor(self, dialect, coltype): + def process(value): + if value is not None: + return _parse_hstore(value) + else: + return value + + return process + + +class hstore(sqlfunc.GenericFunction): + """Construct an hstore value within a SQL expression using the + PostgreSQL ``hstore()`` function. + + The :class:`.hstore` function accepts one or two arguments as described + in the PostgreSQL documentation. + + E.g.:: + + from sqlalchemy.dialects.postgresql import array, hstore + + select(hstore('key1', 'value1')) + + select( + hstore( + array(['key1', 'key2', 'key3']), + array(['value1', 'value2', 'value3']) + ) + ) + + .. seealso:: + + :class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype. + + """ + + type = HSTORE + name = "hstore" + inherit_cache = True + + +class _HStoreDefinedFunction(sqlfunc.GenericFunction): + type = sqltypes.Boolean + name = "defined" + inherit_cache = True + + +class _HStoreDeleteFunction(sqlfunc.GenericFunction): + type = HSTORE + name = "delete" + inherit_cache = True + + +class _HStoreSliceFunction(sqlfunc.GenericFunction): + type = HSTORE + name = "slice" + inherit_cache = True + + +class _HStoreKeysFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "akeys" + inherit_cache = True + + +class _HStoreValsFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "avals" + inherit_cache = True + + +class _HStoreArrayFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "hstore_to_array" + inherit_cache = True + + +class _HStoreMatrixFunction(sqlfunc.GenericFunction): + type = ARRAY(sqltypes.Text) + name = "hstore_to_matrix" + inherit_cache = True + + +# +# parsing. note that none of this is used with the psycopg2 backend, +# which provides its own native extensions. +# + +# My best guess at the parsing rules of hstore literals, since no formal +# grammar is given. This is mostly reverse engineered from PG's input parser +# behavior. +HSTORE_PAIR_RE = re.compile( + r""" +( + "(?P (\\ . | [^"])* )" # Quoted key +) +[ ]* => [ ]* # Pair operator, optional adjoining whitespace +( + (?P NULL ) # NULL value + | "(?P (\\ . | [^"])* )" # Quoted value +) +""", + re.VERBOSE, +) + +HSTORE_DELIMITER_RE = re.compile( + r""" +[ ]* , [ ]* +""", + re.VERBOSE, +) + + +def _parse_error(hstore_str, pos): + """format an unmarshalling error.""" + + ctx = 20 + hslen = len(hstore_str) + + parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)] + residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)] + + if len(parsed_tail) > ctx: + parsed_tail = "[...]" + parsed_tail[1:] + if len(residual) > ctx: + residual = residual[:-1] + "[...]" + + return "After %r, could not parse residual at position %d: %r" % ( + parsed_tail, + pos, + residual, + ) + + +def _parse_hstore(hstore_str): + """Parse an hstore from its literal string representation. + + Attempts to approximate PG's hstore input parsing rules as closely as + possible. Although currently this is not strictly necessary, since the + current implementation of hstore's output syntax is stricter than what it + accepts as input, the documentation makes no guarantees that will always + be the case. + + + + """ + result = {} + pos = 0 + pair_match = HSTORE_PAIR_RE.match(hstore_str) + + while pair_match is not None: + key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\") + if pair_match.group("value_null"): + value = None + else: + value = ( + pair_match.group("value") + .replace(r"\"", '"') + .replace("\\\\", "\\") + ) + result[key] = value + + pos += pair_match.end() + + delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:]) + if delim_match is not None: + pos += delim_match.end() + + pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:]) + + if pos != len(hstore_str): + raise ValueError(_parse_error(hstore_str, pos)) + + return result + + +def _serialize_hstore(val): + """Serialize a dictionary into an hstore literal. Keys and values must + both be strings (except None for values). + + """ + + def esc(s, position): + if position == "value" and s is None: + return "NULL" + elif isinstance(s, str): + return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"") + else: + raise ValueError( + "%r in %s position is not a string." % (s, position) + ) + + return ", ".join( + "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items() + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/json.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/json.py new file mode 100644 index 0000000..3790fa3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/json.py @@ -0,0 +1,325 @@ +# dialects/postgresql/json.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from .array import ARRAY +from .array import array as _pg_array +from .operators import ASTEXT +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import DELETE_PATH +from .operators import HAS_ALL +from .operators import HAS_ANY +from .operators import HAS_KEY +from .operators import JSONPATH_ASTEXT +from .operators import PATH_EXISTS +from .operators import PATH_MATCH +from ... import types as sqltypes +from ...sql import cast + +__all__ = ("JSON", "JSONB") + + +class JSONPathType(sqltypes.JSON.JSONPathType): + def _processor(self, dialect, super_proc): + def process(value): + if isinstance(value, str): + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + return value + elif value: + # If it's already a string assume that it's in json path + # format. This allows using cast with json paths literals + value = "{%s}" % (", ".join(map(str, value))) + else: + value = "{}" + if super_proc: + value = super_proc(value) + return value + + return process + + def bind_processor(self, dialect): + return self._processor(dialect, self.string_bind_processor(dialect)) + + def literal_processor(self, dialect): + return self._processor(dialect, self.string_literal_processor(dialect)) + + +class JSONPATH(JSONPathType): + """JSON Path Type. + + This is usually required to cast literal values to json path when using + json search like function, such as ``jsonb_path_query_array`` or + ``jsonb_path_exists``:: + + stmt = sa.select( + sa.func.jsonb_path_query_array( + table.c.jsonb_col, cast("$.address.id", JSONPATH) + ) + ) + + """ + + __visit_name__ = "JSONPATH" + + +class JSON(sqltypes.JSON): + """Represent the PostgreSQL JSON type. + + :class:`_postgresql.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a PostgreSQL backend, + however base :class:`_types.JSON` datatype does not provide Python + accessors for PostgreSQL-specific comparison methods such as + :meth:`_postgresql.JSON.Comparator.astext`; additionally, to use + PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should + be used explicitly. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The operators provided by the PostgreSQL version of :class:`_types.JSON` + include: + + * Index operations (the ``->`` operator):: + + data_table.c.data['some key'] + + data_table.c.data[5] + + + * Index operations returning text (the ``->>`` operator):: + + data_table.c.data['some key'].astext == 'some value' + + Note that equivalent functionality is available via the + :attr:`.JSON.Comparator.as_string` accessor. + + * Index operations with CAST + (equivalent to ``CAST(col ->> ['some key'] AS )``):: + + data_table.c.data['some key'].astext.cast(Integer) == 5 + + Note that equivalent functionality is available via the + :attr:`.JSON.Comparator.as_integer` and similar accessors. + + * Path index operations (the ``#>`` operator):: + + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')] + + * Path index operations returning text (the ``#>>`` operator):: + + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value' + + Index operations return an expression object whose type defaults to + :class:`_types.JSON` by default, + so that further JSON-oriented instructions + may be called upon the result type. + + Custom serializers and deserializers are specified at the dialect level, + that is using :func:`_sa.create_engine`. The reason for this is that when + using psycopg2, the DBAPI only allows serializers at the per-cursor + or per-connection level. E.g.:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + json_serializer=my_serialize_fn, + json_deserializer=my_deserialize_fn + ) + + When using the psycopg2 dialect, the json_deserializer is registered + against the database using ``psycopg2.extras.register_default_json``. + + .. seealso:: + + :class:`_types.JSON` - Core level JSON type + + :class:`_postgresql.JSONB` + + """ # noqa + + astext_type = sqltypes.Text() + + def __init__(self, none_as_null=False, astext_type=None): + """Construct a :class:`_types.JSON` type. + + :param none_as_null: if True, persist the value ``None`` as a + SQL NULL value, not the JSON encoding of ``null``. Note that + when this flag is False, the :func:`.null` construct can still + be used to persist a NULL value:: + + from sqlalchemy import null + conn.execute(table.insert(), {"data": null()}) + + .. seealso:: + + :attr:`_types.JSON.NULL` + + :param astext_type: the type to use for the + :attr:`.JSON.Comparator.astext` + accessor on indexed attributes. Defaults to :class:`_types.Text`. + + """ + super().__init__(none_as_null=none_as_null) + if astext_type is not None: + self.astext_type = astext_type + + class Comparator(sqltypes.JSON.Comparator): + """Define comparison operations for :class:`_types.JSON`.""" + + @property + def astext(self): + """On an indexed expression, use the "astext" (e.g. "->>") + conversion when rendered in SQL. + + E.g.:: + + select(data_table.c.data['some key'].astext) + + .. seealso:: + + :meth:`_expression.ColumnElement.cast` + + """ + if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): + return self.expr.left.operate( + JSONPATH_ASTEXT, + self.expr.right, + result_type=self.type.astext_type, + ) + else: + return self.expr.left.operate( + ASTEXT, self.expr.right, result_type=self.type.astext_type + ) + + comparator_factory = Comparator + + +class JSONB(JSON): + """Represent the PostgreSQL JSONB type. + + The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data, + e.g.:: + + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', JSONB) + ) + + with engine.connect() as conn: + conn.execute( + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} + ) + + The :class:`_postgresql.JSONB` type includes all operations provided by + :class:`_types.JSON`, including the same behaviors for indexing + operations. + It also adds additional operators specific to JSONB, including + :meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`, + :meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`, + :meth:`.JSONB.Comparator.contained_by`, + :meth:`.JSONB.Comparator.delete_path`, + :meth:`.JSONB.Comparator.path_exists` and + :meth:`.JSONB.Comparator.path_match`. + + Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB` + type does not detect + in-place changes when used with the ORM, unless the + :mod:`sqlalchemy.ext.mutable` extension is used. + + Custom serializers and deserializers + are shared with the :class:`_types.JSON` class, + using the ``json_serializer`` + and ``json_deserializer`` keyword arguments. These must be specified + at the dialect level using :func:`_sa.create_engine`. When using + psycopg2, the serializers are associated with the jsonb type using + ``psycopg2.extras.register_default_jsonb`` on a per-connection basis, + in the same way that ``psycopg2.extras.register_default_json`` is used + to register these handlers with the json type. + + .. seealso:: + + :class:`_types.JSON` + + """ + + __visit_name__ = "JSONB" + + class Comparator(JSON.Comparator): + """Define comparison operations for :class:`_types.JSON`.""" + + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. + """ + return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) + + def has_all(self, other): + """Boolean expression. Test for presence of all keys in jsonb""" + return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) + + def has_any(self, other): + """Boolean expression. Test for presence of any key in jsonb""" + return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) + + def contains(self, other, **kwargs): + """Boolean expression. Test if keys (or array) are a superset + of/contained the keys of the argument jsonb expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if keys are a proper subset of the + keys of the argument jsonb expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def delete_path(self, array): + """JSONB expression. Deletes field or array element specified in + the argument array. + + The input may be a list of strings that will be coerced to an + ``ARRAY`` or an instance of :meth:`_postgres.array`. + + .. versionadded:: 2.0 + """ + if not isinstance(array, _pg_array): + array = _pg_array(array) + right_side = cast(array, ARRAY(sqltypes.TEXT)) + return self.operate(DELETE_PATH, right_side, result_type=JSONB) + + def path_exists(self, other): + """Boolean expression. Test for presence of item given by the + argument JSONPath expression. + + .. versionadded:: 2.0 + """ + return self.operate( + PATH_EXISTS, other, result_type=sqltypes.Boolean + ) + + def path_match(self, other): + """Boolean expression. Test if JSONPath predicate given by the + argument JSONPath expression matches. + + Only the first item of the result is taken into account. + + .. versionadded:: 2.0 + """ + return self.operate( + PATH_MATCH, other, result_type=sqltypes.Boolean + ) + + comparator_factory = Comparator diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/named_types.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/named_types.py new file mode 100644 index 0000000..16e5c86 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/named_types.py @@ -0,0 +1,509 @@ +# dialects/postgresql/named_types.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from ... import schema +from ... import util +from ...sql import coercions +from ...sql import elements +from ...sql import roles +from ...sql import sqltypes +from ...sql import type_api +from ...sql.base import _NoArg +from ...sql.ddl import InvokeCreateDDLBase +from ...sql.ddl import InvokeDropDDLBase + +if TYPE_CHECKING: + from ...sql._typing import _TypeEngineArgument + + +class NamedType(sqltypes.TypeEngine): + """Base for named types.""" + + __abstract__ = True + DDLGenerator: Type[NamedTypeGenerator] + DDLDropper: Type[NamedTypeDropper] + create_type: bool + + def create(self, bind, checkfirst=True, **kw): + """Emit ``CREATE`` DDL for this type. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst) + + def drop(self, bind, checkfirst=True, **kw): + """Emit ``DROP`` DDL for this type. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst) + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named type is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + type_name = f"pg_{self.__visit_name__}" + if type_name in ddl_runner.memo: + existing = ddl_runner.memo[type_name] + else: + existing = ddl_runner.memo[type_name] = set() + present = (self.schema, self.name) in existing + existing.add((self.schema, self.name)) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + ) and not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + +class NamedTypeGenerator(InvokeCreateDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_type(self, type_): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(type_) + return not self.connection.dialect.has_type( + self.connection, type_.name, schema=effective_schema + ) + + +class NamedTypeDropper(InvokeDropDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_type(self, type_): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(type_) + return self.connection.dialect.has_type( + self.connection, type_.name, schema=effective_schema + ) + + +class EnumGenerator(NamedTypeGenerator): + def visit_enum(self, enum): + if not self._can_create_type(enum): + return + + with self.with_ddl_events(enum): + self.connection.execute(CreateEnumType(enum)) + + +class EnumDropper(NamedTypeDropper): + def visit_enum(self, enum): + if not self._can_drop_type(enum): + return + + with self.with_ddl_events(enum): + self.connection.execute(DropEnumType(enum)) + + +class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): + """PostgreSQL ENUM type. + + This is a subclass of :class:`_types.Enum` which includes + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`_types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`_types.Enum` or :class:`_postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`_schema.Table.create` and + :meth:`_schema.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`_types.Enum` or + :class:`_postgresql.ENUM` independently, and associate it with the + :class:`_schema.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + """ + + native_enum = True + DDLGenerator = EnumGenerator + DDLDropper = EnumDropper + + def __init__( + self, + *enums, + name: Union[str, _NoArg, None] = _NoArg.NO_ARG, + create_type: bool = True, + **kw, + ): + """Construct an :class:`_postgresql.ENUM`. + + Arguments are the same as that of + :class:`_types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + """ + native_enum = kw.pop("native_enum", None) + if native_enum is False: + util.warn( + "the native_enum flag does not apply to the " + "sqlalchemy.dialects.postgresql.ENUM datatype; this type " + "always refers to ENUM. Use sqlalchemy.types.Enum for " + "non-native enum." + ) + self.create_type = create_type + if name is not _NoArg.NO_ARG: + kw["name"] = name + super().__init__(*enums, **kw) + + def coerce_compared_value(self, op, value): + super_coerced_type = super().coerce_compared_value(op, value) + if ( + super_coerced_type._type_affinity + is type_api.STRINGTYPE._type_affinity + ): + return self + else: + return super_coerced_type + + @classmethod + def __test_init__(cls): + return cls(name="name") + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + kw.setdefault("_adapted_from", impl) + if type_api._is_native_for_emulated(impl.__class__): + kw.setdefault("create_type", impl.create_type) + + return cls(**kw) + + def create(self, bind=None, checkfirst=True): + """Emit ``CREATE TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + super().create(bind, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Emit ``DROP TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL DROP TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + super().drop(bind, checkfirst=checkfirst) + + def get_dbapi_type(self, dbapi): + """dont return dbapi.STRING for ENUM in PostgreSQL, since that's + a different type""" + + return None + + +class DomainGenerator(NamedTypeGenerator): + def visit_DOMAIN(self, domain): + if not self._can_create_type(domain): + return + with self.with_ddl_events(domain): + self.connection.execute(CreateDomainType(domain)) + + +class DomainDropper(NamedTypeDropper): + def visit_DOMAIN(self, domain): + if not self._can_drop_type(domain): + return + + with self.with_ddl_events(domain): + self.connection.execute(DropDomainType(domain)) + + +class DOMAIN(NamedType, sqltypes.SchemaType): + r"""Represent the DOMAIN PostgreSQL type. + + A domain is essentially a data type with optional constraints + that restrict the allowed set of values. E.g.:: + + PositiveInt = DOMAIN( + "pos_int", Integer, check="VALUE > 0", not_null=True + ) + + UsPostalCode = DOMAIN( + "us_postal_code", + Text, + check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'" + ) + + See the `PostgreSQL documentation`__ for additional details + + __ https://www.postgresql.org/docs/current/sql-createdomain.html + + .. versionadded:: 2.0 + + """ + + DDLGenerator = DomainGenerator + DDLDropper = DomainDropper + + __visit_name__ = "DOMAIN" + + def __init__( + self, + name: str, + data_type: _TypeEngineArgument[Any], + *, + collation: Optional[str] = None, + default: Union[elements.TextClause, str, None] = None, + constraint_name: Optional[str] = None, + not_null: Optional[bool] = None, + check: Union[elements.TextClause, str, None] = None, + create_type: bool = True, + **kw: Any, + ): + """ + Construct a DOMAIN. + + :param name: the name of the domain + :param data_type: The underlying data type of the domain. + This can include array specifiers. + :param collation: An optional collation for the domain. + If no collation is specified, the underlying data type's default + collation is used. The underlying type must be collatable if + ``collation`` is specified. + :param default: The DEFAULT clause specifies a default value for + columns of the domain data type. The default should be a string + or a :func:`_expression.text` value. + If no default value is specified, then the default value is + the null value. + :param constraint_name: An optional name for a constraint. + If not specified, the backend generates a name. + :param not_null: Values of this domain are prevented from being null. + By default domain are allowed to be null. If not specified + no nullability clause will be emitted. + :param check: CHECK clause specify integrity constraint or test + which values of the domain must satisfy. A constraint must be + an expression producing a Boolean result that can use the key + word VALUE to refer to the value being tested. + Differently from PostgreSQL, only a single check clause is + currently allowed in SQLAlchemy. + :param schema: optional schema name + :param metadata: optional :class:`_schema.MetaData` object which + this :class:`_postgresql.DOMAIN` will be directly associated + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be emitted, after optionally + checking for the presence of the type, when the parent table is + being created; and additionally that ``DROP TYPE`` is called + when the table is dropped. + + """ + self.data_type = type_api.to_instance(data_type) + self.default = default + self.collation = collation + self.constraint_name = constraint_name + self.not_null = bool(not_null) + if check is not None: + check = coercions.expect(roles.DDLExpressionRole, check) + self.check = check + self.create_type = create_type + super().__init__(name=name, **kw) + + @classmethod + def __test_init__(cls): + return cls("name", sqltypes.Integer) + + def adapt(self, impl, **kw): + if self.default: + kw["default"] = self.default + if self.constraint_name is not None: + kw["constraint_name"] = self.constraint_name + if self.not_null: + kw["not_null"] = self.not_null + if self.check is not None: + kw["check"] = str(self.check) + if self.create_type: + kw["create_type"] = self.create_type + + return super().adapt(impl, **kw) + + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" + + +class CreateDomainType(schema._CreateDropBase): + """Represent a CREATE DOMAIN statement.""" + + __visit_name__ = "create_domain_type" + + +class DropDomainType(schema._CreateDropBase): + """Represent a DROP DOMAIN statement.""" + + __visit_name__ = "drop_domain_type" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/operators.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/operators.py new file mode 100644 index 0000000..53e175f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/operators.py @@ -0,0 +1,129 @@ +# dialects/postgresql/operators.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from ...sql import operators + + +_getitem_precedence = operators._PRECEDENCE[operators.json_getitem_op] +_eq_precedence = operators._PRECEDENCE[operators.eq] + +# JSON + JSONB +ASTEXT = operators.custom_op( + "->>", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +JSONPATH_ASTEXT = operators.custom_op( + "#>>", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +# JSONB + HSTORE +HAS_KEY = operators.custom_op( + "?", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +HAS_ALL = operators.custom_op( + "?&", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +HAS_ANY = operators.custom_op( + "?|", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +# JSONB +DELETE_PATH = operators.custom_op( + "#-", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +PATH_EXISTS = operators.custom_op( + "@?", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +PATH_MATCH = operators.custom_op( + "@@", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +# JSONB + ARRAY + HSTORE + RANGE +CONTAINS = operators.custom_op( + "@>", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +CONTAINED_BY = operators.custom_op( + "<@", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +# ARRAY + RANGE +OVERLAP = operators.custom_op( + "&&", + precedence=_eq_precedence, + is_comparison=True, +) + +# RANGE +STRICTLY_LEFT_OF = operators.custom_op( + "<<", precedence=_eq_precedence, is_comparison=True +) + +STRICTLY_RIGHT_OF = operators.custom_op( + ">>", precedence=_eq_precedence, is_comparison=True +) + +NOT_EXTEND_RIGHT_OF = operators.custom_op( + "&<", precedence=_eq_precedence, is_comparison=True +) + +NOT_EXTEND_LEFT_OF = operators.custom_op( + "&>", precedence=_eq_precedence, is_comparison=True +) + +ADJACENT_TO = operators.custom_op( + "-|-", precedence=_eq_precedence, is_comparison=True +) + +# HSTORE +GETITEM = operators.custom_op( + "->", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg8000.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg8000.py new file mode 100644 index 0000000..0151be0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg8000.py @@ -0,0 +1,662 @@ +# dialects/postgresql/pg8000.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+pg8000 + :name: pg8000 + :dbapi: pg8000 + :connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/pg8000/ + +.. versionchanged:: 1.4 The pg8000 dialect has been updated for version + 1.16.6 and higher, and is again part of SQLAlchemy's continuous integration + with full feature support. + +.. _pg8000_unicode: + +Unicode +------- + +pg8000 will encode / decode string values between it and the server using the +PostgreSQL ``client_encoding`` parameter; by default this is the value in +the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. +Typically, this can be changed to ``utf-8``, as a more useful default:: + + #client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + +The ``client_encoding`` can be overridden for a session by executing the SQL: + +SET CLIENT_ENCODING TO 'utf8'; + +SQLAlchemy will execute this SQL on all new connections based on the value +passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter:: + + engine = create_engine( + "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') + +.. _pg8000_ssl: + +SSL Connections +--------------- + +pg8000 accepts a Python ``SSLContext`` object which may be specified using the +:paramref:`_sa.create_engine.connect_args` dictionary:: + + import ssl + ssl_context = ssl.create_default_context() + engine = sa.create_engine( + "postgresql+pg8000://scott:tiger@192.168.0.199/test", + connect_args={"ssl_context": ssl_context}, + ) + +If the server uses an automatically-generated certificate that is self-signed +or does not match the host name (as seen from the client), it may also be +necessary to disable hostname checking:: + + import ssl + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + engine = sa.create_engine( + "postgresql+pg8000://scott:tiger@192.168.0.199/test", + connect_args={"ssl_context": ssl_context}, + ) + +.. _pg8000_isolation_level: + +pg8000 Transaction Isolation Level +------------------------------------- + +The pg8000 dialect offers the same isolation level settings as that +of the :ref:`psycopg2 ` dialect: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`psycopg2_isolation_level` + + +""" # noqa +import decimal +import re + +from . import ranges +from .array import ARRAY as PGARRAY +from .base import _DECIMAL_TYPES +from .base import _FLOAT_TYPES +from .base import _INT_TYPES +from .base import ENUM +from .base import INTERVAL +from .base import PGCompiler +from .base import PGDialect +from .base import PGExecutionContext +from .base import PGIdentifierPreparer +from .json import JSON +from .json import JSONB +from .json import JSONPathType +from .pg_catalog import _SpaceVector +from .pg_catalog import OIDVECTOR +from .types import CITEXT +from ... import exc +from ... import util +from ...engine import processors +from ...sql import sqltypes +from ...sql.elements import quoted_name + + +class _PGString(sqltypes.String): + render_bind_cast = True + + +class _PGNumeric(sqltypes.Numeric): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + if self.asdecimal: + if coltype in _FLOAT_TYPES: + return processors.to_decimal_processor_factory( + decimal.Decimal, self._effective_decimal_return_scale + ) + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + # pg8000 returns Decimal natively for 1700 + return None + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + else: + if coltype in _FLOAT_TYPES: + # pg8000 returns float natively for 701 + return None + elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: + return processors.to_float + else: + raise exc.InvalidRequestError( + "Unknown PG numeric type: %d" % coltype + ) + + +class _PGFloat(_PGNumeric, sqltypes.Float): + __visit_name__ = "float" + render_bind_cast = True + + +class _PGNumericNoBind(_PGNumeric): + def bind_processor(self, dialect): + return None + + +class _PGJSON(JSON): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + render_bind_cast = True + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONIndexType(sqltypes.JSON.JSONIndexType): + def get_dbapi_type(self, dbapi): + raise NotImplementedError("should not be here") + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class _PGJSONPathType(JSONPathType): + pass + + # DBAPI type 1009 + + +class _PGEnum(ENUM): + def get_dbapi_type(self, dbapi): + return dbapi.UNKNOWN + + +class _PGInterval(INTERVAL): + render_bind_cast = True + + def get_dbapi_type(self, dbapi): + return dbapi.INTERVAL + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return _PGInterval(precision=interval.second_precision) + + +class _PGTimeStamp(sqltypes.DateTime): + render_bind_cast = True + + +class _PGDate(sqltypes.Date): + render_bind_cast = True + + +class _PGTime(sqltypes.Time): + render_bind_cast = True + + +class _PGInteger(sqltypes.Integer): + render_bind_cast = True + + +class _PGSmallInteger(sqltypes.SmallInteger): + render_bind_cast = True + + +class _PGNullType(sqltypes.NullType): + pass + + +class _PGBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class _PGBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class _PGARRAY(PGARRAY): + render_bind_cast = True + + +class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + +class _Pg8000Range(ranges.AbstractSingleRangeImpl): + def bind_processor(self, dialect): + pg8000_Range = dialect.dbapi.Range + + def to_range(value): + if isinstance(value, ranges.Range): + value = pg8000_Range( + value.lower, value.upper, value.bounds, value.empty + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = ranges.Range( + value.lower, + value.upper, + bounds=value.bounds, + empty=value.is_empty, + ) + return value + + return to_range + + +class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl): + def bind_processor(self, dialect): + pg8000_Range = dialect.dbapi.Range + + def to_multirange(value): + if isinstance(value, list): + mr = [] + for v in value: + if isinstance(v, ranges.Range): + mr.append( + pg8000_Range(v.lower, v.upper, v.bounds, v.empty) + ) + else: + mr.append(v) + return mr + else: + return value + + return to_multirange + + def result_processor(self, dialect, coltype): + def to_multirange(value): + if value is None: + return None + else: + return ranges.MultiRange( + ranges.Range( + v.lower, v.upper, bounds=v.bounds, empty=v.is_empty + ) + for v in value + ) + + return to_multirange + + +_server_side_id = util.counter() + + +class PGExecutionContext_pg8000(PGExecutionContext): + def create_server_side_cursor(self): + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) + return ServerSideCursor(self._dbapi_connection.cursor(), ident) + + def pre_exec(self): + if not self.compiled: + return + + +class ServerSideCursor: + server_side = True + + def __init__(self, cursor, ident): + self.ident = ident + self.cursor = cursor + + @property + def connection(self): + return self.cursor.connection + + @property + def rowcount(self): + return self.cursor.rowcount + + @property + def description(self): + return self.cursor.description + + def execute(self, operation, args=(), stream=None): + op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation + self.cursor.execute(op, args, stream=stream) + return self + + def executemany(self, operation, param_sets): + self.cursor.executemany(operation, param_sets) + return self + + def fetchone(self): + self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident) + return self.cursor.fetchone() + + def fetchmany(self, num=None): + if num is None: + return self.fetchall() + else: + self.cursor.execute( + "FETCH FORWARD " + str(int(num)) + " FROM " + self.ident + ) + return self.cursor.fetchall() + + def fetchall(self): + self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident) + return self.cursor.fetchall() + + def close(self): + self.cursor.execute("CLOSE " + self.ident) + self.cursor.close() + + def setinputsizes(self, *sizes): + self.cursor.setinputsizes(*sizes) + + def setoutputsize(self, size, column=None): + pass + + +class PGCompiler_pg8000(PGCompiler): + def visit_mod_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + + +class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): + def __init__(self, *args, **kwargs): + PGIdentifierPreparer.__init__(self, *args, **kwargs) + self._double_percents = False + + +class PGDialect_pg8000(PGDialect): + driver = "pg8000" + supports_statement_cache = True + + supports_unicode_statements = True + + supports_unicode_binds = True + + default_paramstyle = "format" + supports_sane_multi_rowcount = True + execution_ctx_cls = PGExecutionContext_pg8000 + statement_compiler = PGCompiler_pg8000 + preparer = PGIdentifierPreparer_pg8000 + supports_server_side_cursors = True + + render_bind_cast = True + + # reversed as of pg8000 1.16.6. 1.16.5 and lower + # are no longer compatible + description_encoding = None + # description_encoding = "use_encoding" + + colspecs = util.update_copy( + PGDialect.colspecs, + { + sqltypes.String: _PGString, + sqltypes.Numeric: _PGNumericNoBind, + sqltypes.Float: _PGFloat, + sqltypes.JSON: _PGJSON, + sqltypes.Boolean: _PGBoolean, + sqltypes.NullType: _PGNullType, + JSONB: _PGJSONB, + CITEXT: CITEXT, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIndexType: _PGJSONIndexType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Date: _PGDate, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + sqltypes.Enum: _PGEnum, + sqltypes.ARRAY: _PGARRAY, + OIDVECTOR: _PGOIDVECTOR, + ranges.INT4RANGE: _Pg8000Range, + ranges.INT8RANGE: _Pg8000Range, + ranges.NUMRANGE: _Pg8000Range, + ranges.DATERANGE: _Pg8000Range, + ranges.TSRANGE: _Pg8000Range, + ranges.TSTZRANGE: _Pg8000Range, + ranges.INT4MULTIRANGE: _Pg8000MultiRange, + ranges.INT8MULTIRANGE: _Pg8000MultiRange, + ranges.NUMMULTIRANGE: _Pg8000MultiRange, + ranges.DATEMULTIRANGE: _Pg8000MultiRange, + ranges.TSMULTIRANGE: _Pg8000MultiRange, + ranges.TSTZMULTIRANGE: _Pg8000MultiRange, + }, + ) + + def __init__(self, client_encoding=None, **kwargs): + PGDialect.__init__(self, **kwargs) + self.client_encoding = client_encoding + + if self._dbapi_version < (1, 16, 6): + raise NotImplementedError("pg8000 1.16.6 or greater is required") + + if self._native_inet_types: + raise NotImplementedError( + "The pg8000 dialect does not fully implement " + "ipaddress type handling; INET is supported by default, " + "CIDR is not" + ) + + @util.memoized_property + def _dbapi_version(self): + if self.dbapi and hasattr(self.dbapi, "__version__"): + return tuple( + [ + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) + else: + return (99, 99, 99) + + @classmethod + def import_dbapi(cls): + return __import__("pg8000") + + def create_connect_args(self, url): + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) + opts.update(url.query) + return ([], opts) + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.InterfaceError) and "network error" in str( + e + ): + # new as of pg8000 1.19.0 for broken connections + return True + + # connection was closed normally + return "connection is closed" in str(e) + + def get_isolation_level_values(self, dbapi_connection): + return ( + "AUTOCOMMIT", + "READ COMMITTED", + "READ UNCOMMITTED", + "REPEATABLE READ", + "SERIALIZABLE", + ) + + def set_isolation_level(self, dbapi_connection, level): + level = level.replace("_", " ") + + if level == "AUTOCOMMIT": + dbapi_connection.autocommit = True + else: + dbapi_connection.autocommit = False + cursor = dbapi_connection.cursor() + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + f"ISOLATION LEVEL {level}" + ) + cursor.execute("COMMIT") + cursor.close() + + def set_readonly(self, connection, value): + cursor = connection.cursor() + try: + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION %s" + % ("READ ONLY" if value else "READ WRITE") + ) + cursor.execute("COMMIT") + finally: + cursor.close() + + def get_readonly(self, connection): + cursor = connection.cursor() + try: + cursor.execute("show transaction_read_only") + val = cursor.fetchone()[0] + finally: + cursor.close() + + return val == "on" + + def set_deferrable(self, connection, value): + cursor = connection.cursor() + try: + cursor.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION %s" + % ("DEFERRABLE" if value else "NOT DEFERRABLE") + ) + cursor.execute("COMMIT") + finally: + cursor.close() + + def get_deferrable(self, connection): + cursor = connection.cursor() + try: + cursor.execute("show transaction_deferrable") + val = cursor.fetchone()[0] + finally: + cursor.close() + + return val == "on" + + def _set_client_encoding(self, dbapi_connection, client_encoding): + cursor = dbapi_connection.cursor() + cursor.execute( + f"""SET CLIENT_ENCODING TO '{ + client_encoding.replace("'", "''") + }'""" + ) + cursor.execute("COMMIT") + cursor.close() + + def do_begin_twophase(self, connection, xid): + connection.connection.tpc_begin((0, xid, "")) + + def do_prepare_twophase(self, connection, xid): + connection.connection.tpc_prepare() + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_rollback((0, xid, "")) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_commit((0, xid, "")) + + def do_recover_twophase(self, connection): + return [row[1] for row in connection.connection.tpc_recover()] + + def on_connect(self): + fns = [] + + def on_connect(conn): + conn.py_types[quoted_name] = conn.py_types[str] + + fns.append(on_connect) + + if self.client_encoding is not None: + + def on_connect(conn): + self._set_client_encoding(conn, self.client_encoding) + + fns.append(on_connect) + + if self._native_inet_types is False: + + def on_connect(conn): + # inet + conn.register_in_adapter(869, lambda s: s) + + # cidr + conn.register_in_adapter(650, lambda s: s) + + fns.append(on_connect) + + if self._json_deserializer: + + def on_connect(conn): + # json + conn.register_in_adapter(114, self._json_deserializer) + + # jsonb + conn.register_in_adapter(3802, self._json_deserializer) + + fns.append(on_connect) + + if len(fns) > 0: + + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + else: + return None + + @util.memoized_property + def _dialect_specific_select_one(self): + return ";" + + +dialect = PGDialect_pg8000 diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py new file mode 100644 index 0000000..9b5562c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -0,0 +1,300 @@ +# dialects/postgresql/pg_catalog.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .array import ARRAY +from .types import OID +from .types import REGCLASS +from ... import Column +from ... import func +from ... import MetaData +from ... import Table +from ...types import BigInteger +from ...types import Boolean +from ...types import CHAR +from ...types import Float +from ...types import Integer +from ...types import SmallInteger +from ...types import String +from ...types import Text +from ...types import TypeDecorator + + +# types +class NAME(TypeDecorator): + impl = String(64, collation="C") + cache_ok = True + + +class PG_NODE_TREE(TypeDecorator): + impl = Text(collation="C") + cache_ok = True + + +class INT2VECTOR(TypeDecorator): + impl = ARRAY(SmallInteger) + cache_ok = True + + +class OIDVECTOR(TypeDecorator): + impl = ARRAY(OID) + cache_ok = True + + +class _SpaceVector: + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return value + return [int(p) for p in value.split(" ")] + + return process + + +REGPROC = REGCLASS # seems an alias + +# functions +_pg_cat = func.pg_catalog +quote_ident = _pg_cat.quote_ident +pg_table_is_visible = _pg_cat.pg_table_is_visible +pg_type_is_visible = _pg_cat.pg_type_is_visible +pg_get_viewdef = _pg_cat.pg_get_viewdef +pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence +format_type = _pg_cat.format_type +pg_get_expr = _pg_cat.pg_get_expr +pg_get_constraintdef = _pg_cat.pg_get_constraintdef +pg_get_indexdef = _pg_cat.pg_get_indexdef + +# constants +RELKINDS_TABLE_NO_FOREIGN = ("r", "p") +RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",) +RELKINDS_VIEW = ("v",) +RELKINDS_MAT_VIEW = ("m",) +RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW + +# tables +pg_catalog_meta = MetaData(schema="pg_catalog") + +pg_namespace = Table( + "pg_namespace", + pg_catalog_meta, + Column("oid", OID), + Column("nspname", NAME), + Column("nspowner", OID), +) + +pg_class = Table( + "pg_class", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("relname", NAME), + Column("relnamespace", OID), + Column("reltype", OID), + Column("reloftype", OID), + Column("relowner", OID), + Column("relam", OID), + Column("relfilenode", OID), + Column("reltablespace", OID), + Column("relpages", Integer), + Column("reltuples", Float), + Column("relallvisible", Integer, info={"server_version": (9, 2)}), + Column("reltoastrelid", OID), + Column("relhasindex", Boolean), + Column("relisshared", Boolean), + Column("relpersistence", CHAR, info={"server_version": (9, 1)}), + Column("relkind", CHAR), + Column("relnatts", SmallInteger), + Column("relchecks", SmallInteger), + Column("relhasrules", Boolean), + Column("relhastriggers", Boolean), + Column("relhassubclass", Boolean), + Column("relrowsecurity", Boolean), + Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}), + Column("relispopulated", Boolean, info={"server_version": (9, 3)}), + Column("relreplident", CHAR, info={"server_version": (9, 4)}), + Column("relispartition", Boolean, info={"server_version": (10,)}), + Column("relrewrite", OID, info={"server_version": (11,)}), + Column("reloptions", ARRAY(Text)), +) + +pg_type = Table( + "pg_type", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("typname", NAME), + Column("typnamespace", OID), + Column("typowner", OID), + Column("typlen", SmallInteger), + Column("typbyval", Boolean), + Column("typtype", CHAR), + Column("typcategory", CHAR), + Column("typispreferred", Boolean), + Column("typisdefined", Boolean), + Column("typdelim", CHAR), + Column("typrelid", OID), + Column("typelem", OID), + Column("typarray", OID), + Column("typinput", REGPROC), + Column("typoutput", REGPROC), + Column("typreceive", REGPROC), + Column("typsend", REGPROC), + Column("typmodin", REGPROC), + Column("typmodout", REGPROC), + Column("typanalyze", REGPROC), + Column("typalign", CHAR), + Column("typstorage", CHAR), + Column("typnotnull", Boolean), + Column("typbasetype", OID), + Column("typtypmod", Integer), + Column("typndims", Integer), + Column("typcollation", OID, info={"server_version": (9, 1)}), + Column("typdefault", Text), +) + +pg_index = Table( + "pg_index", + pg_catalog_meta, + Column("indexrelid", OID), + Column("indrelid", OID), + Column("indnatts", SmallInteger), + Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}), + Column("indisunique", Boolean), + Column("indnullsnotdistinct", Boolean, info={"server_version": (15,)}), + Column("indisprimary", Boolean), + Column("indisexclusion", Boolean, info={"server_version": (9, 1)}), + Column("indimmediate", Boolean), + Column("indisclustered", Boolean), + Column("indisvalid", Boolean), + Column("indcheckxmin", Boolean), + Column("indisready", Boolean), + Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3 + Column("indisreplident", Boolean), + Column("indkey", INT2VECTOR), + Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1 + Column("indclass", OIDVECTOR), + Column("indoption", INT2VECTOR), + Column("indexprs", PG_NODE_TREE), + Column("indpred", PG_NODE_TREE), +) + +pg_attribute = Table( + "pg_attribute", + pg_catalog_meta, + Column("attrelid", OID), + Column("attname", NAME), + Column("atttypid", OID), + Column("attstattarget", Integer), + Column("attlen", SmallInteger), + Column("attnum", SmallInteger), + Column("attndims", Integer), + Column("attcacheoff", Integer), + Column("atttypmod", Integer), + Column("attbyval", Boolean), + Column("attstorage", CHAR), + Column("attalign", CHAR), + Column("attnotnull", Boolean), + Column("atthasdef", Boolean), + Column("atthasmissing", Boolean, info={"server_version": (11,)}), + Column("attidentity", CHAR, info={"server_version": (10,)}), + Column("attgenerated", CHAR, info={"server_version": (12,)}), + Column("attisdropped", Boolean), + Column("attislocal", Boolean), + Column("attinhcount", Integer), + Column("attcollation", OID, info={"server_version": (9, 1)}), +) + +pg_constraint = Table( + "pg_constraint", + pg_catalog_meta, + Column("oid", OID), # 9.3 + Column("conname", NAME), + Column("connamespace", OID), + Column("contype", CHAR), + Column("condeferrable", Boolean), + Column("condeferred", Boolean), + Column("convalidated", Boolean, info={"server_version": (9, 1)}), + Column("conrelid", OID), + Column("contypid", OID), + Column("conindid", OID), + Column("conparentid", OID, info={"server_version": (11,)}), + Column("confrelid", OID), + Column("confupdtype", CHAR), + Column("confdeltype", CHAR), + Column("confmatchtype", CHAR), + Column("conislocal", Boolean), + Column("coninhcount", Integer), + Column("connoinherit", Boolean, info={"server_version": (9, 2)}), + Column("conkey", ARRAY(SmallInteger)), + Column("confkey", ARRAY(SmallInteger)), +) + +pg_sequence = Table( + "pg_sequence", + pg_catalog_meta, + Column("seqrelid", OID), + Column("seqtypid", OID), + Column("seqstart", BigInteger), + Column("seqincrement", BigInteger), + Column("seqmax", BigInteger), + Column("seqmin", BigInteger), + Column("seqcache", BigInteger), + Column("seqcycle", Boolean), + info={"server_version": (10,)}, +) + +pg_attrdef = Table( + "pg_attrdef", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("adrelid", OID), + Column("adnum", SmallInteger), + Column("adbin", PG_NODE_TREE), +) + +pg_description = Table( + "pg_description", + pg_catalog_meta, + Column("objoid", OID), + Column("classoid", OID), + Column("objsubid", Integer), + Column("description", Text(collation="C")), +) + +pg_enum = Table( + "pg_enum", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("enumtypid", OID), + Column("enumsortorder", Float(), info={"server_version": (9, 1)}), + Column("enumlabel", NAME), +) + +pg_am = Table( + "pg_am", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("amname", NAME), + Column("amhandler", REGPROC, info={"server_version": (9, 6)}), + Column("amtype", CHAR, info={"server_version": (9, 6)}), +) + +pg_collation = Table( + "pg_collation", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("collname", NAME), + Column("collnamespace", OID), + Column("collowner", OID), + Column("collprovider", CHAR, info={"server_version": (10,)}), + Column("collisdeterministic", Boolean, info={"server_version": (12,)}), + Column("collencoding", Integer), + Column("collcollate", Text), + Column("collctype", Text), + Column("colliculocale", Text), + Column("collicurules", Text, info={"server_version": (16,)}), + Column("collversion", Text, info={"server_version": (10,)}), +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/provision.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/provision.py new file mode 100644 index 0000000..a87bb93 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/provision.py @@ -0,0 +1,175 @@ +# dialects/postgresql/provision.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import time + +from ... import exc +from ... import inspect +from ... import text +from ...testing import warn_test_suite +from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables +from ...testing.provision import drop_db +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import prepare_for_drop_tables +from ...testing.provision import set_default_schema_on_connection +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import upsert + + +@create_db.for_db("postgresql") +def _pg_create_db(cfg, eng, ident): + template_db = cfg.options.postgresql_templatedb + + with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn: + if not template_db: + template_db = conn.exec_driver_sql( + "select current_database()" + ).scalar() + + attempt = 0 + while True: + try: + conn.exec_driver_sql( + "CREATE DATABASE %s TEMPLATE %s" % (ident, template_db) + ) + except exc.OperationalError as err: + attempt += 1 + if attempt >= 3: + raise + if "accessed by other users" in str(err): + log.info( + "Waiting to create %s, URI %r, " + "template DB %s is in use sleeping for .5", + ident, + eng.url, + template_db, + ) + time.sleep(0.5) + except: + raise + else: + break + + +@drop_db.for_db("postgresql") +def _pg_drop_db(cfg, eng, ident): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + with conn.begin(): + conn.execute( + text( + "select pg_terminate_backend(pid) from pg_stat_activity " + "where usename=current_user and pid != pg_backend_pid() " + "and datname=:dname" + ), + dict(dname=ident), + ) + conn.exec_driver_sql("DROP DATABASE %s" % ident) + + +@temp_table_keyword_args.for_db("postgresql") +def _postgresql_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@set_default_schema_on_connection.for_db("postgresql") +def _postgresql_set_default_schema_on_connection( + cfg, dbapi_connection, schema_name +): + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor = dbapi_connection.cursor() + cursor.execute("SET SESSION search_path='%s'" % schema_name) + cursor.close() + dbapi_connection.autocommit = existing_autocommit + + +@drop_all_schema_objects_pre_tables.for_db("postgresql") +def drop_all_schema_objects_pre_tables(cfg, eng): + with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn: + for xid in conn.exec_driver_sql( + "select gid from pg_prepared_xacts" + ).scalars(): + conn.execute("ROLLBACK PREPARED '%s'" % xid) + + +@drop_all_schema_objects_post_tables.for_db("postgresql") +def drop_all_schema_objects_post_tables(cfg, eng): + from sqlalchemy.dialects import postgresql + + inspector = inspect(eng) + with eng.begin() as conn: + for enum in inspector.get_enums("*"): + conn.execute( + postgresql.DropEnumType( + postgresql.ENUM(name=enum["name"], schema=enum["schema"]) + ) + ) + + +@prepare_for_drop_tables.for_db("postgresql") +def prepare_for_drop_tables(config, connection): + """Ensure there are no locks on the current username/database.""" + + result = connection.exec_driver_sql( + "select pid, state, wait_event_type, query " + # "select pg_terminate_backend(pid), state, wait_event_type " + "from pg_stat_activity where " + "usename=current_user " + "and datname=current_database() and state='idle in transaction' " + "and pid != pg_backend_pid()" + ) + rows = result.all() # noqa + if rows: + warn_test_suite( + "PostgreSQL may not be able to DROP tables due to " + "idle in transaction: %s" + % ("; ".join(row._mapping["query"] for row in rows)) + ) + + +@upsert.for_db("postgresql") +def _upsert( + cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False +): + from sqlalchemy.dialects.postgresql import insert + + stmt = insert(table) + + table_pk = inspect(table).selectable + + if set_lambda: + stmt = stmt.on_conflict_do_update( + index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded) + ) + else: + stmt = stmt.on_conflict_do_nothing() + + stmt = stmt.returning( + *returning, sort_by_parameter_order=sort_by_parameter_order + ) + return stmt + + +_extensions = [ + ("citext", (13,)), + ("hstore", (13,)), +] + + +@post_configure_engine.for_db("postgresql") +def _create_citext_extension(url, engine, follower_ident): + with engine.connect() as conn: + for extension, min_version in _extensions: + if conn.dialect.server_version_info >= min_version: + conn.execute( + text(f"CREATE EXTENSION IF NOT EXISTS {extension}") + ) + conn.commit() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg.py new file mode 100644 index 0000000..90177a4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg.py @@ -0,0 +1,749 @@ +# dialects/postgresql/psycopg.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+psycopg + :name: psycopg (a.k.a. psycopg 3) + :dbapi: psycopg + :connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg/ + +``psycopg`` is the package and module name for version 3 of the ``psycopg`` +database driver, formerly known as ``psycopg2``. This driver is different +enough from its ``psycopg2`` predecessor that SQLAlchemy supports it +via a totally separate dialect; support for ``psycopg2`` is expected to remain +for as long as that package continues to function for modern Python versions, +and also remains the default dialect for the ``postgresql://`` dialect +series. + +The SQLAlchemy ``psycopg`` dialect provides both a sync and an async +implementation under the same dialect name. The proper version is +selected depending on how the engine is created: + +* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will + automatically select the sync version, e.g.:: + + from sqlalchemy import create_engine + sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test") + +* calling :func:`_asyncio.create_async_engine` with + ``postgresql+psycopg://...`` will automatically select the async version, + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test") + +The asyncio version of the dialect may also be specified explicitly using the +``psycopg_async`` suffix, as:: + + from sqlalchemy.ext.asyncio import create_async_engine + asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test") + +.. seealso:: + + :ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg`` + dialect shares most of its behavior with the ``psycopg2`` dialect. + Further documentation is available there. + +""" # noqa +from __future__ import annotations + +import logging +import re +from typing import cast +from typing import TYPE_CHECKING + +from . import ranges +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg +from .base import INTERVAL +from .base import PGCompiler +from .base import PGIdentifierPreparer +from .base import REGCONFIG +from .json import JSON +from .json import JSONB +from .json import JSONPathType +from .types import CITEXT +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...sql import sqltypes +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + +if TYPE_CHECKING: + from typing import Iterable + + from psycopg import AsyncConnection + +logger = logging.getLogger("sqlalchemy.dialects.postgresql") + + +class _PGString(sqltypes.String): + render_bind_cast = True + + +class _PGREGCONFIG(REGCONFIG): + render_bind_cast = True + + +class _PGJSON(JSON): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Json) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + render_bind_cast = True + + def bind_processor(self, dialect): + return self._make_bind_processor(None, dialect._psycopg_Jsonb) + + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): + __visit_name__ = "json_int_index" + + render_bind_cast = True + + +class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): + __visit_name__ = "json_str_index" + + render_bind_cast = True + + +class _PGJSONPathType(JSONPathType): + pass + + +class _PGInterval(INTERVAL): + render_bind_cast = True + + +class _PGTimeStamp(sqltypes.DateTime): + render_bind_cast = True + + +class _PGDate(sqltypes.Date): + render_bind_cast = True + + +class _PGTime(sqltypes.Time): + render_bind_cast = True + + +class _PGInteger(sqltypes.Integer): + render_bind_cast = True + + +class _PGSmallInteger(sqltypes.SmallInteger): + render_bind_cast = True + + +class _PGNullType(sqltypes.NullType): + render_bind_cast = True + + +class _PGBigInteger(sqltypes.BigInteger): + render_bind_cast = True + + +class _PGBoolean(sqltypes.Boolean): + render_bind_cast = True + + +class _PsycopgRange(ranges.AbstractSingleRangeImpl): + def bind_processor(self, dialect): + psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range + + def to_range(value): + if isinstance(value, ranges.Range): + value = psycopg_Range( + value.lower, value.upper, value.bounds, value.empty + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = ranges.Range( + value._lower, + value._upper, + bounds=value._bounds if value._bounds else "[)", + empty=not value._bounds, + ) + return value + + return to_range + + +class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl): + def bind_processor(self, dialect): + psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range + psycopg_Multirange = cast( + PGDialect_psycopg, dialect + )._psycopg_Multirange + + NoneType = type(None) + + def to_range(value): + if isinstance(value, (str, NoneType, psycopg_Multirange)): + return value + + return psycopg_Multirange( + [ + psycopg_Range( + element.lower, + element.upper, + element.bounds, + element.empty, + ) + for element in cast("Iterable[ranges.Range]", value) + ] + ) + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is None: + return None + else: + return ranges.MultiRange( + ranges.Range( + elem._lower, + elem._upper, + bounds=elem._bounds if elem._bounds else "[)", + empty=not elem._bounds, + ) + for elem in value + ) + + return to_range + + +class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg): + pass + + +class PGCompiler_psycopg(PGCompiler): + pass + + +class PGIdentifierPreparer_psycopg(PGIdentifierPreparer): + pass + + +def _log_notices(diagnostic): + logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary) + + +class PGDialect_psycopg(_PGDialect_common_psycopg): + driver = "psycopg" + + supports_statement_cache = True + supports_server_side_cursors = True + default_paramstyle = "pyformat" + supports_sane_multi_rowcount = True + + execution_ctx_cls = PGExecutionContext_psycopg + statement_compiler = PGCompiler_psycopg + preparer = PGIdentifierPreparer_psycopg + psycopg_version = (0, 0) + + _has_native_hstore = True + _psycopg_adapters_map = None + + colspecs = util.update_copy( + _PGDialect_common_psycopg.colspecs, + { + sqltypes.String: _PGString, + REGCONFIG: _PGREGCONFIG, + JSON: _PGJSON, + CITEXT: CITEXT, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + sqltypes.JSON.JSONPathType: _PGJSONPathType, + sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType, + sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType, + sqltypes.Interval: _PGInterval, + INTERVAL: _PGInterval, + sqltypes.Date: _PGDate, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Time: _PGTime, + sqltypes.Integer: _PGInteger, + sqltypes.SmallInteger: _PGSmallInteger, + sqltypes.BigInteger: _PGBigInteger, + ranges.AbstractSingleRange: _PsycopgRange, + ranges.AbstractMultiRange: _PsycopgMultiRange, + }, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + if self.dbapi: + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + self.psycopg_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + + if self.psycopg_version < (3, 0, 2): + raise ImportError( + "psycopg version 3.0.2 or higher is required." + ) + + from psycopg.adapt import AdaptersMap + + self._psycopg_adapters_map = adapters_map = AdaptersMap( + self.dbapi.adapters + ) + + if self._native_inet_types is False: + import psycopg.types.string + + adapters_map.register_loader( + "inet", psycopg.types.string.TextLoader + ) + adapters_map.register_loader( + "cidr", psycopg.types.string.TextLoader + ) + + if self._json_deserializer: + from psycopg.types.json import set_json_loads + + set_json_loads(self._json_deserializer, adapters_map) + + if self._json_serializer: + from psycopg.types.json import set_json_dumps + + set_json_dumps(self._json_serializer, adapters_map) + + def create_connect_args(self, url): + # see https://github.com/psycopg/psycopg/issues/83 + cargs, cparams = super().create_connect_args(url) + + if self._psycopg_adapters_map: + cparams["context"] = self._psycopg_adapters_map + if self.client_encoding is not None: + cparams["client_encoding"] = self.client_encoding + return cargs, cparams + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + return TypeInfo.fetch(connection.connection.driver_connection, name) + + def initialize(self, connection): + super().initialize(connection) + + # PGDialect.initialize() checks server version for <= 8.2 and sets + # this flag to False if so + if not self.insert_returning: + self.insert_executemany_returning = False + + # HSTORE can't be registered until we have a connection so that + # we can look up its OID, so we set up this adapter in + # initialize() + if self.use_native_hstore: + info = self._type_info_fetch(connection, "hstore") + self._has_native_hstore = info is not None + if self._has_native_hstore: + from psycopg.types.hstore import register_hstore + + # register the adapter for connections made subsequent to + # this one + register_hstore(info, self._psycopg_adapters_map) + + # register the adapter for this connection + register_hstore(info, connection.connection) + + @classmethod + def import_dbapi(cls): + import psycopg + + return psycopg + + @classmethod + def get_async_dialect_cls(cls, url): + return PGDialectAsync_psycopg + + @util.memoized_property + def _isolation_lookup(self): + return { + "READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED, + "READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED, + "REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ, + "SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE, + } + + @util.memoized_property + def _psycopg_Json(self): + from psycopg.types import json + + return json.Json + + @util.memoized_property + def _psycopg_Jsonb(self): + from psycopg.types import json + + return json.Jsonb + + @util.memoized_property + def _psycopg_TransactionStatus(self): + from psycopg.pq import TransactionStatus + + return TransactionStatus + + @util.memoized_property + def _psycopg_Range(self): + from psycopg.types.range import Range + + return Range + + @util.memoized_property + def _psycopg_Multirange(self): + from psycopg.types.multirange import Multirange + + return Multirange + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.autocommit = autocommit + connection.isolation_level = isolation_level + + def get_isolation_level(self, dbapi_connection): + status_before = dbapi_connection.info.transaction_status + value = super().get_isolation_level(dbapi_connection) + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if status_before == self._psycopg_TransactionStatus.IDLE: + dbapi_connection.rollback() + return value + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + self._do_isolation_level( + dbapi_connection, autocommit=True, isolation_level=None + ) + else: + self._do_isolation_level( + dbapi_connection, + autocommit=False, + isolation_level=self._isolation_lookup[level], + ) + + def set_readonly(self, connection, value): + connection.read_only = value + + def get_readonly(self, connection): + return connection.read_only + + def on_connect(self): + def notices(conn): + conn.add_notice_handler(_log_notices) + + fns = [notices] + + if self.isolation_level is not None: + + def on_connect(conn): + self.set_isolation_level(conn, self.isolation_level) + + fns.append(on_connect) + + # fns always has the notices function + def on_connect(conn): + for fn in fns: + fn(conn) + + return on_connect + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error) and connection is not None: + if connection.closed or connection.broken: + return True + return False + + def _do_prepared_twophase(self, connection, command, recover=False): + dbapi_conn = connection.connection.dbapi_connection + if ( + recover + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + or dbapi_conn.info.transaction_status + != self._psycopg_TransactionStatus.IDLE + ): + dbapi_conn.rollback() + before_autocommit = dbapi_conn.autocommit + try: + if not before_autocommit: + self._do_autocommit(dbapi_conn, True) + dbapi_conn.execute(command) + finally: + if not before_autocommit: + self._do_autocommit(dbapi_conn, before_autocommit) + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"ROLLBACK PREPARED '{xid}'", recover=recover + ) + else: + self.do_rollback(connection.connection) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + if is_prepared: + self._do_prepared_twophase( + connection, f"COMMIT PREPARED '{xid}'", recover=recover + ) + else: + self.do_commit(connection.connection) + + @util.memoized_property + def _dialect_specific_select_one(self): + return ";" + + +class AsyncAdapt_psycopg_cursor: + __slots__ = ("_cursor", "await_", "_rows") + + _psycopg_ExecStatus = None + + def __init__(self, cursor, await_) -> None: + self._cursor = cursor + self.await_ = await_ + self._rows = [] + + def __getattr__(self, name): + return getattr(self._cursor, name) + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + def close(self): + self._rows.clear() + # Normal cursor just call _close() in a non-sync way. + self._cursor._close() + + def execute(self, query, params=None, **kw): + result = self.await_(self._cursor.execute(query, params, **kw)) + # sqlalchemy result is not async, so need to pull all rows here + res = self._cursor.pgresult + + # don't rely on psycopg providing enum symbols, compare with + # eq/ne + if res and res.status == self._psycopg_ExecStatus.TUPLES_OK: + rows = self.await_(self._cursor.fetchall()) + if not isinstance(rows, list): + self._rows = list(rows) + else: + self._rows = rows + return result + + def executemany(self, query, params_seq): + return self.await_(self._cursor.executemany(query, params_seq)) + + def __iter__(self): + # TODO: try to avoid pop(0) on a list + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + # TODO: try to avoid pop(0) on a list + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self._cursor.arraysize + + retval = self._rows[0:size] + self._rows = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows + self._rows = [] + return retval + + +class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor): + def execute(self, query, params=None, **kw): + self.await_(self._cursor.execute(query, params, **kw)) + return self + + def close(self): + self.await_(self._cursor.close()) + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=0): + return self.await_(self._cursor.fetchmany(size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + def __iter__(self): + iterator = self._cursor.__aiter__() + while True: + try: + yield self.await_(iterator.__anext__()) + except StopAsyncIteration: + break + + +class AsyncAdapt_psycopg_connection(AdaptedConnection): + _connection: AsyncConnection + __slots__ = () + await_ = staticmethod(await_only) + + def __init__(self, connection) -> None: + self._connection = connection + + def __getattr__(self, name): + return getattr(self._connection, name) + + def execute(self, query, params=None, **kw): + cursor = self.await_(self._connection.execute(query, params, **kw)) + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def cursor(self, *args, **kw): + cursor = self._connection.cursor(*args, **kw) + if hasattr(cursor, "name"): + return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_) + else: + return AsyncAdapt_psycopg_cursor(cursor, self.await_) + + def commit(self): + self.await_(self._connection.commit()) + + def rollback(self): + self.await_(self._connection.rollback()) + + def close(self): + self.await_(self._connection.close()) + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + self.set_autocommit(value) + + def set_autocommit(self, value): + self.await_(self._connection.set_autocommit(value)) + + def set_isolation_level(self, value): + self.await_(self._connection.set_isolation_level(value)) + + def set_read_only(self, value): + self.await_(self._connection.set_read_only(value)) + + def set_deferrable(self, value): + self.await_(self._connection.set_deferrable(value)) + + +class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection): + __slots__ = () + await_ = staticmethod(await_fallback) + + +class PsycopgAdaptDBAPI: + def __init__(self, psycopg) -> None: + self.psycopg = psycopg + + for k, v in self.psycopg.__dict__.items(): + if k != "connect": + self.__dict__[k] = v + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + creator_fn = kw.pop( + "async_creator_fn", self.psycopg.AsyncConnection.connect + ) + if util.asbool(async_fallback): + return AsyncAdaptFallback_psycopg_connection( + await_fallback(creator_fn(*arg, **kw)) + ) + else: + return AsyncAdapt_psycopg_connection( + await_only(creator_fn(*arg, **kw)) + ) + + +class PGDialectAsync_psycopg(PGDialect_psycopg): + is_async = True + supports_statement_cache = True + + @classmethod + def import_dbapi(cls): + import psycopg + from psycopg.pq import ExecStatus + + AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus + + return PsycopgAdaptDBAPI(psycopg) + + @classmethod + def get_pool_class(cls, url): + async_fallback = url.query.get("async_fallback", False) + + if util.asbool(async_fallback): + return pool.FallbackAsyncAdaptedQueuePool + else: + return pool.AsyncAdaptedQueuePool + + def _type_info_fetch(self, connection, name): + from psycopg.types import TypeInfo + + adapted = connection.connection + return adapted.await_(TypeInfo.fetch(adapted.driver_connection, name)) + + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.set_autocommit(autocommit) + connection.set_isolation_level(isolation_level) + + def _do_autocommit(self, connection, value): + connection.set_autocommit(value) + + def set_readonly(self, connection, value): + connection.set_read_only(value) + + def set_deferrable(self, connection, value): + connection.set_deferrable(value) + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = PGDialect_psycopg +dialect_async = PGDialectAsync_psycopg diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py new file mode 100644 index 0000000..9bf2e49 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py @@ -0,0 +1,876 @@ +# dialects/postgresql/psycopg2.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+psycopg2 + :name: psycopg2 + :dbapi: psycopg2 + :connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg2/ + +.. _psycopg2_toplevel: + +psycopg2 Connect Arguments +-------------------------- + +Keyword arguments that are specific to the SQLAlchemy psycopg2 dialect +may be passed to :func:`_sa.create_engine()`, and include the following: + + +* ``isolation_level``: This option, available for all PostgreSQL dialects, + includes the ``AUTOCOMMIT`` isolation level when using the psycopg2 + dialect. This option sets the **default** isolation level for the + connection that is set immediately upon connection to the database before + the connection is pooled. This option is generally superseded by the more + modern :paramref:`_engine.Connection.execution_options.isolation_level` + execution option, detailed at :ref:`dbapi_autocommit`. + + .. seealso:: + + :ref:`psycopg2_isolation_level` + + :ref:`dbapi_autocommit` + + +* ``client_encoding``: sets the client encoding in a libpq-agnostic way, + using psycopg2's ``set_client_encoding()`` method. + + .. seealso:: + + :ref:`psycopg2_unicode` + + +* ``executemany_mode``, ``executemany_batch_page_size``, + ``executemany_values_page_size``: Allows use of psycopg2 + extensions for optimizing "executemany"-style queries. See the referenced + section below for details. + + .. seealso:: + + :ref:`psycopg2_executemany_mode` + +.. tip:: + + The above keyword arguments are **dialect** keyword arguments, meaning + that they are passed as explicit keyword arguments to :func:`_sa.create_engine()`:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@localhost/test", + isolation_level="SERIALIZABLE", + ) + + These should not be confused with **DBAPI** connect arguments, which + are passed as part of the :paramref:`_sa.create_engine.connect_args` + dictionary and/or are passed in the URL query string, as detailed in + the section :ref:`custom_dbapi_args`. + +.. _psycopg2_ssl: + +SSL Connections +--------------- + +The psycopg2 module has a connection argument named ``sslmode`` for +controlling its behavior regarding secure (SSL) connections. The default is +``sslmode=prefer``; it will attempt an SSL connection and if that fails it +will fall back to an unencrypted connection. ``sslmode=require`` may be used +to ensure that only secure connections are established. Consult the +psycopg2 / libpq documentation for further options that are available. + +Note that ``sslmode`` is specific to psycopg2 so it is included in the +connection URI:: + + engine = sa.create_engine( + "postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require" + ) + + +Unix Domain Connections +------------------------ + +psycopg2 supports connecting via Unix domain connections. When the ``host`` +portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2, +which specifies Unix-domain communication rather than TCP/IP communication:: + + create_engine("postgresql+psycopg2://user:password@/dbname") + +By default, the socket file used is to connect to a Unix-domain socket +in ``/tmp``, or whatever socket directory was specified when PostgreSQL +was built. This value can be overridden by passing a pathname to psycopg2, +using ``host`` as an additional keyword argument:: + + create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") + +.. warning:: The format accepted here allows for a hostname in the main URL + in addition to the "host" query string argument. **When using this URL + format, the initial host is silently ignored**. That is, this URL:: + + engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2") + + Above, the hostname ``myhost1`` is **silently ignored and discarded.** The + host which is connected is the ``myhost2`` host. + + This is to maintain some degree of compatibility with PostgreSQL's own URL + format which has been tested to behave the same way and for which tools like + PifPaf hardcode two hostnames. + +.. seealso:: + + `PQconnectdbParams \ + `_ + +.. _psycopg2_multi_host: + +Specifying multiple fallback hosts +----------------------------------- + +psycopg2 supports multiple connection points in the connection string. +When the ``host`` parameter is used multiple times in the query section of +the URL, SQLAlchemy will create a single string of the host and port +information provided to make the connections. Tokens may consist of +``host::port`` or just ``host``; in the latter case, the default port +is selected by libpq. In the example below, three host connections +are specified, for ``HostA::PortA``, ``HostB`` connecting to the default port, +and ``HostC::PortC``:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC" + ) + +As an alternative, libpq query string format also may be used; this specifies +``host`` and ``port`` as single query string arguments with comma-separated +lists - the default port can be chosen by indicating an empty value +in the comma separated list:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA,HostB,HostC&port=PortA,,PortC" + ) + +With either URL style, connections to each host is attempted based on a +configurable strategy, which may be configured using the libpq +``target_session_attrs`` parameter. Per libpq this defaults to ``any`` +which indicates a connection to each host is then attempted until a connection is successful. +Other strategies include ``primary``, ``prefer-standby``, etc. The complete +list is documented by PostgreSQL at +`libpq connection strings `_. + +For example, to indicate two hosts using the ``primary`` strategy:: + + create_engine( + "postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC&target_session_attrs=primary" + ) + +.. versionchanged:: 1.4.40 Port specification in psycopg2 multiple host format + is repaired, previously ports were not correctly interpreted in this context. + libpq comma-separated format is also now supported. + +.. versionadded:: 1.3.20 Support for multiple hosts in PostgreSQL connection + string. + +.. seealso:: + + `libpq connection strings `_ - please refer + to this section in the libpq documentation for complete background on multiple host support. + + +Empty DSN Connections / Environment Variable Connections +--------------------------------------------------------- + +The psycopg2 DBAPI can connect to PostgreSQL by passing an empty DSN to the +libpq client library, which by default indicates to connect to a localhost +PostgreSQL database that is open for "trust" connections. This behavior can be +further tailored using a particular set of environment variables which are +prefixed with ``PG_...``, which are consumed by ``libpq`` to take the place of +any or all elements of the connection string. + +For this form, the URL can be passed without any elements other than the +initial scheme:: + + engine = create_engine('postgresql+psycopg2://') + +In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()`` +function which in turn represents an empty DSN passed to libpq. + +.. versionadded:: 1.3.2 support for parameter-less connections with psycopg2. + +.. seealso:: + + `Environment Variables\ + `_ - + PostgreSQL documentation on how to use ``PG_...`` + environment variables for connections. + +.. _psycopg2_execution_options: + +Per-Statement/Connection Execution Options +------------------------------------------- + +The following DBAPI-specific options are respected when used with +:meth:`_engine.Connection.execution_options`, +:meth:`.Executable.execution_options`, +:meth:`_query.Query.execution_options`, +in addition to those not specific to DBAPIs: + +* ``isolation_level`` - Set the transaction isolation level for the lifespan + of a :class:`_engine.Connection` (can only be set on a connection, + not a statement + or query). See :ref:`psycopg2_isolation_level`. + +* ``stream_results`` - Enable or disable usage of psycopg2 server side + cursors - this feature makes use of "named" cursors in combination with + special result handling methods so that result rows are not fully buffered. + Defaults to False, meaning cursors are buffered by default. + +* ``max_row_buffer`` - when using ``stream_results``, an integer value that + specifies the maximum number of rows to buffer at a time. This is + interpreted by the :class:`.BufferedRowCursorResult`, and if omitted the + buffer will grow to ultimately store 1000 rows at a time. + + .. versionchanged:: 1.4 The ``max_row_buffer`` size can now be greater than + 1000, and the buffer will grow to that size. + +.. _psycopg2_batch_mode: + +.. _psycopg2_executemany_mode: + +Psycopg2 Fast Execution Helpers +------------------------------- + +Modern versions of psycopg2 include a feature known as +`Fast Execution Helpers \ +`_, which +have been shown in benchmarking to improve psycopg2's executemany() +performance, primarily with INSERT statements, by at least +an order of magnitude. + +SQLAlchemy implements a native form of the "insert many values" +handler that will rewrite a single-row INSERT statement to accommodate for +many values at once within an extended VALUES clause; this handler is +equivalent to psycopg2's ``execute_values()`` handler; an overview of this +feature and its configuration are at :ref:`engine_insertmanyvalues`. + +.. versionadded:: 2.0 Replaced psycopg2's ``execute_values()`` fast execution + helper with a native SQLAlchemy mechanism known as + :ref:`insertmanyvalues `. + +The psycopg2 dialect retains the ability to use the psycopg2-specific +``execute_batch()`` feature, although it is not expected that this is a widely +used feature. The use of this extension may be enabled using the +``executemany_mode`` flag which may be passed to :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@host/dbname", + executemany_mode='values_plus_batch') + + +Possible options for ``executemany_mode`` include: + +* ``values_only`` - this is the default value. SQLAlchemy's native + :ref:`insertmanyvalues ` handler is used for qualifying + INSERT statements, assuming + :paramref:`_sa.create_engine.use_insertmanyvalues` is left at + its default value of ``True``. This handler rewrites simple + INSERT statements to include multiple VALUES clauses so that many + parameter sets can be inserted with one statement. + +* ``'values_plus_batch'``- SQLAlchemy's native + :ref:`insertmanyvalues ` handler is used for qualifying + INSERT statements, assuming + :paramref:`_sa.create_engine.use_insertmanyvalues` is left at its default + value of ``True``. Then, psycopg2's ``execute_batch()`` handler is used for + qualifying UPDATE and DELETE statements when executed with multiple parameter + sets. When using this mode, the :attr:`_engine.CursorResult.rowcount` + attribute will not contain a value for executemany-style executions against + UPDATE and DELETE statements. + +.. versionchanged:: 2.0 Removed the ``'batch'`` and ``'None'`` options + from psycopg2 ``executemany_mode``. Control over batching for INSERT + statements is now configured via the + :paramref:`_sa.create_engine.use_insertmanyvalues` engine-level parameter. + +The term "qualifying statements" refers to the statement being executed +being a Core :func:`_expression.insert`, :func:`_expression.update` +or :func:`_expression.delete` construct, and **not** a plain textual SQL +string or one constructed using :func:`_expression.text`. It also may **not** be +a special "extension" statement such as an "ON CONFLICT" "upsert" statement. +When using the ORM, all insert/update/delete statements used by the ORM flush process +are qualifying. + +The "page size" for the psycopg2 "batch" strategy can be affected +by using the ``executemany_batch_page_size`` parameter, which defaults to +100. + +For the "insertmanyvalues" feature, the page size can be controlled using the +:paramref:`_sa.create_engine.insertmanyvalues_page_size` parameter, +which defaults to 1000. An example of modifying both parameters +is below:: + + engine = create_engine( + "postgresql+psycopg2://scott:tiger@host/dbname", + executemany_mode='values_plus_batch', + insertmanyvalues_page_size=5000, executemany_batch_page_size=500) + +.. seealso:: + + :ref:`engine_insertmanyvalues` - background on "insertmanyvalues" + + :ref:`tutorial_multiple_parameters` - General information on using the + :class:`_engine.Connection` + object to execute statements in such a way as to make + use of the DBAPI ``.executemany()`` method. + + +.. _psycopg2_unicode: + +Unicode with Psycopg2 +---------------------- + +The psycopg2 DBAPI driver supports Unicode data transparently. + +The client character encoding can be controlled for the psycopg2 dialect +in the following ways: + +* For PostgreSQL 9.1 and above, the ``client_encoding`` parameter may be + passed in the database URL; this parameter is consumed by the underlying + ``libpq`` PostgreSQL client library:: + + engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8") + + Alternatively, the above ``client_encoding`` value may be passed using + :paramref:`_sa.create_engine.connect_args` for programmatic establishment with + ``libpq``:: + + engine = create_engine( + "postgresql+psycopg2://user:pass@host/dbname", + connect_args={'client_encoding': 'utf8'} + ) + +* For all PostgreSQL versions, psycopg2 supports a client-side encoding + value that will be passed to database connections when they are first + established. The SQLAlchemy psycopg2 dialect supports this using the + ``client_encoding`` parameter passed to :func:`_sa.create_engine`:: + + engine = create_engine( + "postgresql+psycopg2://user:pass@host/dbname", + client_encoding="utf8" + ) + + .. tip:: The above ``client_encoding`` parameter admittedly is very similar + in appearance to usage of the parameter within the + :paramref:`_sa.create_engine.connect_args` dictionary; the difference + above is that the parameter is consumed by psycopg2 and is + passed to the database connection using ``SET client_encoding TO + 'utf8'``; in the previously mentioned style, the parameter is instead + passed through psycopg2 and consumed by the ``libpq`` library. + +* A common way to set up client encoding with PostgreSQL databases is to + ensure it is configured within the server-side postgresql.conf file; + this is the recommended way to set encoding for a server that is + consistently of one encoding in all databases:: + + # postgresql.conf file + + # client_encoding = sql_ascii # actually, defaults to database + # encoding + client_encoding = utf8 + + + +Transactions +------------ + +The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations. + +.. _psycopg2_isolation_level: + +Psycopg2 Transaction Isolation Level +------------------------------------- + +As discussed in :ref:`postgresql_isolation_level`, +all PostgreSQL dialects support setting of transaction isolation level +both via the ``isolation_level`` parameter passed to :func:`_sa.create_engine` +, +as well as the ``isolation_level`` argument used by +:meth:`_engine.Connection.execution_options`. When using the psycopg2 dialect +, these +options make use of psycopg2's ``set_isolation_level()`` connection method, +rather than emitting a PostgreSQL directive; this is because psycopg2's +API-level setting is always emitted at the start of each transaction in any +case. + +The psycopg2 dialect supports these constants for isolation level: + +* ``READ COMMITTED`` +* ``READ UNCOMMITTED`` +* ``REPEATABLE READ`` +* ``SERIALIZABLE`` +* ``AUTOCOMMIT`` + +.. seealso:: + + :ref:`postgresql_isolation_level` + + :ref:`pg8000_isolation_level` + + +NOTICE logging +--------------- + +The psycopg2 dialect will log PostgreSQL NOTICE messages +via the ``sqlalchemy.dialects.postgresql`` logger. When this logger +is set to the ``logging.INFO`` level, notice messages will be logged:: + + import logging + + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + +Above, it is assumed that logging is configured externally. If this is not +the case, configuration such as ``logging.basicConfig()`` must be utilized:: + + import logging + + logging.basicConfig() # log messages to stdout + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) + +.. seealso:: + + `Logging HOWTO `_ - on the python.org website + +.. _psycopg2_hstore: + +HSTORE type +------------ + +The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of +the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension +by default when psycopg2 version 2.4 or greater is used, and +it is detected that the target database has the HSTORE type set up for use. +In other words, when the dialect makes the first +connection, a sequence like the following is performed: + +1. Request the available HSTORE oids using + ``psycopg2.extras.HstoreAdapter.get_oids()``. + If this function returns a list of HSTORE identifiers, we then determine + that the ``HSTORE`` extension is present. + This function is **skipped** if the version of psycopg2 installed is + less than version 2.4. + +2. If the ``use_native_hstore`` flag is at its default of ``True``, and + we've detected that ``HSTORE`` oids are available, the + ``psycopg2.extensions.register_hstore()`` extension is invoked for all + connections. + +The ``register_hstore()`` extension has the effect of **all Python +dictionaries being accepted as parameters regardless of the type of target +column in SQL**. The dictionaries are converted by this extension into a +textual HSTORE expression. If this behavior is not desired, disable the +use of the hstore extension by setting ``use_native_hstore`` to ``False`` as +follows:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + use_native_hstore=False) + +The ``HSTORE`` type is **still supported** when the +``psycopg2.extensions.register_hstore()`` extension is not used. It merely +means that the coercion between Python dictionaries and the HSTORE +string format, on both the parameter side and the result side, will take +place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2`` +which may be more performant. + +""" # noqa +from __future__ import annotations + +import collections.abc as collections_abc +import logging +import re +from typing import cast + +from . import ranges +from ._psycopg_common import _PGDialect_common_psycopg +from ._psycopg_common import _PGExecutionContext_common_psycopg +from .base import PGIdentifierPreparer +from .json import JSON +from .json import JSONB +from ... import types as sqltypes +from ... import util +from ...util import FastIntFlag +from ...util import parse_user_argument_for_enum + +logger = logging.getLogger("sqlalchemy.dialects.postgresql") + + +class _PGJSON(JSON): + def result_processor(self, dialect, coltype): + return None + + +class _PGJSONB(JSONB): + def result_processor(self, dialect, coltype): + return None + + +class _Psycopg2Range(ranges.AbstractSingleRangeImpl): + _psycopg2_range_cls = "none" + + def bind_processor(self, dialect): + psycopg2_Range = getattr( + cast(PGDialect_psycopg2, dialect)._psycopg2_extras, + self._psycopg2_range_cls, + ) + + def to_range(value): + if isinstance(value, ranges.Range): + value = psycopg2_Range( + value.lower, value.upper, value.bounds, value.empty + ) + return value + + return to_range + + def result_processor(self, dialect, coltype): + def to_range(value): + if value is not None: + value = ranges.Range( + value._lower, + value._upper, + bounds=value._bounds if value._bounds else "[)", + empty=not value._bounds, + ) + return value + + return to_range + + +class _Psycopg2NumericRange(_Psycopg2Range): + _psycopg2_range_cls = "NumericRange" + + +class _Psycopg2DateRange(_Psycopg2Range): + _psycopg2_range_cls = "DateRange" + + +class _Psycopg2DateTimeRange(_Psycopg2Range): + _psycopg2_range_cls = "DateTimeRange" + + +class _Psycopg2DateTimeTZRange(_Psycopg2Range): + _psycopg2_range_cls = "DateTimeTZRange" + + +class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg): + _psycopg2_fetched_rows = None + + def post_exec(self): + self._log_notices(self.cursor) + + def _log_notices(self, cursor): + # check also that notices is an iterable, after it's already + # established that we will be iterating through it. This is to get + # around test suites such as SQLAlchemy's using a Mock object for + # cursor + if not cursor.connection.notices or not isinstance( + cursor.connection.notices, collections_abc.Iterable + ): + return + + for notice in cursor.connection.notices: + # NOTICE messages have a + # newline character at the end + logger.info(notice.rstrip()) + + cursor.connection.notices[:] = [] + + +class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): + pass + + +class ExecutemanyMode(FastIntFlag): + EXECUTEMANY_VALUES = 0 + EXECUTEMANY_VALUES_PLUS_BATCH = 1 + + +( + EXECUTEMANY_VALUES, + EXECUTEMANY_VALUES_PLUS_BATCH, +) = ExecutemanyMode.__members__.values() + + +class PGDialect_psycopg2(_PGDialect_common_psycopg): + driver = "psycopg2" + + supports_statement_cache = True + supports_server_side_cursors = True + + default_paramstyle = "pyformat" + # set to true based on psycopg2 version + supports_sane_multi_rowcount = False + execution_ctx_cls = PGExecutionContext_psycopg2 + preparer = PGIdentifierPreparer_psycopg2 + psycopg2_version = (0, 0) + use_insertmanyvalues_wo_returning = True + + returns_native_bytes = False + + _has_native_hstore = True + + colspecs = util.update_copy( + _PGDialect_common_psycopg.colspecs, + { + JSON: _PGJSON, + sqltypes.JSON: _PGJSON, + JSONB: _PGJSONB, + ranges.INT4RANGE: _Psycopg2NumericRange, + ranges.INT8RANGE: _Psycopg2NumericRange, + ranges.NUMRANGE: _Psycopg2NumericRange, + ranges.DATERANGE: _Psycopg2DateRange, + ranges.TSRANGE: _Psycopg2DateTimeRange, + ranges.TSTZRANGE: _Psycopg2DateTimeTZRange, + }, + ) + + def __init__( + self, + executemany_mode="values_only", + executemany_batch_page_size=100, + **kwargs, + ): + _PGDialect_common_psycopg.__init__(self, **kwargs) + + if self._native_inet_types: + raise NotImplementedError( + "The psycopg2 dialect does not implement " + "ipaddress type handling; native_inet_types cannot be set " + "to ``True`` when using this dialect." + ) + + # Parse executemany_mode argument, allowing it to be only one of the + # symbol names + self.executemany_mode = parse_user_argument_for_enum( + executemany_mode, + { + EXECUTEMANY_VALUES: ["values_only"], + EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch"], + }, + "executemany_mode", + ) + + self.executemany_batch_page_size = executemany_batch_page_size + + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) + if m: + self.psycopg2_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + + if self.psycopg2_version < (2, 7): + raise ImportError( + "psycopg2 version 2.7 or higher is required." + ) + + def initialize(self, connection): + super().initialize(connection) + self._has_native_hstore = ( + self.use_native_hstore + and self._hstore_oids(connection.connection.dbapi_connection) + is not None + ) + + self.supports_sane_multi_rowcount = ( + self.executemany_mode is not EXECUTEMANY_VALUES_PLUS_BATCH + ) + + @classmethod + def import_dbapi(cls): + import psycopg2 + + return psycopg2 + + @util.memoized_property + def _psycopg2_extensions(cls): + from psycopg2 import extensions + + return extensions + + @util.memoized_property + def _psycopg2_extras(cls): + from psycopg2 import extras + + return extras + + @util.memoized_property + def _isolation_lookup(self): + extensions = self._psycopg2_extensions + return { + "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT, + "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED, + "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, + "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ, + "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE, + } + + def set_isolation_level(self, dbapi_connection, level): + dbapi_connection.set_isolation_level(self._isolation_lookup[level]) + + def set_readonly(self, connection, value): + connection.readonly = value + + def get_readonly(self, connection): + return connection.readonly + + def set_deferrable(self, connection, value): + connection.deferrable = value + + def get_deferrable(self, connection): + return connection.deferrable + + def on_connect(self): + extras = self._psycopg2_extras + + fns = [] + if self.client_encoding is not None: + + def on_connect(dbapi_conn): + dbapi_conn.set_client_encoding(self.client_encoding) + + fns.append(on_connect) + + if self.dbapi: + + def on_connect(dbapi_conn): + extras.register_uuid(None, dbapi_conn) + + fns.append(on_connect) + + if self.dbapi and self.use_native_hstore: + + def on_connect(dbapi_conn): + hstore_oids = self._hstore_oids(dbapi_conn) + if hstore_oids is not None: + oid, array_oid = hstore_oids + kw = {"oid": oid} + kw["array_oid"] = array_oid + extras.register_hstore(dbapi_conn, **kw) + + fns.append(on_connect) + + if self.dbapi and self._json_deserializer: + + def on_connect(dbapi_conn): + extras.register_default_json( + dbapi_conn, loads=self._json_deserializer + ) + extras.register_default_jsonb( + dbapi_conn, loads=self._json_deserializer + ) + + fns.append(on_connect) + + if fns: + + def on_connect(dbapi_conn): + for fn in fns: + fn(dbapi_conn) + + return on_connect + else: + return None + + def do_executemany(self, cursor, statement, parameters, context=None): + if self.executemany_mode is EXECUTEMANY_VALUES_PLUS_BATCH: + if self.executemany_batch_page_size: + kwargs = {"page_size": self.executemany_batch_page_size} + else: + kwargs = {} + self._psycopg2_extras.execute_batch( + cursor, statement, parameters, **kwargs + ) + else: + cursor.executemany(statement, parameters) + + def do_begin_twophase(self, connection, xid): + connection.connection.tpc_begin(xid) + + def do_prepare_twophase(self, connection, xid): + connection.connection.tpc_prepare() + + def _do_twophase(self, dbapi_conn, operation, xid, recover=False): + if recover: + if dbapi_conn.status != self._psycopg2_extensions.STATUS_READY: + dbapi_conn.rollback() + operation(xid) + else: + operation() + + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + dbapi_conn = connection.connection.dbapi_connection + self._do_twophase( + dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover + ) + + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): + dbapi_conn = connection.connection.dbapi_connection + self._do_twophase( + dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover + ) + + @util.memoized_instancemethod + def _hstore_oids(self, dbapi_connection): + extras = self._psycopg2_extras + oids = extras.HstoreAdapter.get_oids(dbapi_connection) + if oids is not None and oids[0]: + return oids[0:2] + else: + return None + + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.Error): + # check the "closed" flag. this might not be + # present on old psycopg2 versions. Also, + # this flag doesn't actually help in a lot of disconnect + # situations, so don't rely on it. + if getattr(connection, "closed", False): + return True + + # checks based on strings. in the case that .closed + # didn't cut it, fall back onto these. + str_e = str(e).partition("\n")[0] + for msg in [ + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", + # psycopg2 client errors, psycopg2/connection.h, + # psycopg2/cursor.h + "connection already closed", + "cursor already closed", + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + "losed the connection unexpectedly", + # these can occur in newer SSL + "connection has been closed unexpectedly", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL SYSCALL error: Operation timed out", + "SSL SYSCALL error: Bad address", + ]: + idx = str_e.find(msg) + if idx >= 0 and '"' not in str_e[:idx]: + return True + return False + + +dialect = PGDialect_psycopg2 diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py new file mode 100644 index 0000000..3cc3b69 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -0,0 +1,61 @@ +# dialects/postgresql/psycopg2cffi.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r""" +.. dialect:: postgresql+psycopg2cffi + :name: psycopg2cffi + :dbapi: psycopg2cffi + :connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...] + :url: https://pypi.org/project/psycopg2cffi/ + +``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C +layer. This makes it suitable for use in e.g. PyPy. Documentation +is as per ``psycopg2``. + +.. seealso:: + + :mod:`sqlalchemy.dialects.postgresql.psycopg2` + +""" # noqa +from .psycopg2 import PGDialect_psycopg2 +from ... import util + + +class PGDialect_psycopg2cffi(PGDialect_psycopg2): + driver = "psycopg2cffi" + supports_unicode_statements = True + supports_statement_cache = True + + # psycopg2cffi's first release is 2.5.0, but reports + # __version__ as 2.4.4. Subsequent releases seem to have + # fixed this. + + FEATURE_VERSION_MAP = dict( + native_json=(2, 4, 4), + native_jsonb=(2, 7, 1), + sane_multi_rowcount=(2, 4, 4), + array_oid=(2, 4, 4), + hstore_adapter=(2, 4, 4), + ) + + @classmethod + def import_dbapi(cls): + return __import__("psycopg2cffi") + + @util.memoized_property + def _psycopg2_extensions(cls): + root = __import__("psycopg2cffi", fromlist=["extensions"]) + return root.extensions + + @util.memoized_property + def _psycopg2_extras(cls): + root = __import__("psycopg2cffi", fromlist=["extras"]) + return root.extras + + +dialect = PGDialect_psycopg2cffi diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ranges.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ranges.py new file mode 100644 index 0000000..b793ca4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/ranges.py @@ -0,0 +1,1029 @@ +# dialects/postgresql/ranges.py +# Copyright (C) 2013-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import dataclasses +from datetime import date +from datetime import datetime +from datetime import timedelta +from decimal import Decimal +from typing import Any +from typing import cast +from typing import Generic +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .operators import ADJACENT_TO +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import NOT_EXTEND_LEFT_OF +from .operators import NOT_EXTEND_RIGHT_OF +from .operators import OVERLAP +from .operators import STRICTLY_LEFT_OF +from .operators import STRICTLY_RIGHT_OF +from ... import types as sqltypes +from ...sql import operators +from ...sql.type_api import TypeEngine +from ...util import py310 +from ...util.typing import Literal + +if TYPE_CHECKING: + from ...sql.elements import ColumnElement + from ...sql.type_api import _TE + from ...sql.type_api import TypeEngineMixin + +_T = TypeVar("_T", bound=Any) + +_BoundsType = Literal["()", "[)", "(]", "[]"] + +if py310: + dc_slots = {"slots": True} + dc_kwonly = {"kw_only": True} +else: + dc_slots = {} + dc_kwonly = {} + + +@dataclasses.dataclass(frozen=True, **dc_slots) +class Range(Generic[_T]): + """Represent a PostgreSQL range. + + E.g.:: + + r = Range(10, 50, bounds="()") + + The calling style is similar to that of psycopg and psycopg2, in part + to allow easier migration from previous SQLAlchemy versions that used + these objects directly. + + :param lower: Lower bound value, or None + :param upper: Upper bound value, or None + :param bounds: keyword-only, optional string value that is one of + ``"()"``, ``"[)"``, ``"(]"``, ``"[]"``. Defaults to ``"[)"``. + :param empty: keyword-only, optional bool indicating this is an "empty" + range + + .. versionadded:: 2.0 + + """ + + lower: Optional[_T] = None + """the lower bound""" + + upper: Optional[_T] = None + """the upper bound""" + + if TYPE_CHECKING: + bounds: _BoundsType = dataclasses.field(default="[)") + empty: bool = dataclasses.field(default=False) + else: + bounds: _BoundsType = dataclasses.field(default="[)", **dc_kwonly) + empty: bool = dataclasses.field(default=False, **dc_kwonly) + + if not py310: + + def __init__( + self, + lower: Optional[_T] = None, + upper: Optional[_T] = None, + *, + bounds: _BoundsType = "[)", + empty: bool = False, + ): + # no __slots__ either so we can update dict + self.__dict__.update( + { + "lower": lower, + "upper": upper, + "bounds": bounds, + "empty": empty, + } + ) + + def __bool__(self) -> bool: + return not self.empty + + @property + def isempty(self) -> bool: + "A synonym for the 'empty' attribute." + + return self.empty + + @property + def is_empty(self) -> bool: + "A synonym for the 'empty' attribute." + + return self.empty + + @property + def lower_inc(self) -> bool: + """Return True if the lower bound is inclusive.""" + + return self.bounds[0] == "[" + + @property + def lower_inf(self) -> bool: + """Return True if this range is non-empty and lower bound is + infinite.""" + + return not self.empty and self.lower is None + + @property + def upper_inc(self) -> bool: + """Return True if the upper bound is inclusive.""" + + return self.bounds[1] == "]" + + @property + def upper_inf(self) -> bool: + """Return True if this range is non-empty and the upper bound is + infinite.""" + + return not self.empty and self.upper is None + + @property + def __sa_type_engine__(self) -> AbstractSingleRange[_T]: + return AbstractSingleRange() + + def _contains_value(self, value: _T) -> bool: + """Return True if this range contains the given value.""" + + if self.empty: + return False + + if self.lower is None: + return self.upper is None or ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + if self.upper is None: + return ( # type: ignore + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) + + return ( # type: ignore + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) and ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + def _get_discrete_step(self) -> Any: + "Determine the “step” for this range, if it is a discrete one." + + # See + # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE + # for the rationale + + if isinstance(self.lower, int) or isinstance(self.upper, int): + return 1 + elif isinstance(self.lower, datetime) or isinstance( + self.upper, datetime + ): + # This is required, because a `isinstance(datetime.now(), date)` + # is True + return None + elif isinstance(self.lower, date) or isinstance(self.upper, date): + return timedelta(days=1) + else: + return None + + def _compare_edges( + self, + value1: Optional[_T], + bound1: str, + value2: Optional[_T], + bound2: str, + only_values: bool = False, + ) -> int: + """Compare two range bounds. + + Return -1, 0 or 1 respectively when `value1` is less than, + equal to or greater than `value2`. + + When `only_value` is ``True``, do not consider the *inclusivity* + of the edges, just their values. + """ + + value1_is_lower_bound = bound1 in {"[", "("} + value2_is_lower_bound = bound2 in {"[", "("} + + # Infinite edges are equal when they are on the same side, + # otherwise a lower edge is considered less than the upper end + if value1 is value2 is None: + if value1_is_lower_bound == value2_is_lower_bound: + return 0 + else: + return -1 if value1_is_lower_bound else 1 + elif value1 is None: + return -1 if value1_is_lower_bound else 1 + elif value2 is None: + return 1 if value2_is_lower_bound else -1 + + # Short path for trivial case + if bound1 == bound2 and value1 == value2: + return 0 + + value1_inc = bound1 in {"[", "]"} + value2_inc = bound2 in {"[", "]"} + step = self._get_discrete_step() + + if step is not None: + # "Normalize" the two edges as '[)', to simplify successive + # logic when the range is discrete: otherwise we would need + # to handle the comparison between ``(0`` and ``[1`` that + # are equal when dealing with integers while for floats the + # former is lesser than the latter + + if value1_is_lower_bound: + if not value1_inc: + value1 += step + value1_inc = True + else: + if value1_inc: + value1 += step + value1_inc = False + if value2_is_lower_bound: + if not value2_inc: + value2 += step + value2_inc = True + else: + if value2_inc: + value2 += step + value2_inc = False + + if value1 < value2: # type: ignore + return -1 + elif value1 > value2: # type: ignore + return 1 + elif only_values: + return 0 + else: + # Neither one is infinite but are equal, so we + # need to consider the respective inclusive/exclusive + # flag + + if value1_inc and value2_inc: + return 0 + elif not value1_inc and not value2_inc: + if value1_is_lower_bound == value2_is_lower_bound: + return 0 + else: + return 1 if value1_is_lower_bound else -1 + elif not value1_inc: + return 1 if value1_is_lower_bound else -1 + elif not value2_inc: + return -1 if value2_is_lower_bound else 1 + else: + return 0 + + def __eq__(self, other: Any) -> bool: + """Compare this range to the `other` taking into account + bounds inclusivity, returning ``True`` if they are equal. + """ + + if not isinstance(other, Range): + return NotImplemented + + if self.empty and other.empty: + return True + elif self.empty != other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + return ( + self._compare_edges(slower, slower_b, olower, olower_b) == 0 + and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0 + ) + + def contained_by(self, other: Range[_T]) -> bool: + "Determine whether this range is a contained by `other`." + + # Any range contains the empty one + if self.empty: + return True + + # An empty range does not contain any range except the empty one + if other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + return False + + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + return False + + return True + + def contains(self, value: Union[_T, Range[_T]]) -> bool: + "Determine whether this range contains `value`." + + if isinstance(value, Range): + return value.contained_by(self) + else: + return self._contains_value(value) + + def overlaps(self, other: Range[_T]) -> bool: + "Determine whether this range overlaps with `other`." + + # Empty ranges never overlap with any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower bound is contained in the other range + if ( + self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + and self._compare_edges(slower, slower_b, oupper, oupper_b) <= 0 + ): + return True + + # Check whether other lower bound is contained in this range + if ( + self._compare_edges(olower, olower_b, slower, slower_b) >= 0 + and self._compare_edges(olower, olower_b, supper, supper_b) <= 0 + ): + return True + + return False + + def strictly_left_of(self, other: Range[_T]) -> bool: + "Determine whether this range is completely to the left of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this upper edge is less than other's lower end + return self._compare_edges(supper, supper_b, olower, olower_b) < 0 + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other: Range[_T]) -> bool: + "Determine whether this range is completely to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower edge is greater than other's upper end + return self._compare_edges(slower, slower_b, oupper, oupper_b) > 0 + + __rshift__ = strictly_right_of + + def not_extend_left_of(self, other: Range[_T]) -> bool: + "Determine whether this does not extend to the left of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this lower edge is not less than other's lower end + return self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + + def not_extend_right_of(self, other: Range[_T]) -> bool: + "Determine whether this does not extend to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this upper edge is not greater than other's upper end + return self._compare_edges(supper, supper_b, oupper, oupper_b) <= 0 + + def _upper_edge_adjacent_to_lower( + self, + value1: Optional[_T], + bound1: str, + value2: Optional[_T], + bound2: str, + ) -> bool: + """Determine whether an upper bound is immediately successive to a + lower bound.""" + + # Since we need a peculiar way to handle the bounds inclusivity, + # just do a comparison by value here + res = self._compare_edges(value1, bound1, value2, bound2, True) + if res == -1: + step = self._get_discrete_step() + if step is None: + return False + if bound1 == "]": + if bound2 == "[": + return value1 == value2 - step # type: ignore + else: + return value1 == value2 + else: + if bound2 == "[": + return value1 == value2 + else: + return value1 == value2 - step # type: ignore + elif res == 0: + # Cover cases like [0,0] -|- [1,] and [0,2) -|- (1,3] + if ( + bound1 == "]" + and bound2 == "[" + or bound1 == ")" + and bound2 == "(" + ): + step = self._get_discrete_step() + if step is not None: + return True + return ( + bound1 == ")" + and bound2 == "[" + or bound1 == "]" + and bound2 == "(" + ) + else: + return False + + def adjacent_to(self, other: Range[_T]) -> bool: + "Determine whether this range is adjacent to the `other`." + + # Empty ranges are not adjacent to any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + return self._upper_edge_adjacent_to_lower( + supper, supper_b, olower, olower_b + ) or self._upper_edge_adjacent_to_lower( + oupper, oupper_b, slower, slower_b + ) + + def union(self, other: Range[_T]) -> Range[_T]: + """Compute the union of this range with the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. + """ + + # Empty ranges are "additive identities" + if self.empty: + return other + if other.empty: + return self + + if not self.overlaps(other) and not self.adjacent_to(other): + raise ValueError( + "Adding non-overlapping and non-adjacent" + " ranges is not implemented" + ) + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = slower + rlower_b = slower_b + else: + rlower = olower + rlower_b = olower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = supper + rupper_b = supper_b + else: + rupper = oupper + rupper_b = oupper_b + + return Range( + rlower, rupper, bounds=cast(_BoundsType, rlower_b + rupper_b) + ) + + def __add__(self, other: Range[_T]) -> Range[_T]: + return self.union(other) + + def difference(self, other: Range[_T]) -> Range[_T]: + """Compute the difference between this range and the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. + """ + + # Subtracting an empty range is a no-op + if self.empty or other.empty: + return self + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + sl_vs_ol = self._compare_edges(slower, slower_b, olower, olower_b) + su_vs_ou = self._compare_edges(supper, supper_b, oupper, oupper_b) + if sl_vs_ol < 0 and su_vs_ou > 0: + raise ValueError( + "Subtracting a strictly inner range is not implemented" + ) + + sl_vs_ou = self._compare_edges(slower, slower_b, oupper, oupper_b) + su_vs_ol = self._compare_edges(supper, supper_b, olower, olower_b) + + # If the ranges do not overlap, result is simply the first + if sl_vs_ou > 0 or su_vs_ol < 0: + return self + + # If this range is completely contained by the other, result is empty + if sl_vs_ol >= 0 and su_vs_ou <= 0: + return Range(None, None, empty=True) + + # If this range extends to the left of the other and ends in its + # middle + if sl_vs_ol <= 0 and su_vs_ol >= 0 and su_vs_ou <= 0: + rupper_b = ")" if olower_b == "[" else "]" + if ( + slower_b != "[" + and rupper_b != "]" + and self._compare_edges(slower, slower_b, olower, rupper_b) + == 0 + ): + return Range(None, None, empty=True) + else: + return Range( + slower, + olower, + bounds=cast(_BoundsType, slower_b + rupper_b), + ) + + # If this range starts in the middle of the other and extends to its + # right + if sl_vs_ol >= 0 and su_vs_ou >= 0 and sl_vs_ou <= 0: + rlower_b = "(" if oupper_b == "]" else "[" + if ( + rlower_b != "[" + and supper_b != "]" + and self._compare_edges(oupper, rlower_b, supper, supper_b) + == 0 + ): + return Range(None, None, empty=True) + else: + return Range( + oupper, + supper, + bounds=cast(_BoundsType, rlower_b + supper_b), + ) + + assert False, f"Unhandled case computing {self} - {other}" + + def __sub__(self, other: Range[_T]) -> Range[_T]: + return self.difference(other) + + def intersection(self, other: Range[_T]) -> Range[_T]: + """Compute the intersection of this range with the `other`. + + .. versionadded:: 2.0.10 + + """ + if self.empty or other.empty or not self.overlaps(other): + return Range(None, None, empty=True) + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = olower + rlower_b = olower_b + else: + rlower = slower + rlower_b = slower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = oupper + rupper_b = oupper_b + else: + rupper = supper + rupper_b = supper_b + + return Range( + rlower, + rupper, + bounds=cast(_BoundsType, rlower_b + rupper_b), + ) + + def __mul__(self, other: Range[_T]) -> Range[_T]: + return self.intersection(other) + + def __str__(self) -> str: + return self._stringify() + + def _stringify(self) -> str: + if self.empty: + return "empty" + + l, r = self.lower, self.upper + l = "" if l is None else l # type: ignore + r = "" if r is None else r # type: ignore + + b0, b1 = cast("Tuple[str, str]", self.bounds) + + return f"{b0}{l},{r}{b1}" + + +class MultiRange(List[Range[_T]]): + """Represents a multirange sequence. + + This list subclass is an utility to allow automatic type inference of + the proper multi-range SQL type depending on the single range values. + This is useful when operating on literal multi-ranges:: + + import sqlalchemy as sa + from sqlalchemy.dialects.postgresql import MultiRange, Range + + value = literal(MultiRange([Range(2, 4)])) + + select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)]))) + + .. versionadded:: 2.0.26 + + .. seealso:: + + - :ref:`postgresql_multirange_list_use`. + """ + + @property + def __sa_type_engine__(self) -> AbstractMultiRange[_T]: + return AbstractMultiRange() + + +class AbstractRange(sqltypes.TypeEngine[_T]): + """Base class for single and multi Range SQL types.""" + + render_bind_cast = True + + __abstract__ = True + + @overload + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... + + @overload + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... + + def adapt( + self, + cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], + **kw: Any, + ) -> TypeEngine[Any]: + """Dynamically adapt a range type to an abstract impl. + + For example ``INT4RANGE().adapt(_Psycopg2NumericRange)`` should + produce a type that will have ``_Psycopg2NumericRange`` behaviors + and also render as ``INT4RANGE`` in SQL and DDL. + + """ + if ( + issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl)) + and cls is not self.__class__ + ): + # two ways to do this are: 1. create a new type on the fly + # or 2. have AbstractRangeImpl(visit_name) constructor and a + # visit_abstract_range_impl() method in the PG compiler. + # I'm choosing #1 as the resulting type object + # will then make use of the same mechanics + # as if we had made all these sub-types explicitly, and will + # also look more obvious under pdb etc. + # The adapt() operation here is cached per type-class-per-dialect, + # so is not much of a performance concern + visit_name = self.__visit_name__ + return type( # type: ignore + f"{visit_name}RangeImpl", + (cls, self.__class__), + {"__visit_name__": visit_name}, + )() + else: + return super().adapt(cls) + + class comparator_factory(TypeEngine.Comparator[Range[Any]]): + """Define comparison operations for range types.""" + + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the right hand operand, + which can be an element or a range, is contained within the + column. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.expr.operate(CONTAINS, other) + + def contained_by(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column is contained + within the right hand operand. + """ + return self.expr.operate(CONTAINED_BY, other) + + def overlaps(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + return self.expr.operate(OVERLAP, other) + + def strictly_left_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + return self.expr.operate(STRICTLY_LEFT_OF, other) + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + return self.expr.operate(STRICTLY_RIGHT_OF, other) + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + return self.expr.operate(NOT_EXTEND_RIGHT_OF, other) + + def not_extend_left_of(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + return self.expr.operate(NOT_EXTEND_LEFT_OF, other) + + def adjacent_to(self, other: Any) -> ColumnElement[bool]: + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + return self.expr.operate(ADJACENT_TO, other) + + def union(self, other: Any) -> ColumnElement[bool]: + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.operate(operators.add, other) + + def difference(self, other: Any) -> ColumnElement[bool]: + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.operate(operators.sub, other) + + def intersection(self, other: Any) -> ColumnElement[Range[_T]]: + """Range expression. Returns the intersection of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.operate(operators.mul, other) + + +class AbstractSingleRange(AbstractRange[Range[_T]]): + """Base for PostgreSQL RANGE types. + + These are types that return a single :class:`_postgresql.Range` object. + + .. seealso:: + + `PostgreSQL range functions `_ + + """ # noqa: E501 + + __abstract__ = True + + def _resolve_for_literal(self, value: Range[Any]) -> Any: + spec = value.lower if value.lower is not None else value.upper + + if isinstance(spec, int): + # pg is unreasonably picky here: the query + # "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises + # "operator does not exist: integer <@ int8range" as of pg 16 + if _is_int32(value): + return INT4RANGE() + else: + return INT8RANGE() + elif isinstance(spec, (Decimal, float)): + return NUMRANGE() + elif isinstance(spec, datetime): + return TSRANGE() if not spec.tzinfo else TSTZRANGE() + elif isinstance(spec, date): + return DATERANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + + +class AbstractSingleRangeImpl(AbstractSingleRange[_T]): + """Marker for AbstractSingleRange that will apply a subclass-specific + adaptation""" + + +class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]): + """Base for PostgreSQL MULTIRANGE types. + + these are types that return a sequence of :class:`_postgresql.Range` + objects. + + """ + + __abstract__ = True + + def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any: + if not value: + # empty MultiRange, SQL datatype can't be determined here + return sqltypes.NULLTYPE + first = value[0] + spec = first.lower if first.lower is not None else first.upper + + if isinstance(spec, int): + # pg is unreasonably picky here: the query + # "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises + # "operator does not exist: integer <@ int8multirange" as of pg 16 + if all(_is_int32(r) for r in value): + return INT4MULTIRANGE() + else: + return INT8MULTIRANGE() + elif isinstance(spec, (Decimal, float)): + return NUMMULTIRANGE() + elif isinstance(spec, datetime): + return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE() + elif isinstance(spec, date): + return DATEMULTIRANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + + +class AbstractMultiRangeImpl(AbstractMultiRange[_T]): + """Marker for AbstractMultiRange that will apply a subclass-specific + adaptation""" + + +class INT4RANGE(AbstractSingleRange[int]): + """Represent the PostgreSQL INT4RANGE type.""" + + __visit_name__ = "INT4RANGE" + + +class INT8RANGE(AbstractSingleRange[int]): + """Represent the PostgreSQL INT8RANGE type.""" + + __visit_name__ = "INT8RANGE" + + +class NUMRANGE(AbstractSingleRange[Decimal]): + """Represent the PostgreSQL NUMRANGE type.""" + + __visit_name__ = "NUMRANGE" + + +class DATERANGE(AbstractSingleRange[date]): + """Represent the PostgreSQL DATERANGE type.""" + + __visit_name__ = "DATERANGE" + + +class TSRANGE(AbstractSingleRange[datetime]): + """Represent the PostgreSQL TSRANGE type.""" + + __visit_name__ = "TSRANGE" + + +class TSTZRANGE(AbstractSingleRange[datetime]): + """Represent the PostgreSQL TSTZRANGE type.""" + + __visit_name__ = "TSTZRANGE" + + +class INT4MULTIRANGE(AbstractMultiRange[int]): + """Represent the PostgreSQL INT4MULTIRANGE type.""" + + __visit_name__ = "INT4MULTIRANGE" + + +class INT8MULTIRANGE(AbstractMultiRange[int]): + """Represent the PostgreSQL INT8MULTIRANGE type.""" + + __visit_name__ = "INT8MULTIRANGE" + + +class NUMMULTIRANGE(AbstractMultiRange[Decimal]): + """Represent the PostgreSQL NUMMULTIRANGE type.""" + + __visit_name__ = "NUMMULTIRANGE" + + +class DATEMULTIRANGE(AbstractMultiRange[date]): + """Represent the PostgreSQL DATEMULTIRANGE type.""" + + __visit_name__ = "DATEMULTIRANGE" + + +class TSMULTIRANGE(AbstractMultiRange[datetime]): + """Represent the PostgreSQL TSRANGE type.""" + + __visit_name__ = "TSMULTIRANGE" + + +class TSTZMULTIRANGE(AbstractMultiRange[datetime]): + """Represent the PostgreSQL TSTZRANGE type.""" + + __visit_name__ = "TSTZMULTIRANGE" + + +_max_int_32 = 2**31 - 1 +_min_int_32 = -(2**31) + + +def _is_int32(r: Range[int]) -> bool: + return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and ( + r.upper is None or _min_int_32 <= r.upper <= _max_int_32 + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/types.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/types.py new file mode 100644 index 0000000..2acf63b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/postgresql/types.py @@ -0,0 +1,303 @@ +# dialects/postgresql/types.py +# Copyright (C) 2013-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import datetime as dt +from typing import Any +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from uuid import UUID as _python_UUID + +from ...sql import sqltypes +from ...sql import type_api +from ...util.typing import Literal + +if TYPE_CHECKING: + from ...engine.interfaces import Dialect + from ...sql.operators import OperatorType + from ...sql.type_api import _LiteralProcessorType + from ...sql.type_api import TypeEngine + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + + +class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]): + render_bind_cast = True + render_literal_cast = True + + if TYPE_CHECKING: + + @overload + def __init__( + self: PGUuid[_python_UUID], as_uuid: Literal[True] = ... + ) -> None: ... + + @overload + def __init__( + self: PGUuid[str], as_uuid: Literal[False] = ... + ) -> None: ... + + def __init__(self, as_uuid: bool = True) -> None: ... + + +class BYTEA(sqltypes.LargeBinary): + __visit_name__ = "BYTEA" + + +class INET(sqltypes.TypeEngine[str]): + __visit_name__ = "INET" + + +PGInet = INET + + +class CIDR(sqltypes.TypeEngine[str]): + __visit_name__ = "CIDR" + + +PGCidr = CIDR + + +class MACADDR(sqltypes.TypeEngine[str]): + __visit_name__ = "MACADDR" + + +PGMacAddr = MACADDR + + +class MACADDR8(sqltypes.TypeEngine[str]): + __visit_name__ = "MACADDR8" + + +PGMacAddr8 = MACADDR8 + + +class MONEY(sqltypes.TypeEngine[str]): + r"""Provide the PostgreSQL MONEY type. + + Depending on driver, result rows using this type may return a + string value which includes currency symbols. + + For this reason, it may be preferable to provide conversion to a + numerically-based currency datatype using :class:`_types.TypeDecorator`:: + + import re + import decimal + from sqlalchemy import Dialect + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def process_result_value( + self, value: Any, dialect: Dialect + ) -> None: + if value is not None: + # adjust this for the currency and numeric + m = re.match(r"\$([\d.]+)", value) + if m: + value = decimal.Decimal(m.group(1)) + return value + + Alternatively, the conversion may be applied as a CAST using + the :meth:`_types.TypeDecorator.column_expression` method as follows:: + + import decimal + from sqlalchemy import cast + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def column_expression(self, column: Any): + return cast(column, Numeric()) + + .. versionadded:: 1.2 + + """ + + __visit_name__ = "MONEY" + + +class OID(sqltypes.TypeEngine[int]): + """Provide the PostgreSQL OID type.""" + + __visit_name__ = "OID" + + +class REGCONFIG(sqltypes.TypeEngine[str]): + """Provide the PostgreSQL REGCONFIG type. + + .. versionadded:: 2.0.0rc1 + + """ + + __visit_name__ = "REGCONFIG" + + +class TSQUERY(sqltypes.TypeEngine[str]): + """Provide the PostgreSQL TSQUERY type. + + .. versionadded:: 2.0.0rc1 + + """ + + __visit_name__ = "TSQUERY" + + +class REGCLASS(sqltypes.TypeEngine[str]): + """Provide the PostgreSQL REGCLASS type. + + .. versionadded:: 1.2.7 + + """ + + __visit_name__ = "REGCLASS" + + +class TIMESTAMP(sqltypes.TIMESTAMP): + """Provide the PostgreSQL TIMESTAMP type.""" + + __visit_name__ = "TIMESTAMP" + + def __init__( + self, timezone: bool = False, precision: Optional[int] = None + ) -> None: + """Construct a TIMESTAMP. + + :param timezone: boolean value if timezone present, default False + :param precision: optional integer precision value + + .. versionadded:: 1.4 + + """ + super().__init__(timezone=timezone) + self.precision = precision + + +class TIME(sqltypes.TIME): + """PostgreSQL TIME type.""" + + __visit_name__ = "TIME" + + def __init__( + self, timezone: bool = False, precision: Optional[int] = None + ) -> None: + """Construct a TIME. + + :param timezone: boolean value if timezone present, default False + :param precision: optional integer precision value + + .. versionadded:: 1.4 + + """ + super().__init__(timezone=timezone) + self.precision = precision + + +class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): + """PostgreSQL INTERVAL type.""" + + __visit_name__ = "INTERVAL" + native = True + + def __init__( + self, precision: Optional[int] = None, fields: Optional[str] = None + ) -> None: + """Construct an INTERVAL. + + :param precision: optional integer precision value + :param fields: string fields specifier. allows storage of fields + to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, + etc. + + .. versionadded:: 1.2 + + """ + self.precision = precision + self.fields = fields + + @classmethod + def adapt_emulated_to_native( + cls, interval: sqltypes.Interval, **kw: Any # type: ignore[override] + ) -> INTERVAL: + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self) -> Type[sqltypes.Interval]: + return sqltypes.Interval + + def as_generic(self, allow_nulltype: bool = False) -> sqltypes.Interval: + return sqltypes.Interval(native=True, second_precision=self.precision) + + @property + def python_type(self) -> Type[dt.timedelta]: + return dt.timedelta + + def literal_processor( + self, dialect: Dialect + ) -> Optional[_LiteralProcessorType[dt.timedelta]]: + def process(value: dt.timedelta) -> str: + return f"make_interval(secs=>{value.total_seconds()})" + + return process + + +PGInterval = INTERVAL + + +class BIT(sqltypes.TypeEngine[int]): + __visit_name__ = "BIT" + + def __init__( + self, length: Optional[int] = None, varying: bool = False + ) -> None: + if varying: + # BIT VARYING can be unlimited-length, so no default + self.length = length + else: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + self.varying = varying + + +PGBit = BIT + + +class TSVECTOR(sqltypes.TypeEngine[str]): + """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. seealso:: + + :ref:`postgresql_match` + + """ + + __visit_name__ = "TSVECTOR" + + +class CITEXT(sqltypes.TEXT): + """Provide the PostgreSQL CITEXT type. + + .. versionadded:: 2.0.7 + + """ + + __visit_name__ = "CITEXT" + + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + return self diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__init__.py new file mode 100644 index 0000000..45f088e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__init__.py @@ -0,0 +1,57 @@ +# dialects/sqlite/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from . import aiosqlite # noqa +from . import base # noqa +from . import pysqlcipher # noqa +from . import pysqlite # noqa +from .base import BLOB +from .base import BOOLEAN +from .base import CHAR +from .base import DATE +from .base import DATETIME +from .base import DECIMAL +from .base import FLOAT +from .base import INTEGER +from .base import JSON +from .base import NUMERIC +from .base import REAL +from .base import SMALLINT +from .base import TEXT +from .base import TIME +from .base import TIMESTAMP +from .base import VARCHAR +from .dml import Insert +from .dml import insert + +# default dialect +base.dialect = dialect = pysqlite.dialect + + +__all__ = ( + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "JSON", + "NUMERIC", + "SMALLINT", + "TEXT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "REAL", + "Insert", + "insert", + "dialect", +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..e4a9b51 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/aiosqlite.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/aiosqlite.cpython-311.pyc new file mode 100644 index 0000000..41466a4 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/aiosqlite.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..e7f5c22 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/dml.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/dml.cpython-311.pyc new file mode 100644 index 0000000..eb0f448 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/dml.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/json.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/json.cpython-311.pyc new file mode 100644 index 0000000..ad4323c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/json.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/provision.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/provision.cpython-311.pyc new file mode 100644 index 0000000..d139ba3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/provision.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlcipher.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlcipher.cpython-311.pyc new file mode 100644 index 0000000..d26e7b3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlcipher.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlite.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlite.cpython-311.pyc new file mode 100644 index 0000000..df08288 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/__pycache__/pysqlite.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py new file mode 100644 index 0000000..6c91563 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -0,0 +1,396 @@ +# dialects/sqlite/aiosqlite.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" + +.. dialect:: sqlite+aiosqlite + :name: aiosqlite + :dbapi: aiosqlite + :connectstring: sqlite+aiosqlite:///file_path + :url: https://pypi.org/project/aiosqlite/ + +The aiosqlite dialect provides support for the SQLAlchemy asyncio interface +running on top of pysqlite. + +aiosqlite is a wrapper around pysqlite that uses a background thread for +each connection. It does not actually use non-blocking IO, as SQLite +databases are not socket-based. However it does provide a working asyncio +interface that's useful for testing and prototyping purposes. + +Using a special asyncio mediation layer, the aiosqlite dialect is usable +as the backend for the :ref:`SQLAlchemy asyncio ` +extension package. + +This dialect should normally be used only with the +:func:`_asyncio.create_async_engine` engine creation function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("sqlite+aiosqlite:///filename") + +The URL passes through all arguments to the ``pysqlite`` driver, so all +connection arguments are the same as they are for that of :ref:`pysqlite`. + +.. _aiosqlite_udfs: + +User-Defined Functions +---------------------- + +aiosqlite extends pysqlite to support async, so we can create our own user-defined functions (UDFs) +in Python and use them directly in SQLite queries as described here: :ref:`pysqlite_udfs`. + +.. _aiosqlite_serializable: + +Serializable isolation / Savepoints / Transactional DDL (asyncio version) +------------------------------------------------------------------------- + +Similarly to pysqlite, aiosqlite does not support SAVEPOINT feature. + +The solution is similar to :ref:`pysqlite_serializable`. This is achieved by the event listeners in async:: + + from sqlalchemy import create_engine, event + from sqlalchemy.ext.asyncio import create_async_engine + + engine = create_async_engine("sqlite+aiosqlite:///myfile.db") + + @event.listens_for(engine.sync_engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable aiosqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(engine.sync_engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN") + +.. warning:: When using the above recipe, it is advised to not use the + :paramref:`.Connection.execution_options.isolation_level` setting on + :class:`_engine.Connection` and :func:`_sa.create_engine` + with the SQLite driver, + as this function necessarily will also alter the ".isolation_level" setting. + +""" # noqa + +import asyncio +from functools import partial + +from .base import SQLiteExecutionContext +from .pysqlite import SQLiteDialect_pysqlite +from ... import pool +from ... import util +from ...engine import AdaptedConnection +from ...util.concurrency import await_fallback +from ...util.concurrency import await_only + + +class AsyncAdapt_aiosqlite_cursor: + # TODO: base on connectors/asyncio.py + # see #10415 + + __slots__ = ( + "_adapt_connection", + "_connection", + "description", + "await_", + "_rows", + "arraysize", + "rowcount", + "lastrowid", + ) + + server_side = False + + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + self.arraysize = 1 + self.rowcount = -1 + self.description = None + self._rows = [] + + def close(self): + self._rows[:] = [] + + def execute(self, operation, parameters=None): + try: + _cursor = self.await_(self._connection.cursor()) + + if parameters is None: + self.await_(_cursor.execute(operation)) + else: + self.await_(_cursor.execute(operation, parameters)) + + if _cursor.description: + self.description = _cursor.description + self.lastrowid = self.rowcount = -1 + + if not self.server_side: + self._rows = self.await_(_cursor.fetchall()) + else: + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + + if not self.server_side: + self.await_(_cursor.close()) + else: + self._cursor = _cursor + except Exception as error: + self._adapt_connection._handle_exception(error) + + def executemany(self, operation, seq_of_parameters): + try: + _cursor = self.await_(self._connection.cursor()) + self.await_(_cursor.executemany(operation, seq_of_parameters)) + self.description = None + self.lastrowid = _cursor.lastrowid + self.rowcount = _cursor.rowcount + self.await_(_cursor.close()) + except Exception as error: + self._adapt_connection._handle_exception(error) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor): + # TODO: base on connectors/asyncio.py + # see #10415 + __slots__ = "_cursor" + + server_side = True + + def __init__(self, *arg, **kw): + super().__init__(*arg, **kw) + self._cursor = None + + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiosqlite_connection(AdaptedConnection): + await_ = staticmethod(await_only) + __slots__ = ("dbapi",) + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + + @property + def isolation_level(self): + return self._connection.isolation_level + + @isolation_level.setter + def isolation_level(self, value): + # aiosqlite's isolation_level setter works outside the Thread + # that it's supposed to, necessitating setting check_same_thread=False. + # for improved stability, we instead invent our own awaitable version + # using aiosqlite's async queue directly. + + def set_iso(connection, value): + connection.isolation_level = value + + function = partial(set_iso, self._connection._conn, value) + future = asyncio.get_event_loop().create_future() + + self._connection._tx.put_nowait((future, function)) + + try: + return self.await_(future) + except Exception as error: + self._handle_exception(error) + + def create_function(self, *args, **kw): + try: + self.await_(self._connection.create_function(*args, **kw)) + except Exception as error: + self._handle_exception(error) + + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiosqlite_ss_cursor(self) + else: + return AsyncAdapt_aiosqlite_cursor(self) + + def execute(self, *args, **kw): + return self.await_(self._connection.execute(*args, **kw)) + + def rollback(self): + try: + self.await_(self._connection.rollback()) + except Exception as error: + self._handle_exception(error) + + def commit(self): + try: + self.await_(self._connection.commit()) + except Exception as error: + self._handle_exception(error) + + def close(self): + try: + self.await_(self._connection.close()) + except ValueError: + # this is undocumented for aiosqlite, that ValueError + # was raised if .close() was called more than once, which is + # both not customary for DBAPI and is also not a DBAPI.Error + # exception. This is now fixed in aiosqlite via my PR + # https://github.com/omnilib/aiosqlite/pull/238, so we can be + # assured this will not become some other kind of exception, + # since it doesn't raise anymore. + + pass + except Exception as error: + self._handle_exception(error) + + def _handle_exception(self, error): + if ( + isinstance(error, ValueError) + and error.args[0] == "no active connection" + ): + raise self.dbapi.sqlite.OperationalError( + "no active connection" + ) from error + else: + raise error + + +class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection): + __slots__ = () + + await_ = staticmethod(await_fallback) + + +class AsyncAdapt_aiosqlite_dbapi: + def __init__(self, aiosqlite, sqlite): + self.aiosqlite = aiosqlite + self.sqlite = sqlite + self.paramstyle = "qmark" + self._init_dbapi_attributes() + + def _init_dbapi_attributes(self): + for name in ( + "DatabaseError", + "Error", + "IntegrityError", + "NotSupportedError", + "OperationalError", + "ProgrammingError", + "sqlite_version", + "sqlite_version_info", + ): + setattr(self, name, getattr(self.aiosqlite, name)) + + for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"): + setattr(self, name, getattr(self.sqlite, name)) + + for name in ("Binary",): + setattr(self, name, getattr(self.sqlite, name)) + + def connect(self, *arg, **kw): + async_fallback = kw.pop("async_fallback", False) + + creator_fn = kw.pop("async_creator_fn", None) + if creator_fn: + connection = creator_fn(*arg, **kw) + else: + connection = self.aiosqlite.connect(*arg, **kw) + # it's a Thread. you'll thank us later + connection.daemon = True + + if util.asbool(async_fallback): + return AsyncAdaptFallback_aiosqlite_connection( + self, + await_fallback(connection), + ) + else: + return AsyncAdapt_aiosqlite_connection( + self, + await_only(connection), + ) + + +class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): + def create_server_side_cursor(self): + return self._dbapi_connection.cursor(server_side=True) + + +class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): + driver = "aiosqlite" + supports_statement_cache = True + + is_async = True + + supports_server_side_cursors = True + + execution_ctx_cls = SQLiteExecutionContext_aiosqlite + + @classmethod + def import_dbapi(cls): + return AsyncAdapt_aiosqlite_dbapi( + __import__("aiosqlite"), __import__("sqlite3") + ) + + @classmethod + def get_pool_class(cls, url): + if cls._is_url_file_db(url): + return pool.NullPool + else: + return pool.StaticPool + + def is_disconnect(self, e, connection, cursor): + if isinstance( + e, self.dbapi.OperationalError + ) and "no active connection" in str(e): + return True + + return super().is_disconnect(e, connection, cursor) + + def get_driver_connection(self, connection): + return connection._connection + + +dialect = SQLiteDialect_aiosqlite diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/base.py new file mode 100644 index 0000000..6db8214 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/base.py @@ -0,0 +1,2782 @@ +# dialects/sqlite/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: sqlite + :name: SQLite + :full_support: 3.36.0 + :normal_support: 3.12+ + :best_effort: 3.7.16+ + +.. _sqlite_datetime: + +Date and Time Types +------------------- + +SQLite does not have built-in DATE, TIME, or DATETIME types, and pysqlite does +not provide out of the box functionality for translating values between Python +`datetime` objects and a SQLite-supported format. SQLAlchemy's own +:class:`~sqlalchemy.types.DateTime` and related types provide date formatting +and parsing functionality when SQLite is used. The implementation classes are +:class:`_sqlite.DATETIME`, :class:`_sqlite.DATE` and :class:`_sqlite.TIME`. +These types represent dates and times as ISO formatted strings, which also +nicely support ordering. There's no reliance on typical "libc" internals for +these functions so historical dates are fully supported. + +Ensuring Text affinity +^^^^^^^^^^^^^^^^^^^^^^ + +The DDL rendered for these types is the standard ``DATE``, ``TIME`` +and ``DATETIME`` indicators. However, custom storage formats can also be +applied to these types. When the +storage format is detected as containing no alpha characters, the DDL for +these types is rendered as ``DATE_CHAR``, ``TIME_CHAR``, and ``DATETIME_CHAR``, +so that the column continues to have textual affinity. + +.. seealso:: + + `Type Affinity `_ - + in the SQLite documentation + +.. _sqlite_autoincrement: + +SQLite Auto Incrementing Behavior +---------------------------------- + +Background on SQLite's autoincrement is at: https://sqlite.org/autoinc.html + +Key concepts: + +* SQLite has an implicit "auto increment" feature that takes place for any + non-composite primary-key column that is specifically created using + "INTEGER PRIMARY KEY" for the type + primary key. + +* SQLite also has an explicit "AUTOINCREMENT" keyword, that is **not** + equivalent to the implicit autoincrement feature; this keyword is not + recommended for general use. SQLAlchemy does not render this keyword + unless a special SQLite-specific directive is used (see below). However, + it still requires that the column's type is named "INTEGER". + +Using the AUTOINCREMENT Keyword +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To specifically render the AUTOINCREMENT keyword on the primary key column +when rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table +construct:: + + Table('sometable', metadata, + Column('id', Integer, primary_key=True), + sqlite_autoincrement=True) + +Allowing autoincrement behavior SQLAlchemy types other than Integer/INTEGER +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +SQLite's typing model is based on naming conventions. Among other things, this +means that any type name which contains the substring ``"INT"`` will be +determined to be of "integer affinity". A type named ``"BIGINT"``, +``"SPECIAL_INT"`` or even ``"XYZINTQPR"``, will be considered by SQLite to be +of "integer" affinity. However, **the SQLite autoincrement feature, whether +implicitly or explicitly enabled, requires that the name of the column's type +is exactly the string "INTEGER"**. Therefore, if an application uses a type +like :class:`.BigInteger` for a primary key, on SQLite this type will need to +be rendered as the name ``"INTEGER"`` when emitting the initial ``CREATE +TABLE`` statement in order for the autoincrement behavior to be available. + +One approach to achieve this is to use :class:`.Integer` on SQLite +only using :meth:`.TypeEngine.with_variant`:: + + table = Table( + "my_table", metadata, + Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) + ) + +Another is to use a subclass of :class:`.BigInteger` that overrides its DDL +name to be ``INTEGER`` when compiled against SQLite:: + + from sqlalchemy import BigInteger + from sqlalchemy.ext.compiler import compiles + + class SLBigInteger(BigInteger): + pass + + @compiles(SLBigInteger, 'sqlite') + def bi_c(element, compiler, **kw): + return "INTEGER" + + @compiles(SLBigInteger) + def bi_c(element, compiler, **kw): + return compiler.visit_BIGINT(element, **kw) + + + table = Table( + "my_table", metadata, + Column("id", SLBigInteger(), primary_key=True) + ) + +.. seealso:: + + :meth:`.TypeEngine.with_variant` + + :ref:`sqlalchemy.ext.compiler_toplevel` + + `Datatypes In SQLite Version 3 `_ + +.. _sqlite_concurrency: + +Database Locking Behavior / Concurrency +--------------------------------------- + +SQLite is not designed for a high level of write concurrency. The database +itself, being a file, is locked completely during write operations within +transactions, meaning exactly one "connection" (in reality a file handle) +has exclusive access to the database during this period - all other +"connections" will be blocked during this time. + +The Python DBAPI specification also calls for a connection model that is +always in a transaction; there is no ``connection.begin()`` method, +only ``connection.commit()`` and ``connection.rollback()``, upon which a +new transaction is to be begun immediately. This may seem to imply +that the SQLite driver would in theory allow only a single filehandle on a +particular database file at any time; however, there are several +factors both within SQLite itself as well as within the pysqlite driver +which loosen this restriction significantly. + +However, no matter what locking modes are used, SQLite will still always +lock the database file once a transaction is started and DML (e.g. INSERT, +UPDATE, DELETE) has at least been emitted, and this will block +other transactions at least at the point that they also attempt to emit DML. +By default, the length of time on this block is very short before it times out +with an error. + +This behavior becomes more critical when used in conjunction with the +SQLAlchemy ORM. SQLAlchemy's :class:`.Session` object by default runs +within a transaction, and with its autoflush model, may emit DML preceding +any SELECT statement. This may lead to a SQLite database that locks +more quickly than is expected. The locking mode of SQLite and the pysqlite +driver can be manipulated to some degree, however it should be noted that +achieving a high degree of write-concurrency with SQLite is a losing battle. + +For more information on SQLite's lack of write concurrency by design, please +see +`Situations Where Another RDBMS May Work Better - High Concurrency +`_ near the bottom of the page. + +The following subsections introduce areas that are impacted by SQLite's +file-based architecture and additionally will usually require workarounds to +work when using the pysqlite driver. + +.. _sqlite_isolation_level: + +Transaction Isolation Level / Autocommit +---------------------------------------- + +SQLite supports "transaction isolation" in a non-standard way, along two +axes. One is that of the +`PRAGMA read_uncommitted `_ +instruction. This setting can essentially switch SQLite between its +default mode of ``SERIALIZABLE`` isolation, and a "dirty read" isolation +mode normally referred to as ``READ UNCOMMITTED``. + +SQLAlchemy ties into this PRAGMA statement using the +:paramref:`_sa.create_engine.isolation_level` parameter of +:func:`_sa.create_engine`. +Valid values for this parameter when used with SQLite are ``"SERIALIZABLE"`` +and ``"READ UNCOMMITTED"`` corresponding to a value of 0 and 1, respectively. +SQLite defaults to ``SERIALIZABLE``, however its behavior is impacted by +the pysqlite driver's default behavior. + +When using the pysqlite driver, the ``"AUTOCOMMIT"`` isolation level is also +available, which will alter the pysqlite connection using the ``.isolation_level`` +attribute on the DBAPI connection and set it to None for the duration +of the setting. + +.. versionadded:: 1.3.16 added support for SQLite AUTOCOMMIT isolation level + when using the pysqlite / sqlite3 SQLite driver. + + +The other axis along which SQLite's transactional locking is impacted is +via the nature of the ``BEGIN`` statement used. The three varieties +are "deferred", "immediate", and "exclusive", as described at +`BEGIN TRANSACTION `_. A straight +``BEGIN`` statement uses the "deferred" mode, where the database file is +not locked until the first read or write operation, and read access remains +open to other transactions until the first write operation. But again, +it is critical to note that the pysqlite driver interferes with this behavior +by *not even emitting BEGIN* until the first write operation. + +.. warning:: + + SQLite's transactional scope is impacted by unresolved + issues in the pysqlite driver, which defers BEGIN statements to a greater + degree than is often feasible. See the section :ref:`pysqlite_serializable` + or :ref:`aiosqlite_serializable` for techniques to work around this behavior. + +.. seealso:: + + :ref:`dbapi_autocommit` + +INSERT/UPDATE/DELETE...RETURNING +--------------------------------- + +The SQLite dialect supports SQLite 3.35's ``INSERT|UPDATE|DELETE..RETURNING`` +syntax. ``INSERT..RETURNING`` may be used +automatically in some cases in order to fetch newly generated identifiers in +place of the traditional approach of using ``cursor.lastrowid``, however +``cursor.lastrowid`` is currently still preferred for simple single-statement +cases for its better performance. + +To specify an explicit ``RETURNING`` clause, use the +:meth:`._UpdateBase.returning` method on a per-statement basis:: + + # INSERT..RETURNING + result = connection.execute( + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # UPDATE..RETURNING + result = connection.execute( + table.update(). + where(table.c.name=='foo'). + values(name='bar'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + + # DELETE..RETURNING + result = connection.execute( + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) + ) + print(result.all()) + +.. versionadded:: 2.0 Added support for SQLite RETURNING + +SAVEPOINT Support +---------------------------- + +SQLite supports SAVEPOINTs, which only function once a transaction is +begun. SQLAlchemy's SAVEPOINT support is available using the +:meth:`_engine.Connection.begin_nested` method at the Core level, and +:meth:`.Session.begin_nested` at the ORM level. However, SAVEPOINTs +won't work at all with pysqlite unless workarounds are taken. + +.. warning:: + + SQLite's SAVEPOINT feature is impacted by unresolved + issues in the pysqlite and aiosqlite drivers, which defer BEGIN statements + to a greater degree than is often feasible. See the sections + :ref:`pysqlite_serializable` and :ref:`aiosqlite_serializable` + for techniques to work around this behavior. + +Transactional DDL +---------------------------- + +The SQLite database supports transactional :term:`DDL` as well. +In this case, the pysqlite driver is not only failing to start transactions, +it also is ending any existing transaction when DDL is detected, so again, +workarounds are required. + +.. warning:: + + SQLite's transactional DDL is impacted by unresolved issues + in the pysqlite driver, which fails to emit BEGIN and additionally + forces a COMMIT to cancel any transaction when DDL is encountered. + See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. + +.. _sqlite_foreign_keys: + +Foreign Key Support +------------------- + +SQLite supports FOREIGN KEY syntax when emitting CREATE statements for tables, +however by default these constraints have no effect on the operation of the +table. + +Constraint checking on SQLite has three prerequisites: + +* At least version 3.6.19 of SQLite must be in use +* The SQLite library must be compiled *without* the SQLITE_OMIT_FOREIGN_KEY + or SQLITE_OMIT_TRIGGER symbols enabled. +* The ``PRAGMA foreign_keys = ON`` statement must be emitted on all + connections before use -- including the initial call to + :meth:`sqlalchemy.schema.MetaData.create_all`. + +SQLAlchemy allows for the ``PRAGMA`` statement to be emitted automatically for +new connections through the usage of events:: + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + +.. warning:: + + When SQLite foreign keys are enabled, it is **not possible** + to emit CREATE or DROP statements for tables that contain + mutually-dependent foreign key constraints; + to emit the DDL for these tables requires that ALTER TABLE be used to + create or drop these constraints separately, for which SQLite has + no support. + +.. seealso:: + + `SQLite Foreign Key Support `_ + - on the SQLite web site. + + :ref:`event_toplevel` - SQLAlchemy event API. + + :ref:`use_alter` - more information on SQLAlchemy's facilities for handling + mutually-dependent foreign key constraints. + +.. _sqlite_on_conflict_ddl: + +ON CONFLICT support for constraints +----------------------------------- + +.. seealso:: This section describes the :term:`DDL` version of "ON CONFLICT" for + SQLite, which occurs within a CREATE TABLE statement. For "ON CONFLICT" as + applied to an INSERT statement, see :ref:`sqlite_on_conflict_insert`. + +SQLite supports a non-standard DDL clause known as ON CONFLICT which can be applied +to primary key, unique, check, and not null constraints. In DDL, it is +rendered either within the "CONSTRAINT" clause or within the column definition +itself depending on the location of the target constraint. To render this +clause within DDL, the extension parameter ``sqlite_on_conflict`` can be +specified with a string conflict resolution algorithm within the +:class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, +:class:`.CheckConstraint` objects. Within the :class:`_schema.Column` object, +there +are individual parameters ``sqlite_on_conflict_not_null``, +``sqlite_on_conflict_primary_key``, ``sqlite_on_conflict_unique`` which each +correspond to the three types of relevant constraint types that can be +indicated from a :class:`_schema.Column` object. + +.. seealso:: + + `ON CONFLICT `_ - in the SQLite + documentation + +.. versionadded:: 1.3 + + +The ``sqlite_on_conflict`` parameters accept a string argument which is just +the resolution name to be chosen, which on SQLite can be one of ROLLBACK, +ABORT, FAIL, IGNORE, and REPLACE. For example, to add a UNIQUE constraint +that specifies the IGNORE algorithm:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + UniqueConstraint('id', 'data', sqlite_on_conflict='IGNORE') + ) + +The above renders CREATE TABLE DDL as:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER, + PRIMARY KEY (id), + UNIQUE (id, data) ON CONFLICT IGNORE + ) + + +When using the :paramref:`_schema.Column.unique` +flag to add a UNIQUE constraint +to a single column, the ``sqlite_on_conflict_unique`` parameter can +be added to the :class:`_schema.Column` as well, which will be added to the +UNIQUE constraint in the DDL:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, unique=True, + sqlite_on_conflict_unique='IGNORE') + ) + +rendering:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER, + PRIMARY KEY (id), + UNIQUE (data) ON CONFLICT IGNORE + ) + +To apply the FAIL algorithm for a NOT NULL constraint, +``sqlite_on_conflict_not_null`` is used:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, nullable=False, + sqlite_on_conflict_not_null='FAIL') + ) + +this renders the column inline ON CONFLICT phrase:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + data INTEGER NOT NULL ON CONFLICT FAIL, + PRIMARY KEY (id) + ) + + +Similarly, for an inline primary key, use ``sqlite_on_conflict_primary_key``:: + + some_table = Table( + 'some_table', metadata, + Column('id', Integer, primary_key=True, + sqlite_on_conflict_primary_key='FAIL') + ) + +SQLAlchemy renders the PRIMARY KEY constraint separately, so the conflict +resolution algorithm is applied to the constraint itself:: + + CREATE TABLE some_table ( + id INTEGER NOT NULL, + PRIMARY KEY (id) ON CONFLICT FAIL + ) + +.. _sqlite_on_conflict_insert: + +INSERT...ON CONFLICT (Upsert) +----------------------------------- + +.. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for + SQLite, which occurs within an INSERT statement. For "ON CONFLICT" as + applied to a CREATE TABLE statement, see :ref:`sqlite_on_conflict_ddl`. + +From version 3.24.0 onwards, SQLite supports "upserts" (update or insert) +of rows into a table via the ``ON CONFLICT`` clause of the ``INSERT`` +statement. A candidate row will only be inserted if that row does not violate +any unique or primary key constraints. In the case of a unique constraint violation, a +secondary action can occur which can be either "DO UPDATE", indicating that +the data in the target row should be updated, or "DO NOTHING", which indicates +to silently skip this row. + +Conflicts are determined using columns that are part of existing unique +constraints and indexes. These constraints are identified by stating the +columns and conditions that comprise the indexes. + +SQLAlchemy provides ``ON CONFLICT`` support via the SQLite-specific +:func:`_sqlite.insert()` function, which provides +the generative methods :meth:`_sqlite.Insert.on_conflict_do_update` +and :meth:`_sqlite.Insert.on_conflict_do_nothing`: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.dialects.sqlite import insert + + >>> insert_stmt = insert(my_table).values( + ... id='some_existing_id', + ... data='inserted value') + + >>> do_update_stmt = insert_stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?{stop} + + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] + ... ) + + >>> print(do_nothing_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO NOTHING + +.. versionadded:: 1.4 + +.. seealso:: + + `Upsert + `_ + - in the SQLite documentation. + + +Specifying the Target +^^^^^^^^^^^^^^^^^^^^^ + +Both methods supply the "target" of the conflict using column inference: + +* The :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` argument + specifies a sequence containing string column names, :class:`_schema.Column` + objects, and/or SQL expression elements, which would identify a unique index + or unique constraint. + +* When using :paramref:`_sqlite.Insert.on_conflict_do_update.index_elements` + to infer an index, a partial index can be inferred by also specifying the + :paramref:`_sqlite.Insert.on_conflict_do_update.index_where` parameter: + + .. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=[my_table.c.user_email], + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (data, user_email) VALUES (?, ?) + ON CONFLICT (user_email) + WHERE user_email LIKE '%@gmail.com' + DO UPDATE SET data = excluded.data + +The SET Clause +^^^^^^^^^^^^^^^ + +``ON CONFLICT...DO UPDATE`` is used to perform an update of the already +existing row, using any combination of new values as well as values +from the proposed insertion. These values are specified using the +:paramref:`_sqlite.Insert.on_conflict_do_update.set_` parameter. This +parameter accepts a dictionary which consists of direct values +for UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value') + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) + ON CONFLICT (id) DO UPDATE SET data = ? + +.. warning:: + + The :meth:`_sqlite.Insert.on_conflict_do_update` method does **not** take + into account Python-side default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. These + values will not be exercised for an ON CONFLICT style of UPDATE, unless + they are manually specified in the + :paramref:`_sqlite.Insert.on_conflict_do_update.set_` dictionary. + +Updating using the Excluded INSERT Values +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to refer to the proposed insertion row, the special alias +:attr:`~.sqlite.Insert.excluded` is available as an attribute on +the :class:`_sqlite.Insert` object; this object creates an "excluded." prefix +on a column, that informs the DO UPDATE to update the row with the value that +would have been inserted had the constraint not failed: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + + >>> do_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) + ... ) + + >>> print(do_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author + +Additional WHERE Criteria +^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`_sqlite.Insert.on_conflict_do_update` method also accepts +a WHERE clause using the :paramref:`_sqlite.Insert.on_conflict_do_update.where` +parameter, which will limit those rows which receive an UPDATE: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values( + ... id='some_id', + ... data='inserted value', + ... author='jlh' + ... ) + + >>> on_update_stmt = stmt.on_conflict_do_update( + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) + ... ) + >>> print(on_update_stmt) + {printsql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) + ON CONFLICT (id) DO UPDATE SET data = ?, author = excluded.author + WHERE my_table.status = ? + + +Skipping Rows with DO NOTHING +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ON CONFLICT`` may be used to skip inserting a row entirely +if any conflict with a unique constraint occurs; below this is illustrated +using the :meth:`_sqlite.Insert.on_conflict_do_nothing` method: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING + + +If ``DO NOTHING`` is used without specifying any columns or constraint, +it has the effect of skipping the INSERT for any unique violation which +occurs: + +.. sourcecode:: pycon+sql + + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing() + >>> print(stmt) + {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING + +.. _sqlite_type_reflection: + +Type Reflection +--------------- + +SQLite types are unlike those of most other database backends, in that +the string name of the type usually does not correspond to a "type" in a +one-to-one fashion. Instead, SQLite links per-column typing behavior +to one of five so-called "type affinities" based on a string matching +pattern for the type. + +SQLAlchemy's reflection process, when inspecting types, uses a simple +lookup table to link the keywords returned to provided SQLAlchemy types. +This lookup table is present within the SQLite dialect as it is for all +other dialects. However, the SQLite dialect has a different "fallback" +routine for when a particular type name is not located in the lookup map; +it instead implements the SQLite "type affinity" scheme located at +https://www.sqlite.org/datatype3.html section 2.1. + +The provided typemap will make direct associations from an exact string +name match for the following types: + +:class:`_types.BIGINT`, :class:`_types.BLOB`, +:class:`_types.BOOLEAN`, :class:`_types.BOOLEAN`, +:class:`_types.CHAR`, :class:`_types.DATE`, +:class:`_types.DATETIME`, :class:`_types.FLOAT`, +:class:`_types.DECIMAL`, :class:`_types.FLOAT`, +:class:`_types.INTEGER`, :class:`_types.INTEGER`, +:class:`_types.NUMERIC`, :class:`_types.REAL`, +:class:`_types.SMALLINT`, :class:`_types.TEXT`, +:class:`_types.TIME`, :class:`_types.TIMESTAMP`, +:class:`_types.VARCHAR`, :class:`_types.NVARCHAR`, +:class:`_types.NCHAR` + +When a type name does not match one of the above types, the "type affinity" +lookup is used instead: + +* :class:`_types.INTEGER` is returned if the type name includes the + string ``INT`` +* :class:`_types.TEXT` is returned if the type name includes the + string ``CHAR``, ``CLOB`` or ``TEXT`` +* :class:`_types.NullType` is returned if the type name includes the + string ``BLOB`` +* :class:`_types.REAL` is returned if the type name includes the string + ``REAL``, ``FLOA`` or ``DOUB``. +* Otherwise, the :class:`_types.NUMERIC` type is used. + +.. _sqlite_partial_index: + +Partial Indexes +--------------- + +A partial index, e.g. one which uses a WHERE clause, can be specified +with the DDL system using the argument ``sqlite_where``:: + + tbl = Table('testtbl', m, Column('data', Integer)) + idx = Index('test_idx1', tbl.c.data, + sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10)) + +The index will be rendered at create time as:: + + CREATE INDEX test_idx1 ON testtbl (data) + WHERE data > 5 AND data < 10 + +.. _sqlite_dotted_column_names: + +Dotted Column Names +------------------- + +Using table or column names that explicitly have periods in them is +**not recommended**. While this is generally a bad idea for relational +databases in general, as the dot is a syntactically significant character, +the SQLite driver up until version **3.10.0** of SQLite has a bug which +requires that SQLAlchemy filter out these dots in result sets. + +The bug, entirely outside of SQLAlchemy, can be illustrated thusly:: + + import sqlite3 + + assert sqlite3.sqlite_version_info < (3, 10, 0), "bug is fixed in this version" + + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + cursor.execute("create table x (a integer, b integer)") + cursor.execute("insert into x (a, b) values (1, 1)") + cursor.execute("insert into x (a, b) values (2, 2)") + + cursor.execute("select x.a, x.b from x") + assert [c[0] for c in cursor.description] == ['a', 'b'] + + cursor.execute(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert [c[0] for c in cursor.description] == ['a', 'b'], \ + [c[0] for c in cursor.description] + +The second assertion fails:: + + Traceback (most recent call last): + File "test.py", line 19, in + [c[0] for c in cursor.description] + AssertionError: ['x.a', 'x.b'] + +Where above, the driver incorrectly reports the names of the columns +including the name of the table, which is entirely inconsistent vs. +when the UNION is not present. + +SQLAlchemy relies upon column names being predictable in how they match +to the original statement, so the SQLAlchemy dialect has no choice but +to filter these out:: + + + from sqlalchemy import create_engine + + eng = create_engine("sqlite://") + conn = eng.connect() + + conn.exec_driver_sql("create table x (a integer, b integer)") + conn.exec_driver_sql("insert into x (a, b) values (1, 1)") + conn.exec_driver_sql("insert into x (a, b) values (2, 2)") + + result = conn.exec_driver_sql("select x.a, x.b from x") + assert result.keys() == ["a", "b"] + + result = conn.exec_driver_sql(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert result.keys() == ["a", "b"] + +Note that above, even though SQLAlchemy filters out the dots, *both +names are still addressable*:: + + >>> row = result.first() + >>> row["a"] + 1 + >>> row["x.a"] + 1 + >>> row["b"] + 1 + >>> row["x.b"] + 1 + +Therefore, the workaround applied by SQLAlchemy only impacts +:meth:`_engine.CursorResult.keys` and :meth:`.Row.keys()` in the public API. In +the very specific case where an application is forced to use column names that +contain dots, and the functionality of :meth:`_engine.CursorResult.keys` and +:meth:`.Row.keys()` is required to return these dotted names unmodified, +the ``sqlite_raw_colnames`` execution option may be provided, either on a +per-:class:`_engine.Connection` basis:: + + result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql(''' + select x.a, x.b from x where a=1 + union + select x.a, x.b from x where a=2 + ''') + assert result.keys() == ["x.a", "x.b"] + +or on a per-:class:`_engine.Engine` basis:: + + engine = create_engine("sqlite://", execution_options={"sqlite_raw_colnames": True}) + +When using the per-:class:`_engine.Engine` execution option, note that +**Core and ORM queries that use UNION may not function properly**. + +SQLite-specific table options +----------------------------- + +One option for CREATE TABLE is supported directly by the SQLite +dialect in conjunction with the :class:`_schema.Table` construct: + +* ``WITHOUT ROWID``:: + + Table("some_table", metadata, ..., sqlite_with_rowid=False) + +.. seealso:: + + `SQLite CREATE TABLE options + `_ + + +.. _sqlite_include_internal: + +Reflecting internal schema tables +---------------------------------- + +Reflection methods that return lists of tables will omit so-called +"SQLite internal schema object" names, which are considered by SQLite +as any object name that is prefixed with ``sqlite_``. An example of +such an object is the ``sqlite_sequence`` table that's generated when +the ``AUTOINCREMENT`` column parameter is used. In order to return +these objects, the parameter ``sqlite_include_internal=True`` may be +passed to methods such as :meth:`_schema.MetaData.reflect` or +:meth:`.Inspector.get_table_names`. + +.. versionadded:: 2.0 Added the ``sqlite_include_internal=True`` parameter. + Previously, these tables were not ignored by SQLAlchemy reflection + methods. + +.. note:: + + The ``sqlite_include_internal`` parameter does not refer to the + "system" tables that are present in schemas such as ``sqlite_master``. + +.. seealso:: + + `SQLite Internal Schema Objects `_ - in the SQLite + documentation. + +""" # noqa +from __future__ import annotations + +import datetime +import numbers +import re +from typing import Optional + +from .json import JSON +from .json import JSONIndexType +from .json import JSONPathType +from ... import exc +from ... import schema as sa_schema +from ... import sql +from ... import text +from ... import types as sqltypes +from ... import util +from ...engine import default +from ...engine import processors +from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import coercions +from ...sql import ColumnElement +from ...sql import compiler +from ...sql import elements +from ...sql import roles +from ...sql import schema +from ...types import BLOB # noqa +from ...types import BOOLEAN # noqa +from ...types import CHAR # noqa +from ...types import DECIMAL # noqa +from ...types import FLOAT # noqa +from ...types import INTEGER # noqa +from ...types import NUMERIC # noqa +from ...types import REAL # noqa +from ...types import SMALLINT # noqa +from ...types import TEXT # noqa +from ...types import TIMESTAMP # noqa +from ...types import VARCHAR # noqa + + +class _SQliteJson(JSON): + def result_processor(self, dialect, coltype): + default_processor = super().result_processor(dialect, coltype) + + def process(value): + try: + return default_processor(value) + except TypeError: + if isinstance(value, numbers.Number): + return value + else: + raise + + return process + + +class _DateTimeMixin: + _reg = None + _storage_format = None + + def __init__(self, storage_format=None, regexp=None, **kw): + super().__init__(**kw) + if regexp is not None: + self._reg = re.compile(regexp) + if storage_format is not None: + self._storage_format = storage_format + + @property + def format_is_text_affinity(self): + """return True if the storage format will automatically imply + a TEXT affinity. + + If the storage format contains no non-numeric characters, + it will imply a NUMERIC storage format on SQLite; in this case, + the type will generate its DDL as DATE_CHAR, DATETIME_CHAR, + TIME_CHAR. + + """ + spec = self._storage_format % { + "year": 0, + "month": 0, + "day": 0, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, + } + return bool(re.search(r"[^0-9]", spec)) + + def adapt(self, cls, **kw): + if issubclass(cls, _DateTimeMixin): + if self._storage_format: + kw["storage_format"] = self._storage_format + if self._reg: + kw["regexp"] = self._reg + return super().adapt(cls, **kw) + + def literal_processor(self, dialect): + bp = self.bind_processor(dialect) + + def process(value): + return "'%s'" % bp(value) + + return process + + +class DATETIME(_DateTimeMixin, sqltypes.DateTime): + r"""Represent a Python datetime object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 2021-03-15 12:05:57.105542 + + The incoming storage format is by default parsed using the + Python ``datetime.fromisoformat()`` function. + + .. versionchanged:: 2.0 ``datetime.fromisoformat()`` is used for default + datetime string parsing. + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATETIME + + dt = DATETIME(storage_format="%(year)04d/%(month)02d/%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d", + regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)" + ) + + :param storage_format: format string which will be applied to the dict + with keys year, month, day, hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows, replacing the use of ``datetime.fromisoformat()`` to parse incoming + strings. If the regexp contains named groups, the resulting match dict is + applied to the Python datetime() constructor as keyword arguments. + Otherwise, if positional groups are used, the datetime() constructor + is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + + """ # noqa + + _storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + ) + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop("truncate_microseconds", False) + super().__init__(*args, **kwargs) + if truncate_microseconds: + assert "storage_format" not in kwargs, ( + "You can specify only " + "one of truncate_microseconds or storage_format." + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " + "truncate_microseconds or regexp." + ) + self._storage_format = ( + "%(year)04d-%(month)02d-%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d" + ) + + def bind_processor(self, dialect): + datetime_datetime = datetime.datetime + datetime_date = datetime.date + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_datetime): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, + } + elif isinstance(value, datetime_date): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, + } + else: + raise TypeError( + "SQLite DateTime type only accepts Python " + "datetime and date objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.datetime + ) + else: + return processors.str_to_datetime + + +class DATE(_DateTimeMixin, sqltypes.Date): + r"""Represent a Python date object in SQLite using a string. + + The default string storage format is:: + + "%(year)04d-%(month)02d-%(day)02d" + + e.g.:: + + 2011-03-15 + + The incoming storage format is by default parsed using the + Python ``date.fromisoformat()`` function. + + .. versionchanged:: 2.0 ``date.fromisoformat()`` is used for default + date string parsing. + + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import DATE + + d = DATE( + storage_format="%(month)02d/%(day)02d/%(year)04d", + regexp=re.compile("(?P\d+)/(?P\d+)/(?P\d+)") + ) + + :param storage_format: format string which will be applied to the + dict with keys year, month, and day. + + :param regexp: regular expression which will be applied to + incoming result rows, replacing the use of ``date.fromisoformat()`` to + parse incoming strings. If the regexp contains named groups, the resulting + match dict is applied to the Python date() constructor as keyword + arguments. Otherwise, if positional groups are used, the date() + constructor is called with positional arguments via + ``*map(int, match_obj.groups(0))``. + + """ + + _storage_format = "%(year)04d-%(month)02d-%(day)02d" + + def bind_processor(self, dialect): + datetime_date = datetime.date + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_date): + return format_ % { + "year": value.year, + "month": value.month, + "day": value.day, + } + else: + raise TypeError( + "SQLite Date type only accepts Python " + "date objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.date + ) + else: + return processors.str_to_date + + +class TIME(_DateTimeMixin, sqltypes.Time): + r"""Represent a Python time object in SQLite using a string. + + The default string storage format is:: + + "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + e.g.:: + + 12:05:57.10558 + + The incoming storage format is by default parsed using the + Python ``time.fromisoformat()`` function. + + .. versionchanged:: 2.0 ``time.fromisoformat()`` is used for default + time string parsing. + + The storage format can be customized to some degree using the + ``storage_format`` and ``regexp`` parameters, such as:: + + import re + from sqlalchemy.dialects.sqlite import TIME + + t = TIME(storage_format="%(hour)02d-%(minute)02d-" + "%(second)02d-%(microsecond)06d", + regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?") + ) + + :param storage_format: format string which will be applied to the dict + with keys hour, minute, second, and microsecond. + + :param regexp: regular expression which will be applied to incoming result + rows, replacing the use of ``datetime.fromisoformat()`` to parse incoming + strings. If the regexp contains named groups, the resulting match dict is + applied to the Python time() constructor as keyword arguments. Otherwise, + if positional groups are used, the time() constructor is called with + positional arguments via ``*map(int, match_obj.groups(0))``. + + """ + + _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" + + def __init__(self, *args, **kwargs): + truncate_microseconds = kwargs.pop("truncate_microseconds", False) + super().__init__(*args, **kwargs) + if truncate_microseconds: + assert "storage_format" not in kwargs, ( + "You can specify only " + "one of truncate_microseconds or storage_format." + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " + "truncate_microseconds or regexp." + ) + self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" + + def bind_processor(self, dialect): + datetime_time = datetime.time + format_ = self._storage_format + + def process(value): + if value is None: + return None + elif isinstance(value, datetime_time): + return format_ % { + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, + } + else: + raise TypeError( + "SQLite Time type only accepts Python " + "time objects as input." + ) + + return process + + def result_processor(self, dialect, coltype): + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.time + ) + else: + return processors.str_to_time + + +colspecs = { + sqltypes.Date: DATE, + sqltypes.DateTime: DATETIME, + sqltypes.JSON: _SQliteJson, + sqltypes.JSON.JSONIndexType: JSONIndexType, + sqltypes.JSON.JSONPathType: JSONPathType, + sqltypes.Time: TIME, +} + +ischema_names = { + "BIGINT": sqltypes.BIGINT, + "BLOB": sqltypes.BLOB, + "BOOL": sqltypes.BOOLEAN, + "BOOLEAN": sqltypes.BOOLEAN, + "CHAR": sqltypes.CHAR, + "DATE": sqltypes.DATE, + "DATE_CHAR": sqltypes.DATE, + "DATETIME": sqltypes.DATETIME, + "DATETIME_CHAR": sqltypes.DATETIME, + "DOUBLE": sqltypes.DOUBLE, + "DECIMAL": sqltypes.DECIMAL, + "FLOAT": sqltypes.FLOAT, + "INT": sqltypes.INTEGER, + "INTEGER": sqltypes.INTEGER, + "JSON": JSON, + "NUMERIC": sqltypes.NUMERIC, + "REAL": sqltypes.REAL, + "SMALLINT": sqltypes.SMALLINT, + "TEXT": sqltypes.TEXT, + "TIME": sqltypes.TIME, + "TIME_CHAR": sqltypes.TIME, + "TIMESTAMP": sqltypes.TIMESTAMP, + "VARCHAR": sqltypes.VARCHAR, + "NVARCHAR": sqltypes.NVARCHAR, + "NCHAR": sqltypes.NCHAR, +} + + +class SQLiteCompiler(compiler.SQLCompiler): + extract_map = util.update_copy( + compiler.SQLCompiler.extract_map, + { + "month": "%m", + "day": "%d", + "year": "%Y", + "second": "%S", + "hour": "%H", + "doy": "%j", + "minute": "%M", + "epoch": "%s", + "dow": "%w", + "week": "%W", + }, + ) + + def visit_truediv_binary(self, binary, operator, **kw): + return ( + self.process(binary.left, **kw) + + " / " + + "(%s + 0.0)" % self.process(binary.right, **kw) + ) + + def visit_now_func(self, fn, **kw): + return "CURRENT_TIMESTAMP" + + def visit_localtimestamp_func(self, func, **kw): + return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' + + def visit_true(self, expr, **kw): + return "1" + + def visit_false(self, expr, **kw): + return "0" + + def visit_char_length_func(self, fn, **kw): + return "length%s" % self.function_argspec(fn) + + def visit_aggregate_strings_func(self, fn, **kw): + return "group_concat%s" % self.function_argspec(fn) + + def visit_cast(self, cast, **kwargs): + if self.dialect.supports_cast: + return super().visit_cast(cast, **kwargs) + else: + return self.process(cast.clause, **kwargs) + + def visit_extract(self, extract, **kw): + try: + return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( + self.extract_map[extract.field], + self.process(extract.expr, **kw), + ) + except KeyError as err: + raise exc.CompileError( + "%s is not a valid extract argument." % extract.field + ) from err + + def returning_clause( + self, + stmt, + returning_cols, + *, + populate_result_map, + **kw, + ): + kw["include_table"] = False + return super().returning_clause( + stmt, returning_cols, populate_result_map=populate_result_map, **kw + ) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += "\n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT " + self.process(sql.literal(-1)) + text += " OFFSET " + self.process(select._offset_clause, **kw) + else: + text += " OFFSET " + self.process(sql.literal(0), **kw) + return text + + def for_update_clause(self, select, **kw): + # sqlite has no "FOR UPDATE" AFAICT + return "" + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def visit_is_distinct_from_binary(self, binary, operator, **kw): + return "%s IS NOT %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): + return "%s IS %s" % ( + self.process(binary.left), + self.process(binary.right), + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + if binary.type._type_affinity is sqltypes.JSON: + expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))" + else: + expr = "JSON_EXTRACT(%s, %s)" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + if binary.type._type_affinity is sqltypes.JSON: + expr = "JSON_QUOTE(JSON_EXTRACT(%s, %s))" + else: + expr = "JSON_EXTRACT(%s, %s)" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_empty_set_op_expr(self, type_, expand_op, **kw): + # slightly old SQLite versions don't seem to be able to handle + # the empty set impl + return self.visit_empty_set_expr(type_) + + def visit_empty_set_expr(self, element_types, **kw): + return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % ( + ", ".join("1" for type_ in element_types or [INTEGER()]), + ", ".join("1" for type_ in element_types or [INTEGER()]), + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " REGEXP ", **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) + + def _on_conflict_target(self, clause, **kw): + if clause.constraint_target is not None: + target_text = "(%s)" % clause.constraint_target + elif clause.inferred_target_elements is not None: + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, str) + else self.process(c, include_table=False, use_schema=False) + ) + for c in clause.inferred_target_elements + ) + if clause.inferred_target_whereclause is not None: + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + literal_binds=True, + ) + + else: + target_text = "" + + return target_text + + def visit_on_conflict_do_nothing(self, on_conflict, **kw): + target_text = self._on_conflict_target(on_conflict, **kw) + + if target_text: + return "ON CONFLICT %s DO NOTHING" % target_text + else: + return "ON CONFLICT DO NOTHING" + + def visit_on_conflict_do_update(self, on_conflict, **kw): + clause = on_conflict + + target_text = self._on_conflict_target(on_conflict, **kw) + + action_set_ops = [] + + set_parameters = dict(clause.update_values_to_set) + # create a list of column assignment clauses as tuples + + insert_statement = self.stack[-1]["selectable"] + cols = insert_statement.table.c + for c in cols: + col_key = c.key + + if col_key in set_parameters: + value = set_parameters.pop(col_key) + elif c in set_parameters: + value = set_parameters.pop(c) + else: + continue + + if coercions._is_literal(value): + value = elements.BindParameter(None, value, type_=c.type) + + else: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): + value = value._clone() + value.type = c.type + value_text = self.process(value.self_group(), use_schema=False) + + key_text = self.preparer.quote(c.name) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + # check for names that don't match columns + if set_parameters: + util.warn( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( + self.current_executable.table.name, + (", ".join("'%s'" % c for c in set_parameters)), + ) + ) + for k, v in set_parameters.items(): + key_text = ( + self.preparer.quote(k) + if isinstance(k, str) + else self.process(k, use_schema=False) + ) + value_text = self.process( + coercions.expect(roles.ExpressionElementRole, v), + use_schema=False, + ) + action_set_ops.append("%s = %s" % (key_text, value_text)) + + action_text = ", ".join(action_set_ops) + if clause.update_whereclause is not None: + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) + + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) + + +class SQLiteDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + coltype = self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ) + colspec = self.preparer.format_column(column) + " " + coltype + default = self.get_column_default_string(column) + if default is not None: + if isinstance(column.server_default.arg, ColumnElement): + default = "(" + default + ")" + colspec += " DEFAULT " + default + + if not column.nullable: + colspec += " NOT NULL" + + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_not_null" + ] + if on_conflict_clause is not None: + colspec += " ON CONFLICT " + on_conflict_clause + + if column.primary_key: + if ( + column.autoincrement is True + and len(column.table.primary_key.columns) != 1 + ): + raise exc.CompileError( + "SQLite does not support autoincrement for " + "composite primary keys" + ) + + if ( + column.table.dialect_options["sqlite"]["autoincrement"] + and len(column.table.primary_key.columns) == 1 + and issubclass(column.type._type_affinity, sqltypes.Integer) + and not column.foreign_keys + ): + colspec += " PRIMARY KEY" + + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_primary_key" + ] + if on_conflict_clause is not None: + colspec += " ON CONFLICT " + on_conflict_clause + + colspec += " AUTOINCREMENT" + + if column.computed is not None: + colspec += " " + self.process(column.computed) + + return colspec + + def visit_primary_key_constraint(self, constraint, **kw): + # for columns with sqlite_autoincrement=True, + # the PRIMARY KEY constraint can only be inline + # with the column itself. + if len(constraint.columns) == 1: + c = list(constraint)[0] + if ( + c.primary_key + and c.table.dialect_options["sqlite"]["autoincrement"] + and issubclass(c.type._type_affinity, sqltypes.Integer) + and not c.foreign_keys + ): + return None + + text = super().visit_primary_key_constraint(constraint) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + if on_conflict_clause is None and len(constraint.columns) == 1: + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_primary_key" + ] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_unique_constraint(self, constraint, **kw): + text = super().visit_unique_constraint(constraint) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + if on_conflict_clause is None and len(constraint.columns) == 1: + col1 = list(constraint)[0] + if isinstance(col1, schema.SchemaItem): + on_conflict_clause = list(constraint)[0].dialect_options[ + "sqlite" + ]["on_conflict_unique"] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_check_constraint(self, constraint, **kw): + text = super().visit_check_constraint(constraint) + + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] + + if on_conflict_clause is not None: + text += " ON CONFLICT " + on_conflict_clause + + return text + + def visit_column_check_constraint(self, constraint, **kw): + text = super().visit_column_check_constraint(constraint) + + if constraint.dialect_options["sqlite"]["on_conflict"] is not None: + raise exc.CompileError( + "SQLite does not support on conflict clause for " + "column check constraint" + ) + + return text + + def visit_foreign_key_constraint(self, constraint, **kw): + local_table = constraint.elements[0].parent.table + remote_table = constraint.elements[0].column.table + + if local_table.schema != remote_table.schema: + return None + else: + return super().visit_foreign_key_constraint(constraint) + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table, use_schema=False) + + def visit_create_index( + self, create, include_schema=False, include_table_schema=True, **kw + ): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + + text += "INDEX " + + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=False), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + + whereclause = index.dialect_options["sqlite"]["where"] + if whereclause is not None: + where_compiled = self.sql_compiler.process( + whereclause, include_table=False, literal_binds=True + ) + text += " WHERE " + where_compiled + + return text + + def post_create_table(self, table): + if table.dialect_options["sqlite"]["with_rowid"] is False: + return "\n WITHOUT ROWID" + return "" + + +class SQLiteTypeCompiler(compiler.GenericTypeCompiler): + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_) + + def visit_DATETIME(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super().visit_DATETIME(type_) + else: + return "DATETIME_CHAR" + + def visit_DATE(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super().visit_DATE(type_) + else: + return "DATE_CHAR" + + def visit_TIME(self, type_, **kw): + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): + return super().visit_TIME(type_) + else: + return "TIME_CHAR" + + def visit_JSON(self, type_, **kw): + # note this name provides NUMERIC affinity, not TEXT. + # should not be an issue unless the JSON value consists of a single + # numeric value. JSONTEXT can be used if this case is required. + return "JSON" + + +class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = { + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "exists", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + } + + +class SQLiteExecutionContext(default.DefaultExecutionContext): + @util.memoized_property + def _preserve_raw_colnames(self): + return ( + not self.dialect._broken_dotted_colnames + or self.execution_options.get("sqlite_raw_colnames", False) + ) + + def _translate_colname(self, colname): + # TODO: detect SQLite version 3.10.0 or greater; + # see [ticket:3633] + + # adjust for dotted column names. SQLite + # in the case of UNION may store col names as + # "tablename.colname", or if using an attached database, + # "database.tablename.colname", in cursor.description + if not self._preserve_raw_colnames and "." in colname: + return colname.split(".")[-1], colname + else: + return colname, None + + +class SQLiteDialect(default.DefaultDialect): + name = "sqlite" + supports_alter = False + + # SQlite supports "DEFAULT VALUES" but *does not* support + # "VALUES (DEFAULT)" + supports_default_values = True + supports_default_metavalue = False + + # sqlite issue: + # https://github.com/python/cpython/issues/93421 + # note this parameter is no longer used by the ORM or default dialect + # see #9414 + supports_sane_rowcount_returning = False + + supports_empty_insert = False + supports_cast = True + supports_multivalues_insert = True + use_insertmanyvalues = True + tuple_in_values = True + supports_statement_cache = True + insert_null_pk_still_autoincrements = True + insert_returning = True + update_returning = True + update_returning_multifrom = True + delete_returning = True + update_returning_multifrom = True + + supports_default_metavalue = True + """dialect supports INSERT... VALUES (DEFAULT) syntax""" + + default_metavalue_token = "NULL" + """for INSERT... VALUES (DEFAULT) syntax, the token to put in the + parenthesis.""" + + default_paramstyle = "qmark" + execution_ctx_cls = SQLiteExecutionContext + statement_compiler = SQLiteCompiler + ddl_compiler = SQLiteDDLCompiler + type_compiler_cls = SQLiteTypeCompiler + preparer = SQLiteIdentifierPreparer + ischema_names = ischema_names + colspecs = colspecs + + construct_arguments = [ + ( + sa_schema.Table, + { + "autoincrement": False, + "with_rowid": True, + }, + ), + (sa_schema.Index, {"where": None}), + ( + sa_schema.Column, + { + "on_conflict_primary_key": None, + "on_conflict_not_null": None, + "on_conflict_unique": None, + }, + ), + (sa_schema.Constraint, {"on_conflict": None}), + ] + + _broken_fk_pragma_quotes = False + _broken_dotted_colnames = False + + @util.deprecated_params( + _json_serializer=( + "1.3.7", + "The _json_serializer argument to the SQLite dialect has " + "been renamed to the correct name of json_serializer. The old " + "argument name will be removed in a future release.", + ), + _json_deserializer=( + "1.3.7", + "The _json_deserializer argument to the SQLite dialect has " + "been renamed to the correct name of json_deserializer. The old " + "argument name will be removed in a future release.", + ), + ) + def __init__( + self, + native_datetime=False, + json_serializer=None, + json_deserializer=None, + _json_serializer=None, + _json_deserializer=None, + **kwargs, + ): + default.DefaultDialect.__init__(self, **kwargs) + + if _json_serializer: + json_serializer = _json_serializer + if _json_deserializer: + json_deserializer = _json_deserializer + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer + + # this flag used by pysqlite dialect, and perhaps others in the + # future, to indicate the driver is handling date/timestamp + # conversions (and perhaps datetime/time as well on some hypothetical + # driver ?) + self.native_datetime = native_datetime + + if self.dbapi is not None: + if self.dbapi.sqlite_version_info < (3, 7, 16): + util.warn( + "SQLite version %s is older than 3.7.16, and will not " + "support right nested joins, as are sometimes used in " + "more complex ORM scenarios. SQLAlchemy 1.4 and above " + "no longer tries to rewrite these joins." + % (self.dbapi.sqlite_version_info,) + ) + + # NOTE: python 3.7 on fedora for me has SQLite 3.34.1. These + # version checks are getting very stale. + self._broken_dotted_colnames = self.dbapi.sqlite_version_info < ( + 3, + 10, + 0, + ) + self.supports_default_values = self.dbapi.sqlite_version_info >= ( + 3, + 3, + 8, + ) + self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3) + self.supports_multivalues_insert = ( + # https://www.sqlite.org/releaselog/3_7_11.html + self.dbapi.sqlite_version_info + >= (3, 7, 11) + ) + # see https://www.sqlalchemy.org/trac/ticket/2568 + # as well as https://www.sqlite.org/src/info/600482d161 + self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < ( + 3, + 6, + 14, + ) + + if self.dbapi.sqlite_version_info < (3, 35) or util.pypy: + self.update_returning = self.delete_returning = ( + self.insert_returning + ) = False + + if self.dbapi.sqlite_version_info < (3, 32, 0): + # https://www.sqlite.org/limits.html + self.insertmanyvalues_max_parameters = 999 + + _isolation_lookup = util.immutabledict( + {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0} + ) + + def get_isolation_level_values(self, dbapi_connection): + return list(self._isolation_lookup) + + def set_isolation_level(self, dbapi_connection, level): + isolation_level = self._isolation_lookup[level] + + cursor = dbapi_connection.cursor() + cursor.execute(f"PRAGMA read_uncommitted = {isolation_level}") + cursor.close() + + def get_isolation_level(self, dbapi_connection): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA read_uncommitted") + res = cursor.fetchone() + if res: + value = res[0] + else: + # https://www.sqlite.org/changes.html#version_3_3_3 + # "Optional READ UNCOMMITTED isolation (instead of the + # default isolation level of SERIALIZABLE) and + # table level locking when database connections + # share a common cache."" + # pre-SQLite 3.3.0 default to 0 + value = 0 + cursor.close() + if value == 0: + return "SERIALIZABLE" + elif value == 1: + return "READ UNCOMMITTED" + else: + assert False, "Unknown isolation level %s" % value + + @reflection.cache + def get_schema_names(self, connection, **kw): + s = "PRAGMA database_list" + dl = connection.exec_driver_sql(s) + + return [db[1] for db in dl if db[1] != "temp"] + + def _format_schema(self, schema, table_name): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + name = f"{qschema}.{table_name}" + else: + name = table_name + return name + + def _sqlite_main_query( + self, + table: str, + type_: str, + schema: Optional[str], + sqlite_include_internal: bool, + ): + main = self._format_schema(schema, table) + if not sqlite_include_internal: + filter_table = " AND name NOT LIKE 'sqlite~_%' ESCAPE '~'" + else: + filter_table = "" + query = ( + f"SELECT name FROM {main} " + f"WHERE type='{type_}'{filter_table} " + "ORDER BY name" + ) + return query + + @reflection.cache + def get_table_names( + self, connection, schema=None, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_master", "table", schema, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def get_temp_table_names( + self, connection, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_temp_master", "table", None, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def get_temp_view_names( + self, connection, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_temp_master", "view", None, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + + if schema is not None and schema not in self.get_schema_names( + connection, **kw + ): + return False + + info = self._get_table_pragma( + connection, "table_info", table_name, schema=schema + ) + return bool(info) + + def _get_default_schema_name(self, connection): + return "main" + + @reflection.cache + def get_view_names( + self, connection, schema=None, sqlite_include_internal=False, **kw + ): + query = self._sqlite_main_query( + "sqlite_master", "view", schema, sqlite_include_internal + ) + names = connection.exec_driver_sql(query).scalars().all() + return names + + @reflection.cache + def get_view_definition(self, connection, view_name, schema=None, **kw): + if schema is not None: + qschema = self.identifier_preparer.quote_identifier(schema) + master = f"{qschema}.sqlite_master" + s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % ( + master, + ) + rs = connection.exec_driver_sql(s, (view_name,)) + else: + try: + s = ( + "SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = ? " + "AND type='view'" + ) + rs = connection.exec_driver_sql(s, (view_name,)) + except exc.DBAPIError: + s = ( + "SELECT sql FROM sqlite_master WHERE name = ? " + "AND type='view'" + ) + rs = connection.exec_driver_sql(s, (view_name,)) + + result = rs.fetchall() + if result: + return result[0].sql + else: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + + @reflection.cache + def get_columns(self, connection, table_name, schema=None, **kw): + pragma = "table_info" + # computed columns are threaded as hidden, they require table_xinfo + if self.server_version_info >= (3, 31): + pragma = "table_xinfo" + info = self._get_table_pragma( + connection, pragma, table_name, schema=schema + ) + columns = [] + tablesql = None + for row in info: + name = row[1] + type_ = row[2].upper() + nullable = not row[3] + default = row[4] + primary_key = row[5] + hidden = row[6] if pragma == "table_xinfo" else 0 + + # hidden has value 0 for normal columns, 1 for hidden columns, + # 2 for computed virtual columns and 3 for computed stored columns + # https://www.sqlite.org/src/info/069351b85f9a706f60d3e98fbc8aaf40c374356b967c0464aede30ead3d9d18b + if hidden == 1: + continue + + generated = bool(hidden) + persisted = hidden == 3 + + if tablesql is None and generated: + tablesql = self._get_table_sql( + connection, table_name, schema, **kw + ) + + columns.append( + self._get_column_info( + name, + type_, + nullable, + default, + primary_key, + generated, + persisted, + tablesql, + ) + ) + if columns: + return columns + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.columns() + + def _get_column_info( + self, + name, + type_, + nullable, + default, + primary_key, + generated, + persisted, + tablesql, + ): + if generated: + # the type of a column "cc INTEGER GENERATED ALWAYS AS (1 + 42)" + # somehow is "INTEGER GENERATED ALWAYS" + type_ = re.sub("generated", "", type_, flags=re.IGNORECASE) + type_ = re.sub("always", "", type_, flags=re.IGNORECASE).strip() + + coltype = self._resolve_type_affinity(type_) + + if default is not None: + default = str(default) + + colspec = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "primary_key": primary_key, + } + if generated: + sqltext = "" + if tablesql: + pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?" + match = re.search( + re.escape(name) + pattern, tablesql, re.IGNORECASE + ) + if match: + sqltext = match.group(1) + colspec["computed"] = {"sqltext": sqltext, "persisted": persisted} + return colspec + + def _resolve_type_affinity(self, type_): + """Return a data type from a reflected column, using affinity rules. + + SQLite's goal for universal compatibility introduces some complexity + during reflection, as a column's defined type might not actually be a + type that SQLite understands - or indeed, my not be defined *at all*. + Internally, SQLite handles this with a 'data type affinity' for each + column definition, mapping to one of 'TEXT', 'NUMERIC', 'INTEGER', + 'REAL', or 'NONE' (raw bits). The algorithm that determines this is + listed in https://www.sqlite.org/datatype3.html section 2.1. + + This method allows SQLAlchemy to support that algorithm, while still + providing access to smarter reflection utilities by recognizing + column definitions that SQLite only supports through affinity (like + DATE and DOUBLE). + + """ + match = re.match(r"([\w ]+)(\(.*?\))?", type_) + if match: + coltype = match.group(1) + args = match.group(2) + else: + coltype = "" + args = "" + + if coltype in self.ischema_names: + coltype = self.ischema_names[coltype] + elif "INT" in coltype: + coltype = sqltypes.INTEGER + elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype: + coltype = sqltypes.TEXT + elif "BLOB" in coltype or not coltype: + coltype = sqltypes.NullType + elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype: + coltype = sqltypes.REAL + else: + coltype = sqltypes.NUMERIC + + if args is not None: + args = re.findall(r"(\d+)", args) + try: + coltype = coltype(*[int(a) for a in args]) + except TypeError: + util.warn( + "Could not instantiate type %s with " + "reflected arguments %s; using no arguments." + % (coltype, args) + ) + coltype = coltype() + else: + coltype = coltype() + + return coltype + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + constraint_name = None + table_data = self._get_table_sql(connection, table_name, schema=schema) + if table_data: + PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY" + result = re.search(PK_PATTERN, table_data, re.I) + constraint_name = result.group(1) if result else None + + cols = self.get_columns(connection, table_name, schema, **kw) + # consider only pk columns. This also avoids sorting the cached + # value returned by get_columns + cols = [col for col in cols if col.get("primary_key", 0) > 0] + cols.sort(key=lambda col: col.get("primary_key")) + pkeys = [col["name"] for col in cols] + + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return ReflectionDefaults.pk_constraint() + + @reflection.cache + def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # sqlite makes this *extremely difficult*. + # First, use the pragma to get the actual FKs. + pragma_fks = self._get_table_pragma( + connection, "foreign_key_list", table_name, schema=schema + ) + + fks = {} + + for row in pragma_fks: + (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) + + if not rcol: + # no referred column, which means it was not named in the + # original DDL. The referred columns of the foreign key + # constraint are therefore the primary key of the referred + # table. + try: + referred_pk = self.get_pk_constraint( + connection, rtbl, schema=schema, **kw + ) + referred_columns = referred_pk["constrained_columns"] + except exc.NoSuchTableError: + # ignore not existing parents + referred_columns = [] + else: + # note we use this list only if this is the first column + # in the constraint. for subsequent columns we ignore the + # list and append "rcol" if present. + referred_columns = [] + + if self._broken_fk_pragma_quotes: + rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl) + + if numerical_id in fks: + fk = fks[numerical_id] + else: + fk = fks[numerical_id] = { + "name": None, + "constrained_columns": [], + "referred_schema": schema, + "referred_table": rtbl, + "referred_columns": referred_columns, + "options": {}, + } + fks[numerical_id] = fk + + fk["constrained_columns"].append(lcol) + + if rcol: + fk["referred_columns"].append(rcol) + + def fk_sig(constrained_columns, referred_table, referred_columns): + return ( + tuple(constrained_columns) + + (referred_table,) + + tuple(referred_columns) + ) + + # then, parse the actual SQL and attempt to find DDL that matches + # the names as well. SQLite saves the DDL in whatever format + # it was typed in as, so need to be liberal here. + + keys_by_signature = { + fk_sig( + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ): fk + for fk in fks.values() + } + + table_data = self._get_table_sql(connection, table_name, schema=schema) + + def parse_fks(): + if table_data is None: + # system tables, etc. + return + + # note that we already have the FKs from PRAGMA above. This whole + # regexp thing is trying to locate additional detail about the + # FKs, namely the name of the constraint and other options. + # so parsing the columns is really about matching it up to what + # we already have. + FK_PATTERN = ( + r"(?:CONSTRAINT (\w+) +)?" + r"FOREIGN KEY *\( *(.+?) *\) +" + r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\( *((?:(?:"[^"]+"|[a-z0-9_]+) *(?:, *)?)+)\) *' # noqa: E501 + r"((?:ON (?:DELETE|UPDATE) " + r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)" + r"((?:NOT +)?DEFERRABLE)?" + r"(?: +INITIALLY +(DEFERRED|IMMEDIATE))?" + ) + for match in re.finditer(FK_PATTERN, table_data, re.I): + ( + constraint_name, + constrained_columns, + referred_quoted_name, + referred_name, + referred_columns, + onupdatedelete, + deferrable, + initially, + ) = match.group(1, 2, 3, 4, 5, 6, 7, 8) + constrained_columns = list( + self._find_cols_in_sig(constrained_columns) + ) + if not referred_columns: + referred_columns = constrained_columns + else: + referred_columns = list( + self._find_cols_in_sig(referred_columns) + ) + referred_name = referred_quoted_name or referred_name + options = {} + + for token in re.split(r" *\bON\b *", onupdatedelete.upper()): + if token.startswith("DELETE"): + ondelete = token[6:].strip() + if ondelete and ondelete != "NO ACTION": + options["ondelete"] = ondelete + elif token.startswith("UPDATE"): + onupdate = token[6:].strip() + if onupdate and onupdate != "NO ACTION": + options["onupdate"] = onupdate + + if deferrable: + options["deferrable"] = "NOT" not in deferrable.upper() + if initially: + options["initially"] = initially.upper() + + yield ( + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) + + fkeys = [] + + for ( + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) in parse_fks(): + sig = fk_sig(constrained_columns, referred_name, referred_columns) + if sig not in keys_by_signature: + util.warn( + "WARNING: SQL-parsed foreign key constraint " + "'%s' could not be located in PRAGMA " + "foreign_keys for table %s" % (sig, table_name) + ) + continue + key = keys_by_signature.pop(sig) + key["name"] = constraint_name + key["options"] = options + fkeys.append(key) + # assume the remainders are the unnamed, inline constraints, just + # use them as is as it's extremely difficult to parse inline + # constraints + fkeys.extend(keys_by_signature.values()) + if fkeys: + return fkeys + else: + return ReflectionDefaults.foreign_keys() + + def _find_cols_in_sig(self, sig): + for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): + yield match.group(1) or match.group(2) + + @reflection.cache + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + auto_index_by_sig = {} + for idx in self.get_indexes( + connection, + table_name, + schema=schema, + include_auto_indexes=True, + **kw, + ): + if not idx["name"].startswith("sqlite_autoindex"): + continue + sig = tuple(idx["column_names"]) + auto_index_by_sig[sig] = idx + + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw + ) + unique_constraints = [] + + def parse_uqs(): + if table_data is None: + return + UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' + INLINE_UNIQUE_PATTERN = ( + r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' + r"+[a-z0-9_ ]+? +UNIQUE" + ) + + for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): + name, cols = match.group(1, 2) + yield name, list(self._find_cols_in_sig(cols)) + + # we need to match inlines as well, as we seek to differentiate + # a UNIQUE constraint from a UNIQUE INDEX, even though these + # are kind of the same thing :) + for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): + cols = list( + self._find_cols_in_sig(match.group(1) or match.group(2)) + ) + yield None, cols + + for name, cols in parse_uqs(): + sig = tuple(cols) + if sig in auto_index_by_sig: + auto_index_by_sig.pop(sig) + parsed_constraint = {"name": name, "column_names": cols} + unique_constraints.append(parsed_constraint) + # NOTE: auto_index_by_sig might not be empty here, + # the PRIMARY KEY may have an entry. + if unique_constraints: + return unique_constraints + else: + return ReflectionDefaults.unique_constraints() + + @reflection.cache + def get_check_constraints(self, connection, table_name, schema=None, **kw): + table_data = self._get_table_sql( + connection, table_name, schema=schema, **kw + ) + + CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" + cks = [] + # NOTE: we aren't using re.S here because we actually are + # taking advantage of each CHECK constraint being all on one + # line in the table definition in order to delineate. This + # necessarily makes assumptions as to how the CREATE TABLE + # was emitted. + + for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): + name = match.group(1) + + if name: + name = re.sub(r'^"|"$', "", name) + + cks.append({"sqltext": match.group(2), "name": name}) + cks.sort(key=lambda d: d["name"] or "~") # sort None as last + if cks: + return cks + else: + return ReflectionDefaults.check_constraints() + + @reflection.cache + def get_indexes(self, connection, table_name, schema=None, **kw): + pragma_indexes = self._get_table_pragma( + connection, "index_list", table_name, schema=schema + ) + indexes = [] + + # regular expression to extract the filter predicate of a partial + # index. this could fail to extract the predicate correctly on + # indexes created like + # CREATE INDEX i ON t (col || ') where') WHERE col <> '' + # but as this function does not support expression-based indexes + # this case does not occur. + partial_pred_re = re.compile(r"\)\s+where\s+(.+)", re.IGNORECASE) + + if schema: + schema_expr = "%s." % self.identifier_preparer.quote_identifier( + schema + ) + else: + schema_expr = "" + + include_auto_indexes = kw.pop("include_auto_indexes", False) + for row in pragma_indexes: + # ignore implicit primary key index. + # https://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html + if not include_auto_indexes and row[1].startswith( + "sqlite_autoindex" + ): + continue + indexes.append( + dict( + name=row[1], + column_names=[], + unique=row[2], + dialect_options={}, + ) + ) + + # check partial indexes + if len(row) >= 5 and row[4]: + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = ? " + "AND type = 'index'" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (row[1],)) + index_sql = rs.scalar() + predicate_match = partial_pred_re.search(index_sql) + if predicate_match is None: + # unless the regex is broken this case shouldn't happen + # because we know this is a partial index, so the + # definition sql should match the regex + util.warn( + "Failed to look up filter predicate of " + "partial index %s" % row[1] + ) + else: + predicate = predicate_match.group(1) + indexes[-1]["dialect_options"]["sqlite_where"] = text( + predicate + ) + + # loop thru unique indexes to get the column names. + for idx in list(indexes): + pragma_index = self._get_table_pragma( + connection, "index_info", idx["name"], schema=schema + ) + + for row in pragma_index: + if row[2] is None: + util.warn( + "Skipped unsupported reflection of " + "expression-based index %s" % idx["name"] + ) + indexes.remove(idx) + break + else: + idx["column_names"].append(row[2]) + + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + if indexes: + return indexes + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.indexes() + + def _is_sys_table(self, table_name): + return table_name in { + "sqlite_schema", + "sqlite_master", + "sqlite_temp_schema", + "sqlite_temp_master", + } + + @reflection.cache + def _get_table_sql(self, connection, table_name, schema=None, **kw): + if schema: + schema_expr = "%s." % ( + self.identifier_preparer.quote_identifier(schema) + ) + else: + schema_expr = "" + try: + s = ( + "SELECT sql FROM " + " (SELECT * FROM %(schema)ssqlite_master UNION ALL " + " SELECT * FROM %(schema)ssqlite_temp_master) " + "WHERE name = ? " + "AND type in ('table', 'view')" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (table_name,)) + except exc.DBAPIError: + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = ? " + "AND type in ('table', 'view')" % {"schema": schema_expr} + ) + rs = connection.exec_driver_sql(s, (table_name,)) + value = rs.scalar() + if value is None and not self._is_sys_table(table_name): + raise exc.NoSuchTableError(f"{schema_expr}{table_name}") + return value + + def _get_table_pragma(self, connection, pragma, table_name, schema=None): + quote = self.identifier_preparer.quote_identifier + if schema is not None: + statements = [f"PRAGMA {quote(schema)}."] + else: + # because PRAGMA looks in all attached databases if no schema + # given, need to specify "main" schema, however since we want + # 'temp' tables in the same namespace as 'main', need to run + # the PRAGMA twice + statements = ["PRAGMA main.", "PRAGMA temp."] + + qtable = quote(table_name) + for statement in statements: + statement = f"{statement}{pragma}({qtable})" + cursor = connection.exec_driver_sql(statement) + if not cursor._soft_closed: + # work around SQLite issue whereby cursor.description + # is blank when PRAGMA returns no rows: + # https://www.sqlite.org/cvstrac/tktview?tn=1884 + result = cursor.fetchall() + else: + result = [] + if result: + return result + else: + return [] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/dml.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/dml.py new file mode 100644 index 0000000..dcf5e44 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/dml.py @@ -0,0 +1,240 @@ +# dialects/sqlite/dml.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +from typing import Any + +from .._typing import _OnConflictIndexElementsT +from .._typing import _OnConflictIndexWhereT +from .._typing import _OnConflictSetT +from .._typing import _OnConflictWhereT +from ... import util +from ...sql import coercions +from ...sql import roles +from ...sql._typing import _DMLTableArgument +from ...sql.base import _exclusive_against +from ...sql.base import _generative +from ...sql.base import ColumnCollection +from ...sql.base import ReadOnlyColumnCollection +from ...sql.dml import Insert as StandardInsert +from ...sql.elements import ClauseElement +from ...sql.elements import KeyedColumnElement +from ...sql.expression import alias +from ...util.typing import Self + +__all__ = ("Insert", "insert") + + +def insert(table: _DMLTableArgument) -> Insert: + """Construct a sqlite-specific variant :class:`_sqlite.Insert` + construct. + + .. container:: inherited_member + + The :func:`sqlalchemy.dialects.sqlite.insert` function creates + a :class:`sqlalchemy.dialects.sqlite.Insert`. This class is based + on the dialect-agnostic :class:`_sql.Insert` construct which may + be constructed using the :func:`_sql.insert` function in + SQLAlchemy Core. + + The :class:`_sqlite.Insert` construct includes additional methods + :meth:`_sqlite.Insert.on_conflict_do_update`, + :meth:`_sqlite.Insert.on_conflict_do_nothing`. + + """ + return Insert(table) + + +class Insert(StandardInsert): + """SQLite-specific implementation of INSERT. + + Adds methods for SQLite-specific syntaxes such as ON CONFLICT. + + The :class:`_sqlite.Insert` object is created using the + :func:`sqlalchemy.dialects.sqlite.insert` function. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`sqlite_on_conflict_insert` + + """ + + stringify_dialect = "sqlite" + inherit_cache = False + + @util.memoized_property + def excluded( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + """Provide the ``excluded`` namespace for an ON CONFLICT statement + + SQLite's ON CONFLICT clause allows reference to the row that would + be inserted, known as ``excluded``. This attribute provides + all columns in this row to be referenceable. + + .. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance + of :class:`_expression.ColumnCollection`, which provides an + interface the same as that of the :attr:`_schema.Table.c` + collection described at :ref:`metadata_tables_and_columns`. + With this collection, ordinary names are accessible like attributes + (e.g. ``stmt.excluded.some_column``), but special names and + dictionary method names should be accessed using indexed access, + such as ``stmt.excluded["column name"]`` or + ``stmt.excluded["values"]``. See the docstring for + :class:`_expression.ColumnCollection` for further examples. + + """ + return alias(self.table, name="excluded").columns + + _on_conflict_exclusive = _exclusive_against( + "_post_values_clause", + msgs={ + "_post_values_clause": "This Insert construct already has " + "an ON CONFLICT clause established" + }, + ) + + @_generative + @_on_conflict_exclusive + def on_conflict_do_update( + self, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, + ) -> Self: + r""" + Specifies a DO UPDATE SET action for ON CONFLICT clause. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index or unique constraint. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + :param set\_: + A dictionary or other mapping object + where the keys are either names of columns in the target table, + or :class:`_schema.Column` objects or other ORM-mapped columns + matching that of the target table, and expressions or literals + as values, specifying the ``SET`` actions to take. + + .. versionadded:: 1.4 The + :paramref:`_sqlite.Insert.on_conflict_do_update.set_` + parameter supports :class:`_schema.Column` objects from the target + :class:`_schema.Table` as keys. + + .. warning:: This dictionary does **not** take into account + Python-specified default UPDATE values or generation functions, + e.g. those specified using :paramref:`_schema.Column.onupdate`. + These values will not be exercised for an ON CONFLICT style of + UPDATE, unless they are manually specified in the + :paramref:`.Insert.on_conflict_do_update.set_` dictionary. + + :param where: + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). + + """ + + self._post_values_clause = OnConflictDoUpdate( + index_elements, index_where, set_, where + ) + return self + + @_generative + @_on_conflict_exclusive + def on_conflict_do_nothing( + self, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + ) -> Self: + """ + Specifies a DO NOTHING action for ON CONFLICT clause. + + :param index_elements: + A sequence consisting of string column names, :class:`_schema.Column` + objects, or other column expression objects that will be used + to infer a target index or unique constraint. + + :param index_where: + Additional WHERE criterion that can be used to infer a + conditional target index. + + """ + + self._post_values_clause = OnConflictDoNothing( + index_elements, index_where + ) + return self + + +class OnConflictClause(ClauseElement): + stringify_dialect = "sqlite" + + constraint_target: None + inferred_target_elements: _OnConflictIndexElementsT + inferred_target_whereclause: _OnConflictIndexWhereT + + def __init__( + self, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + ): + if index_elements is not None: + self.constraint_target = None + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where + else: + self.constraint_target = self.inferred_target_elements = ( + self.inferred_target_whereclause + ) = None + + +class OnConflictDoNothing(OnConflictClause): + __visit_name__ = "on_conflict_do_nothing" + + +class OnConflictDoUpdate(OnConflictClause): + __visit_name__ = "on_conflict_do_update" + + def __init__( + self, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, + ): + super().__init__( + index_elements=index_elements, + index_where=index_where, + ) + + if isinstance(set_, dict): + if not set_: + raise ValueError("set parameter dictionary must not be empty") + elif isinstance(set_, ColumnCollection): + set_ = dict(set_) + else: + raise ValueError( + "set parameter must be a non-empty dictionary " + "or a ColumnCollection such as the `.c.` collection " + "of a Table object" + ) + self.update_values_to_set = [ + (coercions.expect(roles.DMLColumnRole, key), value) + for key, value in set_.items() + ] + self.update_whereclause = where diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/json.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/json.py new file mode 100644 index 0000000..ec29802 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/json.py @@ -0,0 +1,92 @@ +# dialects/sqlite/json.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ... import types as sqltypes + + +class JSON(sqltypes.JSON): + """SQLite JSON type. + + SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note + that JSON1_ is a + `loadable extension `_ and as such + may not be available, or may require run-time loading. + + :class:`_sqlite.JSON` is used automatically whenever the base + :class:`_types.JSON` datatype is used against a SQLite backend. + + .. seealso:: + + :class:`_types.JSON` - main documentation for the generic + cross-platform JSON datatype. + + The :class:`_sqlite.JSON` type supports persistence of JSON values + as well as the core index operations provided by :class:`_types.JSON` + datatype, by adapting the operations to render the ``JSON_EXTRACT`` + function wrapped in the ``JSON_QUOTE`` function at the database level. + Extracted values are quoted in order to ensure that the results are + always JSON string values. + + + .. versionadded:: 1.3 + + + .. _JSON1: https://www.sqlite.org/json1.html + + """ + + +# Note: these objects currently match exactly those of MySQL, however since +# these are not generalizable to all JSON implementations, remain separately +# implemented for each dialect. +class _FormatTypeMixin: + def _format_value(self, value): + raise NotImplementedError() + + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) + + def process(value): + value = self._format_value(value) + if super_proc: + value = super_proc(value) + return value + + return process + + +class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): + def _format_value(self, value): + if isinstance(value, int): + value = "$[%s]" % value + else: + value = '$."%s"' % value + return value + + +class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): + def _format_value(self, value): + return "$%s" % ( + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/provision.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/provision.py new file mode 100644 index 0000000..f18568b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/provision.py @@ -0,0 +1,198 @@ +# dialects/sqlite/provision.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import os +import re + +from ... import exc +from ...engine import url as sa_url +from ...testing.provision import create_db +from ...testing.provision import drop_db +from ...testing.provision import follower_url_from_main +from ...testing.provision import generate_driver_url +from ...testing.provision import log +from ...testing.provision import post_configure_engine +from ...testing.provision import run_reap_dbs +from ...testing.provision import stop_test_class_outside_fixtures +from ...testing.provision import temp_table_keyword_args +from ...testing.provision import upsert + + +# TODO: I can't get this to build dynamically with pytest-xdist procs +_drivernames = { + "pysqlite", + "aiosqlite", + "pysqlcipher", + "pysqlite_numeric", + "pysqlite_dollar", +} + + +def _format_url(url, driver, ident): + """given a sqlite url + desired driver + ident, make a canonical + URL out of it + + """ + url = sa_url.make_url(url) + + if driver is None: + driver = url.get_driver_name() + + filename = url.database + + needs_enc = driver == "pysqlcipher" + name_token = None + + if filename and filename != ":memory:": + assert "test_schema" not in filename + tokens = re.split(r"[_\.]", filename) + + new_filename = f"{driver}" + + for token in tokens: + if token in _drivernames: + if driver is None: + driver = token + continue + elif token in ("db", "enc"): + continue + elif name_token is None: + name_token = token.strip("_") + + assert name_token, f"sqlite filename has no name token: {url.database}" + + new_filename = f"{name_token}_{driver}" + if ident: + new_filename += f"_{ident}" + new_filename += ".db" + if needs_enc: + new_filename += ".enc" + url = url.set(database=new_filename) + + if needs_enc: + url = url.set(password="test") + + url = url.set(drivername="sqlite+%s" % (driver,)) + + return url + + +@generate_driver_url.for_db("sqlite") +def generate_driver_url(url, driver, query_str): + url = _format_url(url, driver, None) + + try: + url.get_dialect() + except exc.NoSuchModuleError: + return None + else: + return url + + +@follower_url_from_main.for_db("sqlite") +def _sqlite_follower_url_from_main(url, ident): + return _format_url(url, None, ident) + + +@post_configure_engine.for_db("sqlite") +def _sqlite_post_configure_engine(url, engine, follower_ident): + from sqlalchemy import event + + if follower_ident: + attach_path = f"{follower_ident}_{engine.driver}_test_schema.db" + else: + attach_path = f"{engine.driver}_test_schema.db" + + @event.listens_for(engine, "connect") + def connect(dbapi_connection, connection_record): + # use file DBs in all cases, memory acts kind of strangely + # as an attached + + # NOTE! this has to be done *per connection*. New sqlite connection, + # as we get with say, QueuePool, the attaches are gone. + # so schemes to delete those attached files have to be done at the + # filesystem level and not rely upon what attachments are in a + # particular SQLite connection + dbapi_connection.execute( + f'ATTACH DATABASE "{attach_path}" AS test_schema' + ) + + @event.listens_for(engine, "engine_disposed") + def dispose(engine): + """most databases should be dropped using + stop_test_class_outside_fixtures + + however a few tests like AttachedDBTest might not get triggered on + that main hook + + """ + + if os.path.exists(attach_path): + os.remove(attach_path) + + filename = engine.url.database + + if filename and filename != ":memory:" and os.path.exists(filename): + os.remove(filename) + + +@create_db.for_db("sqlite") +def _sqlite_create_db(cfg, eng, ident): + pass + + +@drop_db.for_db("sqlite") +def _sqlite_drop_db(cfg, eng, ident): + _drop_dbs_w_ident(eng.url.database, eng.driver, ident) + + +def _drop_dbs_w_ident(databasename, driver, ident): + for path in os.listdir("."): + fname, ext = os.path.split(path) + if ident in fname and ext in [".db", ".db.enc"]: + log.info("deleting SQLite database file: %s", path) + os.remove(path) + + +@stop_test_class_outside_fixtures.for_db("sqlite") +def stop_test_class_outside_fixtures(config, db, cls): + db.dispose() + + +@temp_table_keyword_args.for_db("sqlite") +def _sqlite_temp_table_keyword_args(cfg, eng): + return {"prefixes": ["TEMPORARY"]} + + +@run_reap_dbs.for_db("sqlite") +def _reap_sqlite_dbs(url, idents): + log.info("db reaper connecting to %r", url) + log.info("identifiers in file: %s", ", ".join(idents)) + url = sa_url.make_url(url) + for ident in idents: + for drivername in _drivernames: + _drop_dbs_w_ident(url.database, drivername, ident) + + +@upsert.for_db("sqlite") +def _upsert( + cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False +): + from sqlalchemy.dialects.sqlite import insert + + stmt = insert(table) + + if set_lambda: + stmt = stmt.on_conflict_do_update(set_=set_lambda(stmt.excluded)) + else: + stmt = stmt.on_conflict_do_nothing() + + stmt = stmt.returning( + *returning, sort_by_parameter_order=sort_by_parameter_order + ) + return stmt diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py new file mode 100644 index 0000000..388a4df --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -0,0 +1,155 @@ +# dialects/sqlite/pysqlcipher.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +""" +.. dialect:: sqlite+pysqlcipher + :name: pysqlcipher + :dbapi: sqlcipher 3 or pysqlcipher + :connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=] + + Dialect for support of DBAPIs that make use of the + `SQLCipher `_ backend. + + +Driver +------ + +Current dialect selection logic is: + +* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module, + that module is used. +* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/ +* If not available, fall back to https://pypi.org/project/pysqlcipher3/ +* For Python 2, https://pypi.org/project/pysqlcipher/ is used. + +.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no + longer maintained; the ``sqlcipher3`` driver as of this writing appears + to be current. For future compatibility, any pysqlcipher-compatible DBAPI + may be used as follows:: + + import sqlcipher_compatible_driver + + from sqlalchemy import create_engine + + e = create_engine( + "sqlite+pysqlcipher://:password@/dbname.db", + module=sqlcipher_compatible_driver + ) + +These drivers make use of the SQLCipher engine. This system essentially +introduces new PRAGMA commands to SQLite which allows the setting of a +passphrase and other encryption parameters, allowing the database file to be +encrypted. + + +Connect Strings +--------------- + +The format of the connect string is in every way the same as that +of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the +"password" field is now accepted, which should contain a passphrase:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db') + +For an absolute file path, two leading slashes should be used for the +database name:: + + e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db') + +A selection of additional encryption-related pragmas supported by SQLCipher +as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed +in the query string, and will result in that PRAGMA being called for each +new connection. Currently, ``cipher``, ``kdf_iter`` +``cipher_page_size`` and ``cipher_use_hmac`` are supported:: + + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000') + +.. warning:: Previous versions of sqlalchemy did not take into consideration + the encryption-related pragmas passed in the url string, that were silently + ignored. This may cause errors when opening files saved by a + previous sqlalchemy version if the encryption options do not match. + + +Pooling Behavior +---------------- + +The driver makes a change to the default pool behavior of pysqlite +as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver +has been observed to be significantly slower on connection than the +pysqlite driver, most likely due to the encryption overhead, so the +dialect here defaults to using the :class:`.SingletonThreadPool` +implementation, +instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool +implementation is entirely configurable using the +:paramref:`_sa.create_engine.poolclass` parameter; the :class:`. +StaticPool` may +be more feasible for single-threaded use, or :class:`.NullPool` may be used +to prevent unencrypted connections from being held open for long periods of +time, at the expense of slower startup time for new connections. + + +""" # noqa + +from .pysqlite import SQLiteDialect_pysqlite +from ... import pool + + +class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): + driver = "pysqlcipher" + supports_statement_cache = True + + pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac") + + @classmethod + def import_dbapi(cls): + try: + import sqlcipher3 as sqlcipher + except ImportError: + pass + else: + return sqlcipher + + from pysqlcipher3 import dbapi2 as sqlcipher + + return sqlcipher + + @classmethod + def get_pool_class(cls, url): + return pool.SingletonThreadPool + + def on_connect_url(self, url): + super_on_connect = super().on_connect_url(url) + + # pull the info we need from the URL early. Even though URL + # is immutable, we don't want any in-place changes to the URL + # to affect things + passphrase = url.password or "" + url_query = dict(url.query) + + def on_connect(conn): + cursor = conn.cursor() + cursor.execute('pragma key="%s"' % passphrase) + for prag in self.pragmas: + value = url_query.get(prag, None) + if value is not None: + cursor.execute('pragma %s="%s"' % (prag, value)) + cursor.close() + + if super_on_connect: + super_on_connect(conn) + + return on_connect + + def create_connect_args(self, url): + plain_url = url._replace(password=None) + plain_url = plain_url.difference_update_query(self.pragmas) + return super().create_connect_args(plain_url) + + +dialect = SQLiteDialect_pysqlcipher diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py new file mode 100644 index 0000000..f39baf3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py @@ -0,0 +1,756 @@ +# dialects/sqlite/pysqlite.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +r""" +.. dialect:: sqlite+pysqlite + :name: pysqlite + :dbapi: sqlite3 + :connectstring: sqlite+pysqlite:///file_path + :url: https://docs.python.org/library/sqlite3.html + + Note that ``pysqlite`` is the same driver as the ``sqlite3`` + module included with the Python distribution. + +Driver +------ + +The ``sqlite3`` Python DBAPI is standard on all modern Python versions; +for cPython and Pypy, no additional installation is necessary. + + +Connect Strings +--------------- + +The file specification for the SQLite database is taken as the "database" +portion of the URL. Note that the format of a SQLAlchemy url is:: + + driver://user:pass@host/database + +This means that the actual filename to be used starts with the characters to +the **right** of the third slash. So connecting to a relative filepath +looks like:: + + # relative path + e = create_engine('sqlite:///path/to/database.db') + +An absolute path, which is denoted by starting with a slash, means you +need **four** slashes:: + + # absolute path + e = create_engine('sqlite:////path/to/database.db') + +To use a Windows path, regular drive specifications and backslashes can be +used. Double backslashes are probably needed:: + + # absolute path on Windows + e = create_engine('sqlite:///C:\\path\\to\\database.db') + +To use sqlite ``:memory:`` database specify it as the filename using +``sqlite://:memory:``. It's also the default if no filepath is +present, specifying only ``sqlite://`` and nothing else:: + + # in-memory database + e = create_engine('sqlite://:memory:') + # also in-memory database + e2 = create_engine('sqlite://') + +.. _pysqlite_uri_connections: + +URI Connections +^^^^^^^^^^^^^^^ + +Modern versions of SQLite support an alternative system of connecting using a +`driver level URI `_, which has the advantage +that additional driver-level arguments can be passed including options such as +"read only". The Python sqlite3 driver supports this mode under modern Python +3 versions. The SQLAlchemy pysqlite driver supports this mode of use by +specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept +as the "database" portion of the SQLAlchemy url (that is, following a slash):: + + e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true") + +.. note:: The "uri=true" parameter must appear in the **query string** + of the URL. It will not currently work as expected if it is only + present in the :paramref:`_sa.create_engine.connect_args` + parameter dictionary. + +The logic reconciles the simultaneous presence of SQLAlchemy's query string and +SQLite's query string by separating out the parameters that belong to the +Python sqlite3 driver vs. those that belong to the SQLite URI. This is +achieved through the use of a fixed list of parameters known to be accepted by +the Python side of the driver. For example, to include a URL that indicates +the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the +SQLite "mode" and "nolock" parameters, they can all be passed together on the +query string:: + + e = create_engine( + "sqlite:///file:path/to/database?" + "check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true" + ) + +Above, the pysqlite / sqlite3 DBAPI would be passed arguments as:: + + sqlite3.connect( + "file:path/to/database?mode=ro&nolock=1", + check_same_thread=True, timeout=10, uri=True + ) + +Regarding future parameters added to either the Python or native drivers. new +parameter names added to the SQLite URI scheme should be automatically +accommodated by this scheme. New parameter names added to the Python driver +side can be accommodated by specifying them in the +:paramref:`_sa.create_engine.connect_args` dictionary, +until dialect support is +added by SQLAlchemy. For the less likely case that the native SQLite driver +adds a new parameter name that overlaps with one of the existing, known Python +driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would +require adjustment for the URL scheme to continue to support this. + +As is always the case for all SQLAlchemy dialects, the entire "URL" process +can be bypassed in :func:`_sa.create_engine` through the use of the +:paramref:`_sa.create_engine.creator` +parameter which allows for a custom callable +that creates a Python sqlite3 driver level connection directly. + +.. versionadded:: 1.3.9 + +.. seealso:: + + `Uniform Resource Identifiers `_ - in + the SQLite documentation + +.. _pysqlite_regexp: + +Regular Expression Support +--------------------------- + +.. versionadded:: 1.4 + +Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided +using Python's re.search_ function. SQLite itself does not include a working +regular expression operator; instead, it includes a non-implemented placeholder +operator ``REGEXP`` that calls a user-defined function that must be provided. + +SQLAlchemy's implementation makes use of the pysqlite create_function_ hook +as follows:: + + + def regexp(a, b): + return re.search(a, b) is not None + + sqlite_connection.create_function( + "regexp", 2, regexp, + ) + +There is currently no support for regular expression flags as a separate +argument, as these are not supported by SQLite's REGEXP operator, however these +may be included inline within the regular expression string. See `Python regular expressions`_ for +details. + +.. seealso:: + + `Python regular expressions`_: Documentation for Python's regular expression syntax. + +.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function + +.. _re.search: https://docs.python.org/3/library/re.html#re.search + +.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search + + + +Compatibility with sqlite3 "native" date and datetime types +----------------------------------------------------------- + +The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and +sqlite3.PARSE_COLNAMES options, which have the effect of any column +or expression explicitly cast as "date" or "timestamp" will be converted +to a Python date or datetime object. The date and datetime types provided +with the pysqlite dialect are not currently compatible with these options, +since they render the ISO date/datetime including microseconds, which +pysqlite's driver does not. Additionally, SQLAlchemy does not at +this time automatically render the "cast" syntax required for the +freestanding functions "current_timestamp" and "current_date" to return +datetime/date types natively. Unfortunately, pysqlite +does not provide the standard DBAPI types in ``cursor.description``, +leaving SQLAlchemy with no way to detect these types on the fly +without expensive per-row type checks. + +Keeping in mind that pysqlite's parsing option is not recommended, +nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES +can be forced if one configures "native_datetime=True" on create_engine():: + + engine = create_engine('sqlite://', + connect_args={'detect_types': + sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, + native_datetime=True + ) + +With this flag enabled, the DATE and TIMESTAMP types (but note - not the +DATETIME or TIME types...confused yet ?) will not perform any bind parameter +or result processing. Execution of "func.current_date()" will return a string. +"func.current_timestamp()" is registered as returning a DATETIME type in +SQLAlchemy, so this function still receives SQLAlchemy-level result +processing. + +.. _pysqlite_threading_pooling: + +Threading/Pooling Behavior +--------------------------- + +The ``sqlite3`` DBAPI by default prohibits the use of a particular connection +in a thread which is not the one in which it was created. As SQLite has +matured, it's behavior under multiple threads has improved, and even includes +options for memory only databases to be used in multiple threads. + +The thread prohibition is known as "check same thread" and may be controlled +using the ``sqlite3`` parameter ``check_same_thread``, which will disable or +enable this check. SQLAlchemy's default behavior here is to set +``check_same_thread`` to ``False`` automatically whenever a file-based database +is in use, to establish compatibility with the default pool class +:class:`.QueuePool`. + +The SQLAlchemy ``pysqlite`` DBAPI establishes the connection pool differently +based on the kind of SQLite database that's requested: + +* When a ``:memory:`` SQLite database is specified, the dialect by default + will use :class:`.SingletonThreadPool`. This pool maintains a single + connection per thread, so that all access to the engine within the current + thread use the same ``:memory:`` database - other threads would access a + different ``:memory:`` database. The ``check_same_thread`` parameter + defaults to ``True``. +* When a file-based database is specified, the dialect will use + :class:`.QueuePool` as the source of connections. at the same time, + the ``check_same_thread`` flag is set to False by default unless overridden. + + .. versionchanged:: 2.0 + + SQLite file database engines now use :class:`.QueuePool` by default. + Previously, :class:`.NullPool` were used. The :class:`.NullPool` class + may be used by specifying it via the + :paramref:`_sa.create_engine.poolclass` parameter. + +Disabling Connection Pooling for File Databases +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Pooling may be disabled for a file based database by specifying the +:class:`.NullPool` implementation for the :func:`_sa.create_engine.poolclass` +parameter:: + + from sqlalchemy import NullPool + engine = create_engine("sqlite:///myfile.db", poolclass=NullPool) + +It's been observed that the :class:`.NullPool` implementation incurs an +extremely small performance overhead for repeated checkouts due to the lack of +connection re-use implemented by :class:`.QueuePool`. However, it still +may be beneficial to use this class if the application is experiencing +issues with files being locked. + +Using a Memory Database in Multiple Threads +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To use a ``:memory:`` database in a multithreaded scenario, the same +connection object must be shared among threads, since the database exists +only within the scope of that connection. The +:class:`.StaticPool` implementation will maintain a single connection +globally, and the ``check_same_thread`` flag can be passed to Pysqlite +as ``False``:: + + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite://', + connect_args={'check_same_thread':False}, + poolclass=StaticPool) + +Note that using a ``:memory:`` database in multiple threads requires a recent +version of SQLite. + +Using Temporary Tables with SQLite +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Due to the way SQLite deals with temporary tables, if you wish to use a +temporary table in a file-based SQLite database across multiple checkouts +from the connection pool, such as when using an ORM :class:`.Session` where +the temporary table should continue to remain after :meth:`.Session.commit` or +:meth:`.Session.rollback` is called, a pool which maintains a single +connection must be used. Use :class:`.SingletonThreadPool` if the scope is +only needed within the current thread, or :class:`.StaticPool` is scope is +needed within multiple threads for this case:: + + # maintain the same connection per thread + from sqlalchemy.pool import SingletonThreadPool + engine = create_engine('sqlite:///mydb.db', + poolclass=SingletonThreadPool) + + + # maintain the same connection across all threads + from sqlalchemy.pool import StaticPool + engine = create_engine('sqlite:///mydb.db', + poolclass=StaticPool) + +Note that :class:`.SingletonThreadPool` should be configured for the number +of threads that are to be used; beyond that number, connections will be +closed out in a non deterministic way. + + +Dealing with Mixed String / Binary Columns +------------------------------------------------------ + +The SQLite database is weakly typed, and as such it is possible when using +binary values, which in Python are represented as ``b'some string'``, that a +particular SQLite database can have data values within different rows where +some of them will be returned as a ``b''`` value by the Pysqlite driver, and +others will be returned as Python strings, e.g. ``''`` values. This situation +is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used +consistently, however if a particular SQLite database has data that was +inserted using the Pysqlite driver directly, or when using the SQLAlchemy +:class:`.String` type which was later changed to :class:`.LargeBinary`, the +table will not be consistently readable because SQLAlchemy's +:class:`.LargeBinary` datatype does not handle strings so it has no way of +"encoding" a value that is in string format. + +To deal with a SQLite table that has mixed string / binary data in the +same column, use a custom type that will check each row individually:: + + from sqlalchemy import String + from sqlalchemy import TypeDecorator + + class MixedBinary(TypeDecorator): + impl = String + cache_ok = True + + def process_result_value(self, value, dialect): + if isinstance(value, str): + value = bytes(value, 'utf-8') + elif value is not None: + value = bytes(value) + + return value + +Then use the above ``MixedBinary`` datatype in the place where +:class:`.LargeBinary` would normally be used. + +.. _pysqlite_serializable: + +Serializable isolation / Savepoints / Transactional DDL +------------------------------------------------------- + +In the section :ref:`sqlite_concurrency`, we refer to the pysqlite +driver's assortment of issues that prevent several features of SQLite +from working correctly. The pysqlite DBAPI driver has several +long-standing bugs which impact the correctness of its transactional +behavior. In its default mode of operation, SQLite features such as +SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are +non-functional, and in order to use these features, workarounds must +be taken. + +The issue is essentially that the driver attempts to second-guess the user's +intent, failing to start transactions and sometimes ending them prematurely, in +an effort to minimize the SQLite databases's file locking behavior, even +though SQLite itself uses "shared" locks for read-only activities. + +SQLAlchemy chooses to not alter this behavior by default, as it is the +long-expected behavior of the pysqlite driver; if and when the pysqlite +driver attempts to repair these issues, that will be more of a driver towards +defaults for SQLAlchemy. + +The good news is that with a few events, we can implement transactional +support fully, by disabling pysqlite's feature entirely and emitting BEGIN +ourselves. This is achieved using two event listeners:: + + from sqlalchemy import create_engine, event + + engine = create_engine("sqlite:///myfile.db") + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN") + +.. warning:: When using the above recipe, it is advised to not use the + :paramref:`.Connection.execution_options.isolation_level` setting on + :class:`_engine.Connection` and :func:`_sa.create_engine` + with the SQLite driver, + as this function necessarily will also alter the ".isolation_level" setting. + + +Above, we intercept a new pysqlite connection and disable any transactional +integration. Then, at the point at which SQLAlchemy knows that transaction +scope is to begin, we emit ``"BEGIN"`` ourselves. + +When we take control of ``"BEGIN"``, we can also control directly SQLite's +locking modes, introduced at +`BEGIN TRANSACTION `_, +by adding the desired locking mode to our ``"BEGIN"``:: + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN EXCLUSIVE") + +.. seealso:: + + `BEGIN TRANSACTION `_ - + on the SQLite site + + `sqlite3 SELECT does not BEGIN a transaction `_ - + on the Python bug tracker + + `sqlite3 module breaks transactions and potentially corrupts data `_ - + on the Python bug tracker + +.. _pysqlite_udfs: + +User-Defined Functions +---------------------- + +pysqlite supports a `create_function() `_ +method that allows us to create our own user-defined functions (UDFs) in Python and use them directly in SQLite queries. +These functions are registered with a specific DBAPI Connection. + +SQLAlchemy uses connection pooling with file-based SQLite databases, so we need to ensure that the UDF is attached to the +connection when it is created. That is accomplished with an event listener:: + + from sqlalchemy import create_engine + from sqlalchemy import event + from sqlalchemy import text + + + def udf(): + return "udf-ok" + + + engine = create_engine("sqlite:///./db_file") + + + @event.listens_for(engine, "connect") + def connect(conn, rec): + conn.create_function("udf", 0, udf) + + + for i in range(5): + with engine.connect() as conn: + print(conn.scalar(text("SELECT UDF()"))) + + +""" # noqa + +import math +import os +import re + +from .base import DATE +from .base import DATETIME +from .base import SQLiteDialect +from ... import exc +from ... import pool +from ... import types as sqltypes +from ... import util + + +class _SQLite_pysqliteTimeStamp(DATETIME): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATETIME.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATETIME.result_processor(self, dialect, coltype) + + +class _SQLite_pysqliteDate(DATE): + def bind_processor(self, dialect): + if dialect.native_datetime: + return None + else: + return DATE.bind_processor(self, dialect) + + def result_processor(self, dialect, coltype): + if dialect.native_datetime: + return None + else: + return DATE.result_processor(self, dialect, coltype) + + +class SQLiteDialect_pysqlite(SQLiteDialect): + default_paramstyle = "qmark" + supports_statement_cache = True + returns_native_bytes = True + + colspecs = util.update_copy( + SQLiteDialect.colspecs, + { + sqltypes.Date: _SQLite_pysqliteDate, + sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp, + }, + ) + + description_encoding = None + + driver = "pysqlite" + + @classmethod + def import_dbapi(cls): + from sqlite3 import dbapi2 as sqlite + + return sqlite + + @classmethod + def _is_url_file_db(cls, url): + if (url.database and url.database != ":memory:") and ( + url.query.get("mode", None) != "memory" + ): + return True + else: + return False + + @classmethod + def get_pool_class(cls, url): + if cls._is_url_file_db(url): + return pool.QueuePool + else: + return pool.SingletonThreadPool + + def _get_server_version_info(self, connection): + return self.dbapi.sqlite_version_info + + _isolation_lookup = SQLiteDialect._isolation_lookup.union( + { + "AUTOCOMMIT": None, + } + ) + + def set_isolation_level(self, dbapi_connection, level): + if level == "AUTOCOMMIT": + dbapi_connection.isolation_level = None + else: + dbapi_connection.isolation_level = "" + return super().set_isolation_level(dbapi_connection, level) + + def on_connect(self): + def regexp(a, b): + if b is None: + return None + return re.search(a, b) is not None + + if util.py38 and self._get_server_version_info(None) >= (3, 9): + # sqlite must be greater than 3.8.3 for deterministic=True + # https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function + # the check is more conservative since there were still issues + # with following 3.8 sqlite versions + create_func_kw = {"deterministic": True} + else: + create_func_kw = {} + + def set_regexp(dbapi_connection): + dbapi_connection.create_function( + "regexp", 2, regexp, **create_func_kw + ) + + def floor_func(dbapi_connection): + # NOTE: floor is optionally present in sqlite 3.35+ , however + # as it is normally non-present we deliver floor() unconditionally + # for now. + # https://www.sqlite.org/lang_mathfunc.html + dbapi_connection.create_function( + "floor", 1, math.floor, **create_func_kw + ) + + fns = [set_regexp, floor_func] + + def connect(conn): + for fn in fns: + fn(conn) + + return connect + + def create_connect_args(self, url): + if url.username or url.password or url.host or url.port: + raise exc.ArgumentError( + "Invalid SQLite URL: %s\n" + "Valid SQLite URL forms are:\n" + " sqlite:///:memory: (or, sqlite://)\n" + " sqlite:///relative/path/to/file.db\n" + " sqlite:////absolute/path/to/file.db" % (url,) + ) + + # theoretically, this list can be augmented, at least as far as + # parameter names accepted by sqlite3/pysqlite, using + # inspect.getfullargspec(). for the moment this seems like overkill + # as these parameters don't change very often, and as always, + # parameters passed to connect_args will always go to the + # sqlite3/pysqlite driver. + pysqlite_args = [ + ("uri", bool), + ("timeout", float), + ("isolation_level", str), + ("detect_types", int), + ("check_same_thread", bool), + ("cached_statements", int), + ] + opts = url.query + pysqlite_opts = {} + for key, type_ in pysqlite_args: + util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts) + + if pysqlite_opts.get("uri", False): + uri_opts = dict(opts) + # here, we are actually separating the parameters that go to + # sqlite3/pysqlite vs. those that go the SQLite URI. What if + # two names conflict? again, this seems to be not the case right + # now, and in the case that new names are added to + # either side which overlap, again the sqlite3/pysqlite parameters + # can be passed through connect_args instead of in the URL. + # If SQLite native URIs add a parameter like "timeout" that + # we already have listed here for the python driver, then we need + # to adjust for that here. + for key, type_ in pysqlite_args: + uri_opts.pop(key, None) + filename = url.database + if uri_opts: + # sorting of keys is for unit test support + filename += "?" + ( + "&".join( + "%s=%s" % (key, uri_opts[key]) + for key in sorted(uri_opts) + ) + ) + else: + filename = url.database or ":memory:" + if filename != ":memory:": + filename = os.path.abspath(filename) + + pysqlite_opts.setdefault( + "check_same_thread", not self._is_url_file_db(url) + ) + + return ([filename], pysqlite_opts) + + def is_disconnect(self, e, connection, cursor): + return isinstance( + e, self.dbapi.ProgrammingError + ) and "Cannot operate on a closed database." in str(e) + + +dialect = SQLiteDialect_pysqlite + + +class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite): + """numeric dialect for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric" + driver = "pysqlite_numeric" + + _first_bind = ":1" + _not_in_statement_regexp = None + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric") + super().__init__(*arg, **kw) + + def create_connect_args(self, url): + arg, opts = super().create_connect_args(url) + opts["factory"] = self._fix_sqlite_issue_99953() + return arg, opts + + def _fix_sqlite_issue_99953(self): + import sqlite3 + + first_bind = self._first_bind + if self._not_in_statement_regexp: + nis = self._not_in_statement_regexp + + def _test_sql(sql): + m = nis.search(sql) + assert not m, f"Found {nis.pattern!r} in {sql!r}" + + else: + + def _test_sql(sql): + pass + + def _numeric_param_as_dict(parameters): + if parameters: + assert isinstance(parameters, tuple) + return { + str(idx): value for idx, value in enumerate(parameters, 1) + } + else: + return () + + class SQLiteFix99953Cursor(sqlite3.Cursor): + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + class SQLiteFix99953Connection(sqlite3.Connection): + def cursor(self, factory=None): + if factory is None: + factory = SQLiteFix99953Cursor + return super().cursor(factory=factory) + + def execute(self, sql, parameters=()): + _test_sql(sql) + if first_bind in sql: + parameters = _numeric_param_as_dict(parameters) + return super().execute(sql, parameters) + + def executemany(self, sql, parameters): + _test_sql(sql) + if first_bind in sql: + parameters = [ + _numeric_param_as_dict(p) for p in parameters + ] + return super().executemany(sql, parameters) + + return SQLiteFix99953Connection + + +class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric): + """numeric dialect that uses $ for testing only + + internal use only. This dialect is **NOT** supported by SQLAlchemy + and may change at any time. + + """ + + supports_statement_cache = True + default_paramstyle = "numeric_dollar" + driver = "pysqlite_dollar" + + _first_bind = "$1" + _not_in_statement_regexp = re.compile(r"[^\d]:\d+") + + def __init__(self, *arg, **kw): + kw.setdefault("paramstyle", "numeric_dollar") + super().__init__(*arg, **kw) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/dialects/type_migration_guidelines.txt b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/type_migration_guidelines.txt new file mode 100644 index 0000000..e6be205 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/dialects/type_migration_guidelines.txt @@ -0,0 +1,145 @@ +Rules for Migrating TypeEngine classes to 0.6 +--------------------------------------------- + +1. the TypeEngine classes are used for: + + a. Specifying behavior which needs to occur for bind parameters + or result row columns. + + b. Specifying types that are entirely specific to the database + in use and have no analogue in the sqlalchemy.types package. + + c. Specifying types where there is an analogue in sqlalchemy.types, + but the database in use takes vendor-specific flags for those + types. + + d. If a TypeEngine class doesn't provide any of this, it should be + *removed* from the dialect. + +2. the TypeEngine classes are *no longer* used for generating DDL. Dialects +now have a TypeCompiler subclass which uses the same visit_XXX model as +other compilers. + +3. the "ischema_names" and "colspecs" dictionaries are now required members on +the Dialect class. + +4. The names of types within dialects are now important. If a dialect-specific type +is a subclass of an existing generic type and is only provided for bind/result behavior, +the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case, +end users would never need to use _PGNumeric directly. However, if a dialect-specific +type is specifying a type *or* arguments that are not present generically, it should +match the real name of the type on that backend, in uppercase. E.g. postgresql.INET, +mysql.ENUM, postgresql.ARRAY. + +Or follow this handy flowchart: + + is the type meant to provide bind/result is the type the same name as an + behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ? + type in types.py ? | | + | no yes + yes | | + | | does your type need special + | +<--- yes --- behavior or arguments ? + | | | + | | no + name the type using | | + _MixedCase, i.e. v V + _OracleBoolean. it name the type don't make a + stays private to the dialect identically as that type, make sure the dialect's + and is invoked *only* via within the DB, base.py imports the types.py + the colspecs dict. using UPPERCASE UPPERCASE name into its namespace + | (i.e. BIT, NCHAR, INTERVAL). + | Users can import it. + | | + v v + subclass the closest is the name of this type + MixedCase type types.py, identical to an UPPERCASE + i.e. <--- no ------- name in types.py ? + class _DateTime(types.DateTime), + class DATETIME2(types.DateTime), | + class BIT(types.TypeEngine). yes + | + v + the type should + subclass the + UPPERCASE + type in types.py + (i.e. class BLOB(types.BLOB)) + + +Example 1. pysqlite needs bind/result processing for the DateTime type in types.py, +which applies to all DateTimes and subclasses. It's named _SLDateTime and +subclasses types.DateTime. + +Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument +that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py, +and subclasses types.TIME. Users can then say mssql.TIME(precision=10). + +Example 3. MS-SQL dialects also need special bind/result processing for date +But its DATE type doesn't render DDL differently than that of a plain +DATE, i.e. it takes no special arguments. Therefore we are just adding behavior +to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses +types.Date. + +Example 4. MySQL has a SET type, there's no analogue for this in types.py. So +MySQL names it SET in the dialect's base.py, and it subclasses types.String, since +it ultimately deals with strings. + +Example 5. PostgreSQL has a DATETIME type. The DBAPIs handle dates correctly, +and no special arguments are used in PG's DDL beyond what types.py provides. +PostgreSQL dialect therefore imports types.DATETIME into its base.py. + +Ideally one should be able to specify a schema using names imported completely from a +dialect, all matching the real name on that backend: + + from sqlalchemy.dialects.postgresql import base as pg + + t = Table('mytable', metadata, + Column('id', pg.INTEGER, primary_key=True), + Column('name', pg.VARCHAR(300)), + Column('inetaddr', pg.INET) + ) + +where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types, +but the PG dialect makes them available in its own namespace. + +5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types +linked to types specified in the dialect. Again, if a type in the dialect does not +specify any special behavior for bind_processor() or result_processor() and does not +indicate a special type only available in this database, it must be *removed* from the +module and from this dictionary. + +6. "ischema_names" indicates string descriptions of types as returned from the database +linked to TypeEngine classes. + + a. The string name should be matched to the most specific type possible within + sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which + case it points to a dialect type. *It doesn't matter* if the dialect has its + own subclass of that type with special bind/result behavior - reflect to the types.py + UPPERCASE type as much as possible. With very few exceptions, all types + should reflect to an UPPERCASE type. + + b. If the dialect contains a matching dialect-specific type that takes extra arguments + which the generic one does not, then point to the dialect-specific type. E.g. + mssql.VARCHAR takes a "collation" parameter which should be preserved. + +5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by +a subclass of compiler.GenericTypeCompiler. + + a. your TypeCompiler class will receive generic and uppercase types from + sqlalchemy.types. Do not assume the presence of dialect-specific attributes on + these types. + + b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with + methods that produce a different DDL name. Uppercase types don't do any kind of + "guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in + all cases, regardless of whether or not that type is legal on the backend database. + + c. the visit_UPPERCASE methods *should* be overridden with methods that add additional + arguments and flags to those types. + + d. the visit_lowercase methods are overridden to provide an interpretation of a generic + type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)". + + e. visit_lowercase methods should *never* render strings directly - it should always + be via calling a visit_UPPERCASE() method. diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__init__.py new file mode 100644 index 0000000..af0f7ee --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__init__.py @@ -0,0 +1,62 @@ +# engine/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""SQL connections, SQL execution and high-level DB-API interface. + +The engine package defines the basic components used to interface +DB-API modules with higher-level statement construction, +connection-management, execution and result contexts. The primary +"entry point" class into this package is the Engine and its public +constructor ``create_engine()``. + +""" + +from . import events as events +from . import util as util +from .base import Connection as Connection +from .base import Engine as Engine +from .base import NestedTransaction as NestedTransaction +from .base import RootTransaction as RootTransaction +from .base import Transaction as Transaction +from .base import TwoPhaseTransaction as TwoPhaseTransaction +from .create import create_engine as create_engine +from .create import create_pool_from_url as create_pool_from_url +from .create import engine_from_config as engine_from_config +from .cursor import CursorResult as CursorResult +from .cursor import ResultProxy as ResultProxy +from .interfaces import AdaptedConnection as AdaptedConnection +from .interfaces import BindTyping as BindTyping +from .interfaces import Compiled as Compiled +from .interfaces import Connectable as Connectable +from .interfaces import ConnectArgsType as ConnectArgsType +from .interfaces import ConnectionEventsTarget as ConnectionEventsTarget +from .interfaces import CreateEnginePlugin as CreateEnginePlugin +from .interfaces import Dialect as Dialect +from .interfaces import ExceptionContext as ExceptionContext +from .interfaces import ExecutionContext as ExecutionContext +from .interfaces import TypeCompiler as TypeCompiler +from .mock import create_mock_engine as create_mock_engine +from .reflection import Inspector as Inspector +from .reflection import ObjectKind as ObjectKind +from .reflection import ObjectScope as ObjectScope +from .result import ChunkedIteratorResult as ChunkedIteratorResult +from .result import FilterResult as FilterResult +from .result import FrozenResult as FrozenResult +from .result import IteratorResult as IteratorResult +from .result import MappingResult as MappingResult +from .result import MergedResult as MergedResult +from .result import Result as Result +from .result import result_tuple as result_tuple +from .result import ScalarResult as ScalarResult +from .result import TupleResult as TupleResult +from .row import BaseRow as BaseRow +from .row import Row as Row +from .row import RowMapping as RowMapping +from .url import make_url as make_url +from .url import URL as URL +from .util import connection_memoize as connection_memoize +from ..sql import ddl as ddl diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..d1e58d7 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_processors.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_processors.cpython-311.pyc new file mode 100644 index 0000000..07c56b0 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_processors.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_row.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_row.cpython-311.pyc new file mode 100644 index 0000000..7540b73 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_row.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_util.cpython-311.pyc new file mode 100644 index 0000000..7b90d87 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/_py_util.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..599160b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/characteristics.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/characteristics.cpython-311.pyc new file mode 100644 index 0000000..1f5071d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/characteristics.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/create.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/create.cpython-311.pyc new file mode 100644 index 0000000..b45c678 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/create.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/cursor.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/cursor.cpython-311.pyc new file mode 100644 index 0000000..632ceaa Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/cursor.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/default.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000..f047237 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/default.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/events.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/events.cpython-311.pyc new file mode 100644 index 0000000..5b57d88 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/events.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/interfaces.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/interfaces.cpython-311.pyc new file mode 100644 index 0000000..3393705 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/interfaces.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/mock.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/mock.cpython-311.pyc new file mode 100644 index 0000000..55058f9 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/mock.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/processors.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/processors.cpython-311.pyc new file mode 100644 index 0000000..cebace5 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/processors.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/reflection.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/reflection.cpython-311.pyc new file mode 100644 index 0000000..e7bf33b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/reflection.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/result.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/result.cpython-311.pyc new file mode 100644 index 0000000..1e4691f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/result.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/row.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/row.cpython-311.pyc new file mode 100644 index 0000000..9eadbb3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/row.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/strategies.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/strategies.cpython-311.pyc new file mode 100644 index 0000000..5729f6e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/strategies.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/url.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/url.cpython-311.pyc new file mode 100644 index 0000000..5d15a72 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/url.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000..89e1582 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/engine/__pycache__/util.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_processors.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_processors.py new file mode 100644 index 0000000..2cc35b5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_processors.py @@ -0,0 +1,136 @@ +# engine/_py_processors.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" + +from __future__ import annotations + +import datetime +from datetime import date as date_cls +from datetime import datetime as datetime_cls +from datetime import time as time_cls +from decimal import Decimal +import typing +from typing import Any +from typing import Callable +from typing import Optional +from typing import Type +from typing import TypeVar +from typing import Union + + +_DT = TypeVar( + "_DT", bound=Union[datetime.datetime, datetime.time, datetime.date] +) + + +def str_to_datetime_processor_factory( + regexp: typing.Pattern[str], type_: Callable[..., _DT] +) -> Callable[[Optional[str]], Optional[_DT]]: + rmatch = regexp.match + # Even on python2.6 datetime.strptime is both slower than this code + # and it does not support microseconds. + has_named_groups = bool(regexp.groupindex) + + def process(value: Optional[str]) -> Optional[_DT]: + if value is None: + return None + else: + try: + m = rmatch(value) + except TypeError as err: + raise ValueError( + "Couldn't parse %s string '%r' " + "- value is not a string." % (type_.__name__, value) + ) from err + + if m is None: + raise ValueError( + "Couldn't parse %s string: " + "'%s'" % (type_.__name__, value) + ) + if has_named_groups: + groups = m.groupdict(0) + return type_( + **dict( + list( + zip( + iter(groups.keys()), + list(map(int, iter(groups.values()))), + ) + ) + ) + ) + else: + return type_(*list(map(int, m.groups(0)))) + + return process + + +def to_decimal_processor_factory( + target_class: Type[Decimal], scale: int +) -> Callable[[Optional[float]], Optional[Decimal]]: + fstring = "%%.%df" % scale + + def process(value: Optional[float]) -> Optional[Decimal]: + if value is None: + return None + else: + return target_class(fstring % value) + + return process + + +def to_float(value: Optional[Union[int, float]]) -> Optional[float]: + if value is None: + return None + else: + return float(value) + + +def to_str(value: Optional[Any]) -> Optional[str]: + if value is None: + return None + else: + return str(value) + + +def int_to_boolean(value: Optional[int]) -> Optional[bool]: + if value is None: + return None + else: + return bool(value) + + +def str_to_datetime(value: Optional[str]) -> Optional[datetime.datetime]: + if value is not None: + dt_value = datetime_cls.fromisoformat(value) + else: + dt_value = None + return dt_value + + +def str_to_time(value: Optional[str]) -> Optional[datetime.time]: + if value is not None: + dt_value = time_cls.fromisoformat(value) + else: + dt_value = None + return dt_value + + +def str_to_date(value: Optional[str]) -> Optional[datetime.date]: + if value is not None: + dt_value = date_cls.fromisoformat(value) + else: + dt_value = None + return dt_value diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_row.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_row.py new file mode 100644 index 0000000..4e1dd7d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_row.py @@ -0,0 +1,128 @@ +# engine/_py_row.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Type + +if typing.TYPE_CHECKING: + from .result import _KeyType + from .result import _ProcessorsType + from .result import _RawRowType + from .result import _TupleGetterType + from .result import ResultMetaData + +MD_INDEX = 0 # integer index in cursor.description + + +class BaseRow: + __slots__ = ("_parent", "_data", "_key_to_index") + + _parent: ResultMetaData + _key_to_index: Mapping[_KeyType, int] + _data: _RawRowType + + def __init__( + self, + parent: ResultMetaData, + processors: Optional[_ProcessorsType], + key_to_index: Mapping[_KeyType, int], + data: _RawRowType, + ): + """Row objects are constructed by CursorResult objects.""" + object.__setattr__(self, "_parent", parent) + + object.__setattr__(self, "_key_to_index", key_to_index) + + if processors: + object.__setattr__( + self, + "_data", + tuple( + [ + proc(value) if proc else value + for proc, value in zip(processors, data) + ] + ), + ) + else: + object.__setattr__(self, "_data", tuple(data)) + + def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]: + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self) -> Dict[str, Any]: + return {"_parent": self._parent, "_data": self._data} + + def __setstate__(self, state: Dict[str, Any]) -> None: + parent = state["_parent"] + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_data", state["_data"]) + object.__setattr__(self, "_key_to_index", parent._key_to_index) + + def _values_impl(self) -> List[Any]: + return list(self) + + def __iter__(self) -> Iterator[Any]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __hash__(self) -> int: + return hash(self._data) + + def __getitem__(self, key: Any) -> Any: + return self._data[key] + + def _get_by_key_impl_mapping(self, key: str) -> Any: + try: + return self._data[self._key_to_index[key]] + except KeyError: + pass + self._parent._key_not_found(key, False) + + def __getattr__(self, name: str) -> Any: + try: + return self._data[self._key_to_index[name]] + except KeyError: + pass + self._parent._key_not_found(name, True) + + def _to_tuple_instance(self) -> Tuple[Any, ...]: + return self._data + + +# This reconstructor is necessary so that pickles with the Cy extension or +# without use the same Binary format. +def rowproxy_reconstructor( + cls: Type[BaseRow], state: Dict[str, Any] +) -> BaseRow: + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + +def tuplegetter(*indexes: int) -> _TupleGetterType: + if len(indexes) != 1: + for i in range(1, len(indexes)): + if indexes[i - 1] != indexes[i] - 1: + return operator.itemgetter(*indexes) + # slice form is faster but returns a list if input is list + return operator.itemgetter(slice(indexes[0], indexes[-1] + 1)) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_util.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_util.py new file mode 100644 index 0000000..2be4322 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/_py_util.py @@ -0,0 +1,74 @@ +# engine/_py_util.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import typing +from typing import Any +from typing import Mapping +from typing import Optional +from typing import Tuple + +from .. import exc + +if typing.TYPE_CHECKING: + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + + +_no_tuple: Tuple[Any, ...] = () + + +def _distill_params_20( + params: Optional[_CoreAnyExecuteParams], +) -> _CoreMultiExecuteParams: + if params is None: + return _no_tuple + # Assume list is more likely than tuple + elif isinstance(params, list) or isinstance(params, tuple): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance(params[0], (tuple, Mapping)): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return params + elif isinstance(params, dict) or isinstance( + # only do immutabledict or abc.__instancecheck__ for Mapping after + # we've checked for plain dictionaries and would otherwise raise + params, + Mapping, + ): + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params( + params: Optional[_DBAPIAnyExecuteParams], +) -> _DBAPIMultiExecuteParams: + if params is None: + return _no_tuple + elif isinstance(params, list): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance(params[0], (tuple, Mapping)): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return params + elif isinstance(params, (tuple, dict)) or isinstance( + # only do abc.__instancecheck__ for Mapping after we've checked + # for plain dictionaries and would otherwise raise + params, + Mapping, + ): + # cast("Union[List[Mapping[str, Any]], Tuple[Any, ...]]", [params]) + return [params] # type: ignore + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/base.py new file mode 100644 index 0000000..403ec45 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/base.py @@ -0,0 +1,3377 @@ +# engine/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`. + +""" +from __future__ import annotations + +import contextlib +import sys +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union + +from .interfaces import BindTyping +from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPICursor +from .interfaces import ExceptionContext +from .interfaces import ExecuteStyle +from .interfaces import ExecutionContext +from .interfaces import IsolationLevel +from .util import _distill_params_20 +from .util import _distill_raw_params +from .util import TransactionalContext +from .. import exc +from .. import inspection +from .. import log +from .. import util +from ..sql import compiler +from ..sql import util as sql_util + +if typing.TYPE_CHECKING: + from . import CursorResult + from . import ScalarResult + from .interfaces import _AnyExecuteParams + from .interfaces import _AnyMultiExecuteParams + from .interfaces import _CoreAnyExecuteParams + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import CompiledCacheType + from .interfaces import CoreExecuteOptionsParameter + from .interfaces import Dialect + from .interfaces import SchemaTranslateMapType + from .reflection import Inspector # noqa + from .url import URL + from ..event import dispatcher + from ..log import _EchoFlagType + from ..pool import _ConnectionFairy + from ..pool import Pool + from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql._typing import _InfoType + from ..sql.compiler import Compiled + from ..sql.ddl import ExecutableDDLElement + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.functions import FunctionElement + from ..sql.schema import DefaultGenerator + from ..sql.schema import HasSchemaAttr + from ..sql.schema import SchemaItem + from ..sql.selectable import TypedReturnsRows + + +_T = TypeVar("_T", bound=Any) +_EMPTY_EXECUTION_OPTS: _ExecuteOptions = util.EMPTY_DICT +NO_OPTIONS: Mapping[str, Any] = util.EMPTY_DICT + + +class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): + """Provides high-level functionality for a wrapped DB-API connection. + + The :class:`_engine.Connection` object is procured by calling the + :meth:`_engine.Engine.connect` method of the :class:`_engine.Engine` + object, and provides services for execution of SQL statements as well + as transaction control. + + The Connection object is **not** thread-safe. While a Connection can be + shared among threads using properly synchronized access, it is still + possible that the underlying DBAPI connection may not support shared + access between threads. Check the DBAPI documentation for details. + + The Connection object represents a single DBAPI connection checked out + from the connection pool. In this state, the connection pool has no + affect upon the connection, including its expiration or timeout state. + For the connection pool to properly manage connections, connections + should be returned to the connection pool (i.e. ``connection.close()``) + whenever the connection is not in use. + + .. index:: + single: thread safety; Connection + + """ + + dialect: Dialect + dispatch: dispatcher[ConnectionEventsTarget] + + _sqla_logger_namespace = "sqlalchemy.engine.Connection" + + # used by sqlalchemy.engine.util.TransactionalContext + _trans_context_manager: Optional[TransactionalContext] = None + + # legacy as of 2.0, should be eventually deprecated and + # removed. was used in the "pre_ping" recipe that's been in the docs + # a long time + should_close_with_result = False + + _dbapi_connection: Optional[PoolProxiedConnection] + + _execution_options: _ExecuteOptions + + _transaction: Optional[RootTransaction] + _nested_transaction: Optional[NestedTransaction] + + def __init__( + self, + engine: Engine, + connection: Optional[PoolProxiedConnection] = None, + _has_events: Optional[bool] = None, + _allow_revalidate: bool = True, + _allow_autobegin: bool = True, + ): + """Construct a new Connection.""" + self.engine = engine + self.dialect = dialect = engine.dialect + + if connection is None: + try: + self._dbapi_connection = engine.raw_connection() + except dialect.loaded_dbapi.Error as err: + Connection._handle_dbapi_exception_noconnection( + err, dialect, engine + ) + raise + else: + self._dbapi_connection = connection + + self._transaction = self._nested_transaction = None + self.__savepoint_seq = 0 + self.__in_begin = False + + self.__can_reconnect = _allow_revalidate + self._allow_autobegin = _allow_autobegin + self._echo = self.engine._should_log_info() + + if _has_events is None: + # if _has_events is sent explicitly as False, + # then don't join the dispatch of the engine; we don't + # want to handle any of the engine's events in that case. + self.dispatch = self.dispatch._join(engine.dispatch) + self._has_events = _has_events or ( + _has_events is None and engine._has_events + ) + + self._execution_options = engine._execution_options + + if self._has_events or self.engine._has_events: + self.dispatch.engine_connect(self) + + @util.memoized_property + def _message_formatter(self) -> Any: + if "logging_token" in self._execution_options: + token = self._execution_options["logging_token"] + return lambda msg: "[%s] %s" % (token, msg) + else: + return None + + def _log_info(self, message: str, *arg: Any, **kw: Any) -> None: + fmt = self._message_formatter + + if fmt: + message = fmt(message) + + if log.STACKLEVEL: + kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET + + self.engine.logger.info(message, *arg, **kw) + + def _log_debug(self, message: str, *arg: Any, **kw: Any) -> None: + fmt = self._message_formatter + + if fmt: + message = fmt(message) + + if log.STACKLEVEL: + kw["stacklevel"] = 1 + log.STACKLEVEL_OFFSET + + self.engine.logger.debug(message, *arg, **kw) + + @property + def _schema_translate_map(self) -> Optional[SchemaTranslateMapType]: + schema_translate_map: Optional[SchemaTranslateMapType] = ( + self._execution_options.get("schema_translate_map", None) + ) + + return schema_translate_map + + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: + """Return the schema name for the given schema item taking into + account current schema translate map. + + """ + + name = obj.schema + schema_translate_map: Optional[SchemaTranslateMapType] = ( + self._execution_options.get("schema_translate_map", None) + ) + + if ( + schema_translate_map + and name in schema_translate_map + and obj._use_schema_map + ): + return schema_translate_map[name] + else: + return name + + def __enter__(self) -> Connection: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + preserve_rowcount: bool = False, + **opt: Any, + ) -> Connection: ... + + @overload + def execution_options(self, **opt: Any) -> Connection: ... + + def execution_options(self, **opt: Any) -> Connection: + r"""Set non-SQL options for the connection which take effect + during execution. + + This method modifies this :class:`_engine.Connection` **in-place**; + the return value is the same :class:`_engine.Connection` object + upon which the method is called. Note that this is in contrast + to the behavior of the ``execution_options`` methods on other + objects such as :meth:`_engine.Engine.execution_options` and + :meth:`_sql.Executable.execution_options`. The rationale is that many + such execution options necessarily modify the state of the base + DBAPI connection in any case so there is no feasible means of + keeping the effect of such an option localized to a "sub" connection. + + .. versionchanged:: 2.0 The :meth:`_engine.Connection.execution_options` + method, in contrast to other objects with this method, modifies + the connection in-place without creating copy of it. + + As discussed elsewhere, the :meth:`_engine.Connection.execution_options` + method accepts any arbitrary parameters including user defined names. + All parameters given are consumable in a number of ways including + by using the :meth:`_engine.Connection.get_execution_options` method. + See the examples at :meth:`_sql.Executable.execution_options` + and :meth:`_engine.Engine.execution_options`. + + The keywords that are currently recognized by SQLAlchemy itself + include all those listed under :meth:`.Executable.execution_options`, + as well as others that are specific to :class:`_engine.Connection`. + + :param compiled_cache: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`. + + A dictionary where :class:`.Compiled` objects + will be cached when the :class:`_engine.Connection` + compiles a clause + expression into a :class:`.Compiled` object. This dictionary will + supersede the statement cache that may be configured on the + :class:`_engine.Engine` itself. If set to None, caching + is disabled, even if the engine has a configured cache size. + + Note that the ORM makes use of its own "compiled" caches for + some operations, including flush operations. The caching + used by the ORM internally supersedes a cache dictionary + specified here. + + :param logging_token: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`, :class:`_sql.Executable`. + + Adds the specified string token surrounded by brackets in log + messages logged by the connection, i.e. the logging that's enabled + either via the :paramref:`_sa.create_engine.echo` flag or via the + ``logging.getLogger("sqlalchemy.engine")`` logger. This allows a + per-connection or per-sub-engine token to be available which is + useful for debugging concurrent connection scenarios. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :ref:`dbengine_logging_tokens` - usage example + + :paramref:`_sa.create_engine.logging_name` - adds a name to the + name used by the Python logger object itself. + + :param isolation_level: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`. + + Set the transaction isolation level for the lifespan of this + :class:`_engine.Connection` object. + Valid values include those string + values accepted by the :paramref:`_sa.create_engine.isolation_level` + parameter passed to :func:`_sa.create_engine`. These levels are + semi-database specific; see individual dialect documentation for + valid levels. + + The isolation level option applies the isolation level by emitting + statements on the DBAPI connection, and **necessarily affects the + original Connection object overall**. The isolation level will remain + at the given setting until explicitly changed, or when the DBAPI + connection itself is :term:`released` to the connection pool, i.e. the + :meth:`_engine.Connection.close` method is called, at which time an + event handler will emit additional statements on the DBAPI connection + in order to revert the isolation level change. + + .. note:: The ``isolation_level`` execution option may only be + established before the :meth:`_engine.Connection.begin` method is + called, as well as before any SQL statements are emitted which + would otherwise trigger "autobegin", or directly after a call to + :meth:`_engine.Connection.commit` or + :meth:`_engine.Connection.rollback`. A database cannot change the + isolation level on a transaction in progress. + + .. note:: The ``isolation_level`` execution option is implicitly + reset if the :class:`_engine.Connection` is invalidated, e.g. via + the :meth:`_engine.Connection.invalidate` method, or if a + disconnection error occurs. The new connection produced after the + invalidation will **not** have the selected isolation level + re-applied to it automatically. + + .. seealso:: + + :ref:`dbapi_autocommit` + + :meth:`_engine.Connection.get_isolation_level` + - view current actual level + + :param no_parameters: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. + + When ``True``, if the final parameter + list or dictionary is totally empty, will invoke the + statement on the cursor as ``cursor.execute(statement)``, + not passing the parameter collection at all. + Some DBAPIs such as psycopg2 and mysql-python consider + percent signs as significant only when parameters are + present; this option allows code to generate SQL + containing percent signs (and possibly other characters) + that is neutral regarding whether it's executed by the DBAPI + or piped into a script that's later invoked by + command line tools. + + :param stream_results: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. + + Indicate to the dialect that results should be + "streamed" and not pre-buffered, if possible. For backends + such as PostgreSQL, MySQL and MariaDB, this indicates the use of + a "server side cursor" as opposed to a client side cursor. + Other backends such as that of Oracle may already use server + side cursors by default. + + The usage of + :paramref:`_engine.Connection.execution_options.stream_results` is + usually combined with setting a fixed number of rows to to be fetched + in batches, to allow for efficient iteration of database rows while + at the same time not loading all result rows into memory at once; + this can be configured on a :class:`_engine.Result` object using the + :meth:`_engine.Result.yield_per` method, after execution has + returned a new :class:`_engine.Result`. If + :meth:`_engine.Result.yield_per` is not used, + the :paramref:`_engine.Connection.execution_options.stream_results` + mode of operation will instead use a dynamically sized buffer + which buffers sets of rows at a time, growing on each batch + based on a fixed growth size up until a limit which may + be configured using the + :paramref:`_engine.Connection.execution_options.max_row_buffer` + parameter. + + When using the ORM to fetch ORM mapped objects from a result, + :meth:`_engine.Result.yield_per` should always be used with + :paramref:`_engine.Connection.execution_options.stream_results`, + so that the ORM does not fetch all rows into new ORM objects at once. + + For typical use, the + :paramref:`_engine.Connection.execution_options.yield_per` execution + option should be preferred, which sets up both + :paramref:`_engine.Connection.execution_options.stream_results` and + :meth:`_engine.Result.yield_per` at once. This option is supported + both at a core level by :class:`_engine.Connection` as well as by the + ORM :class:`_engine.Session`; the latter is described at + :ref:`orm_queryguide_yield_per`. + + .. seealso:: + + :ref:`engine_stream_results` - background on + :paramref:`_engine.Connection.execution_options.stream_results` + + :paramref:`_engine.Connection.execution_options.max_row_buffer` + + :paramref:`_engine.Connection.execution_options.yield_per` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + describing the ORM version of ``yield_per`` + + :param max_row_buffer: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. Sets a maximum + buffer size to use when the + :paramref:`_engine.Connection.execution_options.stream_results` + execution option is used on a backend that supports server side + cursors. The default value if not specified is 1000. + + .. seealso:: + + :paramref:`_engine.Connection.execution_options.stream_results` + + :ref:`engine_stream_results` + + + :param yield_per: Available on: :class:`_engine.Connection`, + :class:`_sql.Executable`. Integer value applied which will + set the :paramref:`_engine.Connection.execution_options.stream_results` + execution option and invoke :meth:`_engine.Result.yield_per` + automatically at once. Allows equivalent functionality as + is present when using this parameter with the ORM. + + .. versionadded:: 1.4.40 + + .. seealso:: + + :ref:`engine_stream_results` - background and examples + on using server side cursors with Core. + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + describing the ORM version of ``yield_per`` + + :param insertmanyvalues_page_size: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`. Number of rows to format into an + INSERT statement when the statement uses "insertmanyvalues" mode, + which is a paged form of bulk insert that is used for many backends + when using :term:`executemany` execution typically in conjunction + with RETURNING. Defaults to 1000. May also be modified on a + per-engine basis using the + :paramref:`_sa.create_engine.insertmanyvalues_page_size` parameter. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + :param schema_translate_map: Available on: :class:`_engine.Connection`, + :class:`_engine.Engine`, :class:`_sql.Executable`. + + A dictionary mapping schema names to schema names, that will be + applied to the :paramref:`_schema.Table.schema` element of each + :class:`_schema.Table` + encountered when SQL or DDL expression elements + are compiled into strings; the resulting schema name will be + converted based on presence in the map of the original name. + + .. seealso:: + + :ref:`schema_translating` + + :param preserve_rowcount: Boolean; when True, the ``cursor.rowcount`` + attribute will be unconditionally memoized within the result and + made available via the :attr:`.CursorResult.rowcount` attribute. + Normally, this attribute is only preserved for UPDATE and DELETE + statements. Using this option, the DBAPIs rowcount value can + be accessed for other kinds of statements such as INSERT and SELECT, + to the degree that the DBAPI supports these statements. See + :attr:`.CursorResult.rowcount` for notes regarding the behavior + of this attribute. + + .. versionadded:: 2.0.28 + + .. seealso:: + + :meth:`_engine.Engine.execution_options` + + :meth:`.Executable.execution_options` + + :meth:`_engine.Connection.get_execution_options` + + :ref:`orm_queryguide_execution_options` - documentation on all + ORM-specific execution options + + """ # noqa + if self._has_events or self.engine._has_events: + self.dispatch.set_connection_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) + self.dialect.set_connection_execution_options(self, opt) + return self + + def get_execution_options(self) -> _ExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded:: 1.3 + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + """ + return self._execution_options + + @property + def _still_open_and_dbapi_connection_is_valid(self) -> bool: + pool_proxied_connection = self._dbapi_connection + return ( + pool_proxied_connection is not None + and pool_proxied_connection.is_valid + ) + + @property + def closed(self) -> bool: + """Return True if this connection is closed.""" + + return self._dbapi_connection is None and not self.__can_reconnect + + @property + def invalidated(self) -> bool: + """Return True if this connection was invalidated. + + This does not indicate whether or not the connection was + invalidated at the pool level, however + + """ + + # prior to 1.4, "invalid" was stored as a state independent of + # "closed", meaning an invalidated connection could be "closed", + # the _dbapi_connection would be None and closed=True, yet the + # "invalid" flag would stay True. This meant that there were + # three separate states (open/valid, closed/valid, closed/invalid) + # when there is really no reason for that; a connection that's + # "closed" does not need to be "invalid". So the state is now + # represented by the two facts alone. + + pool_proxied_connection = self._dbapi_connection + return pool_proxied_connection is None and self.__can_reconnect + + @property + def connection(self) -> PoolProxiedConnection: + """The underlying DB-API connection managed by this Connection. + + This is a SQLAlchemy connection-pool proxied connection + which then has the attribute + :attr:`_pool._ConnectionFairy.dbapi_connection` that refers to the + actual driver connection. + + .. seealso:: + + + :ref:`dbapi_connections` + + """ + + if self._dbapi_connection is None: + try: + return self._revalidate_connection() + except (exc.PendingRollbackError, exc.ResourceClosedError): + raise + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + else: + return self._dbapi_connection + + def get_isolation_level(self) -> IsolationLevel: + """Return the current **actual** isolation level that's present on + the database within the scope of this connection. + + This attribute will perform a live SQL operation against the database + in order to procure the current isolation level, so the value returned + is the actual level on the underlying DBAPI connection regardless of + how this state was set. This will be one of the four actual isolation + modes ``READ UNCOMMITTED``, ``READ COMMITTED``, ``REPEATABLE READ``, + ``SERIALIZABLE``. It will **not** include the ``AUTOCOMMIT`` isolation + level setting. Third party dialects may also feature additional + isolation level settings. + + .. note:: This method **will not report** on the ``AUTOCOMMIT`` + isolation level, which is a separate :term:`dbapi` setting that's + independent of **actual** isolation level. When ``AUTOCOMMIT`` is + in use, the database connection still has a "traditional" isolation + mode in effect, that is typically one of the four values + ``READ UNCOMMITTED``, ``READ COMMITTED``, ``REPEATABLE READ``, + ``SERIALIZABLE``. + + Compare to the :attr:`_engine.Connection.default_isolation_level` + accessor which returns the isolation level that is present on the + database at initial connection time. + + .. seealso:: + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`_sa.create_engine.isolation_level` + - set per :class:`_engine.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`_engine.Connection` isolation level + + """ + dbapi_connection = self.connection.dbapi_connection + assert dbapi_connection is not None + try: + return self.dialect.get_isolation_level(dbapi_connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + @property + def default_isolation_level(self) -> Optional[IsolationLevel]: + """The initial-connection time isolation level associated with the + :class:`_engine.Dialect` in use. + + This value is independent of the + :paramref:`.Connection.execution_options.isolation_level` and + :paramref:`.Engine.execution_options.isolation_level` execution + options, and is determined by the :class:`_engine.Dialect` when the + first connection is created, by performing a SQL query against the + database for the current isolation level before any additional commands + have been emitted. + + Calling this accessor does not invoke any new SQL queries. + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current actual isolation level + + :paramref:`_sa.create_engine.isolation_level` + - set per :class:`_engine.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`_engine.Connection` isolation level + + """ + return self.dialect.default_isolation_level + + def _invalid_transaction(self) -> NoReturn: + raise exc.PendingRollbackError( + "Can't reconnect until invalid %stransaction is rolled " + "back. Please rollback() fully before proceeding" + % ("savepoint " if self._nested_transaction is not None else ""), + code="8s2b", + ) + + def _revalidate_connection(self) -> PoolProxiedConnection: + if self.__can_reconnect and self.invalidated: + if self._transaction is not None: + self._invalid_transaction() + self._dbapi_connection = self.engine.raw_connection() + return self._dbapi_connection + raise exc.ResourceClosedError("This Connection is closed") + + @property + def info(self) -> _InfoType: + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`_engine.Connection`, allowing user-defined + data to be associated with the connection. + + The data here will follow along with the DBAPI connection including + after it is returned to the connection pool and used again + in subsequent instances of :class:`_engine.Connection`. + + """ + + return self.connection.info + + def invalidate(self, exception: Optional[BaseException] = None) -> None: + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + An attempt will be made to close the underlying DBAPI connection + immediately; however if this operation fails, the error is logged + but not raised. The connection is then discarded whether or not + close() succeeded. + + Upon the next use (where "use" typically means using the + :meth:`_engine.Connection.execute` method or similar), + this :class:`_engine.Connection` will attempt to + procure a new DBAPI connection using the services of the + :class:`_pool.Pool` as a source of connectivity (e.g. + a "reconnection"). + + If a transaction was in progress (e.g. the + :meth:`_engine.Connection.begin` method has been called) when + :meth:`_engine.Connection.invalidate` method is called, at the DBAPI + level all state associated with this transaction is lost, as + the DBAPI connection is closed. The :class:`_engine.Connection` + will not allow a reconnection to proceed until the + :class:`.Transaction` object is ended, by calling the + :meth:`.Transaction.rollback` method; until that point, any attempt at + continuing to use the :class:`_engine.Connection` will raise an + :class:`~sqlalchemy.exc.InvalidRequestError`. + This is to prevent applications from accidentally + continuing an ongoing transactional operations despite the + fact that the transaction has been lost due to an + invalidation. + + The :meth:`_engine.Connection.invalidate` method, + just like auto-invalidation, + will at the connection pool level invoke the + :meth:`_events.PoolEvents.invalidate` event. + + :param exception: an optional ``Exception`` instance that's the + reason for the invalidation. is passed along to event handlers + and logging functions. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + + if self.invalidated: + return + + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") + + if self._still_open_and_dbapi_connection_is_valid: + pool_proxied_connection = self._dbapi_connection + assert pool_proxied_connection is not None + pool_proxied_connection.invalidate(exception) + + self._dbapi_connection = None + + def detach(self) -> None: + """Detach the underlying DB-API connection from its connection pool. + + E.g.:: + + with engine.connect() as conn: + conn.detach() + conn.execute(text("SET search_path TO schema1, schema2")) + + # work with connection + + # connection is fully closed (since we used "with:", can + # also call .close()) + + This :class:`_engine.Connection` instance will remain usable. + When closed + (or exited from a context manager context as above), + the DB-API connection will be literally closed and not + returned to its originating pool. + + This method can be used to insulate the rest of an application + from a modified state on a connection (such as a transaction + isolation level or similar). + + """ + + if self.closed: + raise exc.ResourceClosedError("This Connection is closed") + + pool_proxied_connection = self._dbapi_connection + if pool_proxied_connection is None: + raise exc.InvalidRequestError( + "Can't detach an invalidated Connection" + ) + pool_proxied_connection.detach() + + def _autobegin(self) -> None: + if self._allow_autobegin and not self.__in_begin: + self.begin() + + def begin(self) -> RootTransaction: + """Begin a transaction prior to autobegin occurring. + + E.g.:: + + with engine.connect() as conn: + with conn.begin() as trans: + conn.execute(table.insert(), {"username": "sandy"}) + + + The returned object is an instance of :class:`_engine.RootTransaction`. + This object represents the "scope" of the transaction, + which completes when either the :meth:`_engine.Transaction.rollback` + or :meth:`_engine.Transaction.commit` method is called; the object + also works as a context manager as illustrated above. + + The :meth:`_engine.Connection.begin` method begins a + transaction that normally will be begun in any case when the connection + is first used to execute a statement. The reason this method might be + used would be to invoke the :meth:`_events.ConnectionEvents.begin` + event at a specific time, or to organize code within the scope of a + connection checkout in terms of context managed blocks, such as:: + + with engine.connect() as conn: + with conn.begin(): + conn.execute(...) + conn.execute(...) + + with conn.begin(): + conn.execute(...) + conn.execute(...) + + The above code is not fundamentally any different in its behavior than + the following code which does not use + :meth:`_engine.Connection.begin`; the below style is known + as "commit as you go" style:: + + with engine.connect() as conn: + conn.execute(...) + conn.execute(...) + conn.commit() + + conn.execute(...) + conn.execute(...) + conn.commit() + + From a database point of view, the :meth:`_engine.Connection.begin` + method does not emit any SQL or change the state of the underlying + DBAPI connection in any way; the Python DBAPI does not have any + concept of explicit transaction begin. + + .. seealso:: + + :ref:`tutorial_working_with_transactions` - in the + :ref:`unified_tutorial` + + :meth:`_engine.Connection.begin_nested` - use a SAVEPOINT + + :meth:`_engine.Connection.begin_twophase` - + use a two phase /XID transaction + + :meth:`_engine.Engine.begin` - context manager available from + :class:`_engine.Engine` + + """ + if self._transaction is None: + self._transaction = RootTransaction(self) + return self._transaction + else: + raise exc.InvalidRequestError( + "This connection has already initialized a SQLAlchemy " + "Transaction() object via begin() or autobegin; can't " + "call begin() here unless rollback() or commit() " + "is called first." + ) + + def begin_nested(self) -> NestedTransaction: + """Begin a nested transaction (i.e. SAVEPOINT) and return a transaction + handle that controls the scope of the SAVEPOINT. + + E.g.:: + + with engine.begin() as connection: + with connection.begin_nested(): + connection.execute(table.insert(), {"username": "sandy"}) + + The returned object is an instance of + :class:`_engine.NestedTransaction`, which includes transactional + methods :meth:`_engine.NestedTransaction.commit` and + :meth:`_engine.NestedTransaction.rollback`; for a nested transaction, + these methods correspond to the operations "RELEASE SAVEPOINT " + and "ROLLBACK TO SAVEPOINT ". The name of the savepoint is local + to the :class:`_engine.NestedTransaction` object and is generated + automatically. Like any other :class:`_engine.Transaction`, the + :class:`_engine.NestedTransaction` may be used as a context manager as + illustrated above which will "release" or "rollback" corresponding to + if the operation within the block were successful or raised an + exception. + + Nested transactions require SAVEPOINT support in the underlying + database, else the behavior is undefined. SAVEPOINT is commonly used to + run operations within a transaction that may fail, while continuing the + outer transaction. E.g.:: + + from sqlalchemy import exc + + with engine.begin() as connection: + trans = connection.begin_nested() + try: + connection.execute(table.insert(), {"username": "sandy"}) + trans.commit() + except exc.IntegrityError: # catch for duplicate username + trans.rollback() # rollback to savepoint + + # outer transaction continues + connection.execute( ... ) + + If :meth:`_engine.Connection.begin_nested` is called without first + calling :meth:`_engine.Connection.begin` or + :meth:`_engine.Engine.begin`, the :class:`_engine.Connection` object + will "autobegin" the outer transaction first. This outer transaction + may be committed using "commit-as-you-go" style, e.g.:: + + with engine.connect() as connection: # begin() wasn't called + + with connection.begin_nested(): will auto-"begin()" first + connection.execute( ... ) + # savepoint is released + + connection.execute( ... ) + + # explicitly commit outer transaction + connection.commit() + + # can continue working with connection here + + .. versionchanged:: 2.0 + + :meth:`_engine.Connection.begin_nested` will now participate + in the connection "autobegin" behavior that is new as of + 2.0 / "future" style connections in 1.4. + + .. seealso:: + + :meth:`_engine.Connection.begin` + + :ref:`session_begin_nested` - ORM support for SAVEPOINT + + """ + if self._transaction is None: + self._autobegin() + + return NestedTransaction(self) + + def begin_twophase(self, xid: Optional[Any] = None) -> TwoPhaseTransaction: + """Begin a two-phase or XA transaction and return a transaction + handle. + + The returned object is an instance of :class:`.TwoPhaseTransaction`, + which in addition to the methods provided by + :class:`.Transaction`, also provides a + :meth:`~.TwoPhaseTransaction.prepare` method. + + :param xid: the two phase transaction id. If not supplied, a + random id will be generated. + + .. seealso:: + + :meth:`_engine.Connection.begin` + + :meth:`_engine.Connection.begin_twophase` + + """ + + if self._transaction is not None: + raise exc.InvalidRequestError( + "Cannot start a two phase transaction when a transaction " + "is already in progress." + ) + if xid is None: + xid = self.engine.dialect.create_xid() + return TwoPhaseTransaction(self, xid) + + def commit(self) -> None: + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + .. note:: The :meth:`_engine.Connection.commit` method only acts upon + the primary database transaction that is linked to the + :class:`_engine.Connection` object. It does not operate upon a + SAVEPOINT that would have been invoked from the + :meth:`_engine.Connection.begin_nested` method; for control of a + SAVEPOINT, call :meth:`_engine.NestedTransaction.commit` on the + :class:`_engine.NestedTransaction` that is returned by the + :meth:`_engine.Connection.begin_nested` method itself. + + + """ + if self._transaction: + self._transaction.commit() + + def rollback(self) -> None: + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + .. note:: The :meth:`_engine.Connection.rollback` method only acts + upon the primary database transaction that is linked to the + :class:`_engine.Connection` object. It does not operate upon a + SAVEPOINT that would have been invoked from the + :meth:`_engine.Connection.begin_nested` method; for control of a + SAVEPOINT, call :meth:`_engine.NestedTransaction.rollback` on the + :class:`_engine.NestedTransaction` that is returned by the + :meth:`_engine.Connection.begin_nested` method itself. + + + """ + if self._transaction: + self._transaction.rollback() + + def recover_twophase(self) -> List[Any]: + return self.engine.dialect.do_recover_twophase(self) + + def rollback_prepared(self, xid: Any, recover: bool = False) -> None: + self.engine.dialect.do_rollback_twophase(self, xid, recover=recover) + + def commit_prepared(self, xid: Any, recover: bool = False) -> None: + self.engine.dialect.do_commit_twophase(self, xid, recover=recover) + + def in_transaction(self) -> bool: + """Return True if a transaction is in progress.""" + return self._transaction is not None and self._transaction.is_active + + def in_nested_transaction(self) -> bool: + """Return True if a transaction is in progress.""" + return ( + self._nested_transaction is not None + and self._nested_transaction.is_active + ) + + def _is_autocommit_isolation(self) -> bool: + opt_iso = self._execution_options.get("isolation_level", None) + return bool( + opt_iso == "AUTOCOMMIT" + or ( + opt_iso is None + and self.engine.dialect._on_connect_isolation_level + == "AUTOCOMMIT" + ) + ) + + def _get_required_transaction(self) -> RootTransaction: + trans = self._transaction + if trans is None: + raise exc.InvalidRequestError("connection is not in a transaction") + return trans + + def _get_required_nested_transaction(self) -> NestedTransaction: + trans = self._nested_transaction + if trans is None: + raise exc.InvalidRequestError( + "connection is not in a nested transaction" + ) + return trans + + def get_transaction(self) -> Optional[RootTransaction]: + """Return the current root transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + + return self._transaction + + def get_nested_transaction(self) -> Optional[NestedTransaction]: + """Return the current nested transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + return self._nested_transaction + + def _begin_impl(self, transaction: RootTransaction) -> None: + if self._echo: + if self._is_autocommit_isolation(): + self._log_info( + "BEGIN (implicit; DBAPI should not BEGIN due to " + "autocommit mode)" + ) + else: + self._log_info("BEGIN (implicit)") + + self.__in_begin = True + + if self._has_events or self.engine._has_events: + self.dispatch.begin(self) + + try: + self.engine.dialect.do_begin(self.connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + self.__in_begin = False + + def _rollback_impl(self) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.rollback(self) + + if self._still_open_and_dbapi_connection_is_valid: + if self._echo: + if self._is_autocommit_isolation(): + self._log_info( + "ROLLBACK using DBAPI connection.rollback(), " + "DBAPI should ignore due to autocommit mode" + ) + else: + self._log_info("ROLLBACK") + try: + self.engine.dialect.do_rollback(self.connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _commit_impl(self) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.commit(self) + + if self._echo: + if self._is_autocommit_isolation(): + self._log_info( + "COMMIT using DBAPI connection.commit(), " + "DBAPI should ignore due to autocommit mode" + ) + else: + self._log_info("COMMIT") + try: + self.engine.dialect.do_commit(self.connection) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _savepoint_impl(self, name: Optional[str] = None) -> str: + if self._has_events or self.engine._has_events: + self.dispatch.savepoint(self, name) + + if name is None: + self.__savepoint_seq += 1 + name = "sa_savepoint_%s" % self.__savepoint_seq + self.engine.dialect.do_savepoint(self, name) + return name + + def _rollback_to_savepoint_impl(self, name: str) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.rollback_savepoint(self, name, None) + + if self._still_open_and_dbapi_connection_is_valid: + self.engine.dialect.do_rollback_to_savepoint(self, name) + + def _release_savepoint_impl(self, name: str) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.release_savepoint(self, name, None) + + self.engine.dialect.do_release_savepoint(self, name) + + def _begin_twophase_impl(self, transaction: TwoPhaseTransaction) -> None: + if self._echo: + self._log_info("BEGIN TWOPHASE (implicit)") + if self._has_events or self.engine._has_events: + self.dispatch.begin_twophase(self, transaction.xid) + + self.__in_begin = True + try: + self.engine.dialect.do_begin_twophase(self, transaction.xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + finally: + self.__in_begin = False + + def _prepare_twophase_impl(self, xid: Any) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.prepare_twophase(self, xid) + + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_prepare_twophase(self, xid) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _rollback_twophase_impl(self, xid: Any, is_prepared: bool) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.rollback_twophase(self, xid, is_prepared) + + if self._still_open_and_dbapi_connection_is_valid: + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_rollback_twophase( + self, xid, is_prepared + ) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def _commit_twophase_impl(self, xid: Any, is_prepared: bool) -> None: + if self._has_events or self.engine._has_events: + self.dispatch.commit_twophase(self, xid, is_prepared) + + assert isinstance(self._transaction, TwoPhaseTransaction) + try: + self.engine.dialect.do_commit_twophase(self, xid, is_prepared) + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + def close(self) -> None: + """Close this :class:`_engine.Connection`. + + This results in a release of the underlying database + resources, that is, the DBAPI connection referenced + internally. The DBAPI connection is typically restored + back to the connection-holding :class:`_pool.Pool` referenced + by the :class:`_engine.Engine` that produced this + :class:`_engine.Connection`. Any transactional state present on + the DBAPI connection is also unconditionally released via + the DBAPI connection's ``rollback()`` method, regardless + of any :class:`.Transaction` object that may be + outstanding with regards to this :class:`_engine.Connection`. + + This has the effect of also calling :meth:`_engine.Connection.rollback` + if any transaction is in place. + + After :meth:`_engine.Connection.close` is called, the + :class:`_engine.Connection` is permanently in a closed state, + and will allow no further operations. + + """ + + if self._transaction: + self._transaction.close() + skip_reset = True + else: + skip_reset = False + + if self._dbapi_connection is not None: + conn = self._dbapi_connection + + # as we just closed the transaction, close the connection + # pool connection without doing an additional reset + if skip_reset: + cast("_ConnectionFairy", conn)._close_special( + transaction_reset=True + ) + else: + conn.close() + + # There is a slight chance that conn.close() may have + # triggered an invalidation here in which case + # _dbapi_connection would already be None, however usually + # it will be non-None here and in a "closed" state. + self._dbapi_connection = None + self.__can_reconnect = False + + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Optional[_T]: ... + + @overload + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: ... + + def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + distilled_parameters = _distill_params_20(parameters) + try: + meth = statement._execute_on_scalar + except AttributeError as err: + raise exc.ObjectNotExecutableError(statement) from err + else: + return meth( + self, + distilled_parameters, + execution_options or NO_OPTIONS, + ) + + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: ... + + @overload + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: ... + + def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + """Executes and returns a scalar result set, which yields scalar values + from the first column of each row. + + This method is equivalent to calling :meth:`_engine.Connection.execute` + to receive a :class:`_result.Result` object, then invoking the + :meth:`_result.Result.scalars` method to produce a + :class:`_result.ScalarResult` instance. + + :return: a :class:`_result.ScalarResult` + + .. versionadded:: 1.4.24 + + """ + + return self.execute( + statement, parameters, execution_options=execution_options + ).scalars() + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: ... + + @overload + def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: ... + + def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a SQL statement construct and returns a + :class:`_engine.CursorResult`. + + :param statement: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.ExecutableDDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + distilled_parameters = _distill_params_20(parameters) + try: + meth = statement._execute_on_connection + except AttributeError as err: + raise exc.ObjectNotExecutableError(statement) from err + else: + return meth( + self, + distilled_parameters, + execution_options or NO_OPTIONS, + ) + + def _execute_function( + self, + func: FunctionElement[Any], + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: + """Execute a sql.FunctionElement object.""" + + return self._execute_clauseelement( + func.select(), distilled_parameters, execution_options + ) + + def _execute_default( + self, + default: DefaultGenerator, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Any: + """Execute a schema.ColumnDefault object.""" + + execution_options = self._execution_options.merge_with( + execution_options + ) + + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreAnyExecuteParams] + + # note for event handlers, the "distilled parameters" which is always + # a list of dicts is broken out into separate "multiparams" and + # "params" collections, which allows the handler to distinguish + # between an executemany and execute style set of parameters. + if self._has_events or self.engine._has_events: + ( + default, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + default, distilled_parameters, execution_options + ) + else: + event_multiparams = event_params = None + + try: + conn = self._dbapi_connection + if conn is None: + conn = self._revalidate_connection() + + dialect = self.dialect + ctx = dialect.execution_ctx_cls._init_default( + dialect, self, conn, execution_options + ) + except (exc.PendingRollbackError, exc.ResourceClosedError): + raise + except BaseException as e: + self._handle_dbapi_exception(e, None, None, None, None) + + ret = ctx._exec_default(None, default, None) + + if self._has_events or self.engine._has_events: + self.dispatch.after_execute( + self, + default, + event_multiparams, + event_params, + execution_options, + ret, + ) + + return ret + + def _execute_ddl( + self, + ddl: ExecutableDDLElement, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: + """Execute a schema.DDL object.""" + + exec_opts = ddl._execution_options.merge_with( + self._execution_options, execution_options + ) + + event_multiparams: Optional[_CoreMultiExecuteParams] + event_params: Optional[_CoreSingleExecuteParams] + + if self._has_events or self.engine._has_events: + ( + ddl, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + ddl, distilled_parameters, exec_opts + ) + else: + event_multiparams = event_params = None + + schema_translate_map = exec_opts.get("schema_translate_map", None) + + dialect = self.dialect + + compiled = ddl.compile( + dialect=dialect, schema_translate_map=schema_translate_map + ) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_ddl, + compiled, + None, + exec_opts, + compiled, + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute( + self, + ddl, + event_multiparams, + event_params, + exec_opts, + ret, + ) + return ret + + def _invoke_before_exec_event( + self, + elem: Any, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Tuple[ + Any, + _CoreMultiExecuteParams, + _CoreMultiExecuteParams, + _CoreSingleExecuteParams, + ]: + event_multiparams: _CoreMultiExecuteParams + event_params: _CoreSingleExecuteParams + + if len(distilled_params) == 1: + event_multiparams, event_params = [], distilled_params[0] + else: + event_multiparams, event_params = distilled_params, {} + + for fn in self.dispatch.before_execute: + elem, event_multiparams, event_params = fn( + self, + elem, + event_multiparams, + event_params, + execution_options, + ) + + if event_multiparams: + distilled_params = list(event_multiparams) + if event_params: + raise exc.InvalidRequestError( + "Event handler can't return non-empty multiparams " + "and params at the same time" + ) + elif event_params: + distilled_params = [event_params] + else: + distilled_params = [] + + return elem, distilled_params, event_multiparams, event_params + + def _execute_clauseelement( + self, + elem: Executable, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: + """Execute a sql.ClauseElement object.""" + + execution_options = elem._execution_options.merge_with( + self._execution_options, execution_options + ) + + has_events = self._has_events or self.engine._has_events + if has_events: + ( + elem, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + elem, distilled_parameters, execution_options + ) + + if distilled_parameters: + # ensure we don't retain a link to the view object for keys() + # which links to the values, which we don't want to cache + keys = sorted(distilled_parameters[0]) + for_executemany = len(distilled_parameters) > 1 + else: + keys = [] + for_executemany = False + + dialect = self.dialect + + schema_translate_map = execution_options.get( + "schema_translate_map", None + ) + + compiled_cache: Optional[CompiledCacheType] = execution_options.get( + "compiled_cache", self.engine._compiled_cache + ) + + compiled_sql, extracted_params, cache_hit = elem._compile_w_cache( + dialect=dialect, + compiled_cache=compiled_cache, + column_keys=keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + linting=self.dialect.compiler_linting | compiler.WARN_LINTING, + ) + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled_sql, + distilled_parameters, + execution_options, + compiled_sql, + distilled_parameters, + elem, + extracted_params, + cache_hit=cache_hit, + ) + if has_events: + self.dispatch.after_execute( + self, + elem, + event_multiparams, + event_params, + execution_options, + ret, + ) + return ret + + def _execute_compiled( + self, + compiled: Compiled, + distilled_parameters: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter = _EMPTY_EXECUTION_OPTS, + ) -> CursorResult[Any]: + """Execute a sql.Compiled object. + + TODO: why do we have this? likely deprecate or remove + + """ + + execution_options = compiled.execution_options.merge_with( + self._execution_options, execution_options + ) + + if self._has_events or self.engine._has_events: + ( + compiled, + distilled_parameters, + event_multiparams, + event_params, + ) = self._invoke_before_exec_event( + compiled, distilled_parameters, execution_options + ) + + dialect = self.dialect + + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_compiled, + compiled, + distilled_parameters, + execution_options, + compiled, + distilled_parameters, + None, + None, + ) + if self._has_events or self.engine._has_events: + self.dispatch.after_execute( + self, + compiled, + event_multiparams, + event_params, + execution_options, + ret, + ) + return ret + + def exec_driver_sql( + self, + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a string SQL statement on the DBAPI cursor directly, + without any SQL compilation steps. + + This can be used to pass any string directly to the + ``cursor.execute()`` method of the DBAPI in use. + + :param statement: The statement str to be executed. Bound parameters + must use the underlying DBAPI's paramstyle, such as "qmark", + "pyformat", "format", etc. + + :param parameters: represent bound parameter values to be used in the + execution. The format is one of: a dictionary of named parameters, + a tuple of positional parameters, or a list containing either + dictionaries or tuples for multiple-execute support. + + :return: a :class:`_engine.CursorResult`. + + E.g. multiple dictionaries:: + + + conn.exec_driver_sql( + "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", + [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}] + ) + + Single dictionary:: + + conn.exec_driver_sql( + "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", + dict(id=1, value="v1") + ) + + Single tuple:: + + conn.exec_driver_sql( + "INSERT INTO table (id, value) VALUES (?, ?)", + (1, 'v1') + ) + + .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does + not participate in the + :meth:`_events.ConnectionEvents.before_execute` and + :meth:`_events.ConnectionEvents.after_execute` events. To + intercept calls to :meth:`_engine.Connection.exec_driver_sql`, use + :meth:`_events.ConnectionEvents.before_cursor_execute` and + :meth:`_events.ConnectionEvents.after_cursor_execute`. + + .. seealso:: + + :pep:`249` + + """ + + distilled_parameters = _distill_raw_params(parameters) + + execution_options = self._execution_options.merge_with( + execution_options + ) + + dialect = self.dialect + ret = self._execute_context( + dialect, + dialect.execution_ctx_cls._init_statement, + statement, + None, + execution_options, + statement, + distilled_parameters, + ) + + return ret + + def _execute_context( + self, + dialect: Dialect, + constructor: Callable[..., ExecutionContext], + statement: Union[str, Compiled], + parameters: Optional[_AnyMultiExecuteParams], + execution_options: _ExecuteOptions, + *args: Any, + **kw: Any, + ) -> CursorResult[Any]: + """Create an :class:`.ExecutionContext` and execute, returning + a :class:`_engine.CursorResult`.""" + + if execution_options: + yp = execution_options.get("yield_per", None) + if yp: + execution_options = execution_options.union( + {"stream_results": True, "max_row_buffer": yp} + ) + try: + conn = self._dbapi_connection + if conn is None: + conn = self._revalidate_connection() + + context = constructor( + dialect, self, conn, execution_options, *args, **kw + ) + except (exc.PendingRollbackError, exc.ResourceClosedError): + raise + except BaseException as e: + self._handle_dbapi_exception( + e, str(statement), parameters, None, None + ) + + if ( + self._transaction + and not self._transaction.is_active + or ( + self._nested_transaction + and not self._nested_transaction.is_active + ) + ): + self._invalid_transaction() + + elif self._trans_context_manager: + TransactionalContext._trans_ctx_check(self) + + if self._transaction is None: + self._autobegin() + + context.pre_exec() + + if context.execute_style is ExecuteStyle.INSERTMANYVALUES: + return self._exec_insertmany_context(dialect, context) + else: + return self._exec_single_context( + dialect, context, statement, parameters + ) + + def _exec_single_context( + self, + dialect: Dialect, + context: ExecutionContext, + statement: Union[str, Compiled], + parameters: Optional[_AnyMultiExecuteParams], + ) -> CursorResult[Any]: + """continue the _execute_context() method for a single DBAPI + cursor.execute() or cursor.executemany() call. + + """ + if dialect.bind_typing is BindTyping.SETINPUTSIZES: + generic_setinputsizes = context._prepare_set_input_sizes() + + if generic_setinputsizes: + try: + dialect.do_set_input_sizes( + context.cursor, generic_setinputsizes, context + ) + except BaseException as e: + self._handle_dbapi_exception( + e, str(statement), parameters, None, context + ) + + cursor, str_statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) + + effective_parameters: Optional[_AnyExecuteParams] + + if not context.executemany: + effective_parameters = parameters[0] + else: + effective_parameters = parameters + + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + str_statement, effective_parameters = fn( + self, + cursor, + str_statement, + effective_parameters, + context, + context.executemany, + ) + + if self._echo: + self._log_info(str_statement) + + stats = context._get_cache_stats() + + if not self.engine.hide_parameters: + self._log_info( + "[%s] %r", + stats, + sql_util._repr_params( + effective_parameters, + batches=10, + ismulti=context.executemany, + ), + ) + else: + self._log_info( + "[%s] [SQL parameters hidden due to hide_parameters=True]", + stats, + ) + + evt_handled: bool = False + try: + if context.execute_style is ExecuteStyle.EXECUTEMANY: + effective_parameters = cast( + "_CoreMultiExecuteParams", effective_parameters + ) + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_executemany: + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): + evt_handled = True + break + if not evt_handled: + self.dialect.do_executemany( + cursor, + str_statement, + effective_parameters, + context, + ) + elif not effective_parameters and context.no_parameters: + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute_no_params: + if fn(cursor, str_statement, context): + evt_handled = True + break + if not evt_handled: + self.dialect.do_execute_no_params( + cursor, str_statement, context + ) + else: + effective_parameters = cast( + "_CoreSingleExecuteParams", effective_parameters + ) + if self.dialect._has_events: + for fn in self.dialect.dispatch.do_execute: + if fn( + cursor, + str_statement, + effective_parameters, + context, + ): + evt_handled = True + break + if not evt_handled: + self.dialect.do_execute( + cursor, str_statement, effective_parameters, context + ) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute( + self, + cursor, + str_statement, + effective_parameters, + context, + context.executemany, + ) + + context.post_exec() + + result = context._setup_result_proxy() + + except BaseException as e: + self._handle_dbapi_exception( + e, str_statement, effective_parameters, cursor, context + ) + + return result + + def _exec_insertmany_context( + self, + dialect: Dialect, + context: ExecutionContext, + ) -> CursorResult[Any]: + """continue the _execute_context() method for an "insertmanyvalues" + operation, which will invoke DBAPI + cursor.execute() one or more times with individual log and + event hook calls. + + """ + + if dialect.bind_typing is BindTyping.SETINPUTSIZES: + generic_setinputsizes = context._prepare_set_input_sizes() + else: + generic_setinputsizes = None + + cursor, str_statement, parameters = ( + context.cursor, + context.statement, + context.parameters, + ) + + effective_parameters = parameters + + engine_events = self._has_events or self.engine._has_events + if self.dialect._has_events: + do_execute_dispatch: Iterable[Any] = ( + self.dialect.dispatch.do_execute + ) + else: + do_execute_dispatch = () + + if self._echo: + stats = context._get_cache_stats() + " (insertmanyvalues)" + + preserve_rowcount = context.execution_options.get( + "preserve_rowcount", False + ) + rowcount = 0 + + for imv_batch in dialect._deliver_insertmanyvalues_batches( + cursor, + str_statement, + effective_parameters, + generic_setinputsizes, + context, + ): + if imv_batch.processed_setinputsizes: + try: + dialect.do_set_input_sizes( + context.cursor, + imv_batch.processed_setinputsizes, + context, + ) + except BaseException as e: + self._handle_dbapi_exception( + e, + sql_util._long_statement(imv_batch.replaced_statement), + imv_batch.replaced_parameters, + None, + context, + ) + + sub_stmt = imv_batch.replaced_statement + sub_params = imv_batch.replaced_parameters + + if engine_events: + for fn in self.dispatch.before_cursor_execute: + sub_stmt, sub_params = fn( + self, + cursor, + sub_stmt, + sub_params, + context, + True, + ) + + if self._echo: + self._log_info(sql_util._long_statement(sub_stmt)) + + imv_stats = f""" {imv_batch.batchnum}/{ + imv_batch.total_batches + } ({ + 'ordered' + if imv_batch.rows_sorted else 'unordered' + }{ + '; batch not supported' + if imv_batch.is_downgraded + else '' + })""" + + if imv_batch.batchnum == 1: + stats += imv_stats + else: + stats = f"insertmanyvalues{imv_stats}" + + if not self.engine.hide_parameters: + self._log_info( + "[%s] %r", + stats, + sql_util._repr_params( + sub_params, + batches=10, + ismulti=False, + ), + ) + else: + self._log_info( + "[%s] [SQL parameters hidden due to " + "hide_parameters=True]", + stats, + ) + + try: + for fn in do_execute_dispatch: + if fn( + cursor, + sub_stmt, + sub_params, + context, + ): + break + else: + dialect.do_execute( + cursor, + sub_stmt, + sub_params, + context, + ) + + except BaseException as e: + self._handle_dbapi_exception( + e, + sql_util._long_statement(sub_stmt), + sub_params, + cursor, + context, + is_sub_exec=True, + ) + + if engine_events: + self.dispatch.after_cursor_execute( + self, + cursor, + str_statement, + effective_parameters, + context, + context.executemany, + ) + + if preserve_rowcount: + rowcount += imv_batch.current_batch_size + + try: + context.post_exec() + + if preserve_rowcount: + context._rowcount = rowcount # type: ignore[attr-defined] + + result = context._setup_result_proxy() + + except BaseException as e: + self._handle_dbapi_exception( + e, str_statement, effective_parameters, cursor, context + ) + + return result + + def _cursor_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: Optional[ExecutionContext] = None, + ) -> None: + """Execute a statement + params on the given cursor. + + Adds appropriate logging and exception handling. + + This method is used by DefaultDialect for special-case + executions, such as for sequences and column defaults. + The path of statement execution in the majority of cases + terminates at _execute_context(). + + """ + if self._has_events or self.engine._has_events: + for fn in self.dispatch.before_cursor_execute: + statement, parameters = fn( + self, cursor, statement, parameters, context, False + ) + + if self._echo: + self._log_info(statement) + self._log_info("[raw sql] %r", parameters) + try: + for fn in ( + () + if not self.dialect._has_events + else self.dialect.dispatch.do_execute + ): + if fn(cursor, statement, parameters, context): + break + else: + self.dialect.do_execute(cursor, statement, parameters, context) + except BaseException as e: + self._handle_dbapi_exception( + e, statement, parameters, cursor, context + ) + + if self._has_events or self.engine._has_events: + self.dispatch.after_cursor_execute( + self, cursor, statement, parameters, context, False + ) + + def _safe_close_cursor(self, cursor: DBAPICursor) -> None: + """Close the given cursor, catching exceptions + and turning into log warnings. + + """ + try: + cursor.close() + except Exception: + # log the error through the connection pool's logger. + self.engine.pool.logger.error( + "Error closing cursor", exc_info=True + ) + + _reentrant_error = False + _is_disconnect = False + + def _handle_dbapi_exception( + self, + e: BaseException, + statement: Optional[str], + parameters: Optional[_AnyExecuteParams], + cursor: Optional[DBAPICursor], + context: Optional[ExecutionContext], + is_sub_exec: bool = False, + ) -> NoReturn: + exc_info = sys.exc_info() + + is_exit_exception = util.is_exit_exception(e) + + if not self._is_disconnect: + self._is_disconnect = ( + isinstance(e, self.dialect.loaded_dbapi.Error) + and not self.closed + and self.dialect.is_disconnect( + e, + self._dbapi_connection if not self.invalidated else None, + cursor, + ) + ) or (is_exit_exception and not self.closed) + + invalidate_pool_on_disconnect = not is_exit_exception + + ismulti: bool = ( + not is_sub_exec and context.executemany + if context is not None + else False + ) + if self._reentrant_error: + raise exc.DBAPIError.instance( + statement, + parameters, + e, + self.dialect.loaded_dbapi.Error, + hide_parameters=self.engine.hide_parameters, + dialect=self.dialect, + ismulti=ismulti, + ).with_traceback(exc_info[2]) from e + self._reentrant_error = True + try: + # non-DBAPI error - if we already got a context, + # or there's no string statement, don't wrap it + should_wrap = isinstance(e, self.dialect.loaded_dbapi.Error) or ( + statement is not None + and context is None + and not is_exit_exception + ) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + statement, + parameters, + cast(Exception, e), + self.dialect.loaded_dbapi.Error, + hide_parameters=self.engine.hide_parameters, + connection_invalidated=self._is_disconnect, + dialect=self.dialect, + ismulti=ismulti, + ) + else: + sqlalchemy_exception = None + + newraise = None + + if (self.dialect._has_events) and not self._execution_options.get( + "skip_user_error_events", False + ): + ctx = ExceptionContextImpl( + e, + sqlalchemy_exception, + self.engine, + self.dialect, + self, + cursor, + statement, + parameters, + context, + self._is_disconnect, + invalidate_pool_on_disconnect, + False, + ) + + for fn in self.dialect.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if self._is_disconnect != ctx.is_disconnect: + self._is_disconnect = ctx.is_disconnect + if sqlalchemy_exception: + sqlalchemy_exception.connection_invalidated = ( + ctx.is_disconnect + ) + + # set up potentially user-defined value for + # invalidate pool. + invalidate_pool_on_disconnect = ( + ctx.invalidate_pool_on_disconnect + ) + + if should_wrap and context: + context.handle_dbapi_exception(e) + + if not self._is_disconnect: + if cursor: + self._safe_close_cursor(cursor) + # "autorollback" was mostly relevant in 1.x series. + # It's very unlikely to reach here, as the connection + # does autobegin so when we are here, we are usually + # in an explicit / semi-explicit transaction. + # however we have a test which manufactures this + # scenario in any case using an event handler. + # test/engine/test_execute.py-> test_actual_autorollback + if not self.in_transaction(): + self._rollback_impl() + + if newraise: + raise newraise.with_traceback(exc_info[2]) from e + elif should_wrap: + assert sqlalchemy_exception is not None + raise sqlalchemy_exception.with_traceback(exc_info[2]) from e + else: + assert exc_info[1] is not None + raise exc_info[1].with_traceback(exc_info[2]) + finally: + del self._reentrant_error + if self._is_disconnect: + del self._is_disconnect + if not self.invalidated: + dbapi_conn_wrapper = self._dbapi_connection + assert dbapi_conn_wrapper is not None + if invalidate_pool_on_disconnect: + self.engine.pool._invalidate(dbapi_conn_wrapper, e) + self.invalidate(e) + + @classmethod + def _handle_dbapi_exception_noconnection( + cls, + e: BaseException, + dialect: Dialect, + engine: Optional[Engine] = None, + is_disconnect: Optional[bool] = None, + invalidate_pool_on_disconnect: bool = True, + is_pre_ping: bool = False, + ) -> NoReturn: + exc_info = sys.exc_info() + + if is_disconnect is None: + is_disconnect = isinstance( + e, dialect.loaded_dbapi.Error + ) and dialect.is_disconnect(e, None, None) + + should_wrap = isinstance(e, dialect.loaded_dbapi.Error) + + if should_wrap: + sqlalchemy_exception = exc.DBAPIError.instance( + None, + None, + cast(Exception, e), + dialect.loaded_dbapi.Error, + hide_parameters=( + engine.hide_parameters if engine is not None else False + ), + connection_invalidated=is_disconnect, + dialect=dialect, + ) + else: + sqlalchemy_exception = None + + newraise = None + + if dialect._has_events: + ctx = ExceptionContextImpl( + e, + sqlalchemy_exception, + engine, + dialect, + None, + None, + None, + None, + None, + is_disconnect, + invalidate_pool_on_disconnect, + is_pre_ping, + ) + for fn in dialect.dispatch.handle_error: + try: + # handler returns an exception; + # call next handler in a chain + per_fn = fn(ctx) + if per_fn is not None: + ctx.chained_exception = newraise = per_fn + except Exception as _raised: + # handler raises an exception - stop processing + newraise = _raised + break + + if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: + sqlalchemy_exception.connection_invalidated = is_disconnect = ( + ctx.is_disconnect + ) + + if newraise: + raise newraise.with_traceback(exc_info[2]) from e + elif should_wrap: + assert sqlalchemy_exception is not None + raise sqlalchemy_exception.with_traceback(exc_info[2]) from e + else: + assert exc_info[1] is not None + raise exc_info[1].with_traceback(exc_info[2]) + + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, + **kwargs: Any, + ) -> None: + """run a DDL visitor. + + This method is only here so that the MockConnection can change the + options given to the visitor so that "checkfirst" is skipped. + + """ + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + + +class ExceptionContextImpl(ExceptionContext): + """Implement the :class:`.ExceptionContext` interface.""" + + __slots__ = ( + "connection", + "engine", + "dialect", + "cursor", + "statement", + "parameters", + "original_exception", + "sqlalchemy_exception", + "chained_exception", + "execution_context", + "is_disconnect", + "invalidate_pool_on_disconnect", + "is_pre_ping", + ) + + def __init__( + self, + exception: BaseException, + sqlalchemy_exception: Optional[exc.StatementError], + engine: Optional[Engine], + dialect: Dialect, + connection: Optional[Connection], + cursor: Optional[DBAPICursor], + statement: Optional[str], + parameters: Optional[_DBAPIAnyExecuteParams], + context: Optional[ExecutionContext], + is_disconnect: bool, + invalidate_pool_on_disconnect: bool, + is_pre_ping: bool, + ): + self.engine = engine + self.dialect = dialect + self.connection = connection + self.sqlalchemy_exception = sqlalchemy_exception + self.original_exception = exception + self.execution_context = context + self.statement = statement + self.parameters = parameters + self.is_disconnect = is_disconnect + self.invalidate_pool_on_disconnect = invalidate_pool_on_disconnect + self.is_pre_ping = is_pre_ping + + +class Transaction(TransactionalContext): + """Represent a database transaction in progress. + + The :class:`.Transaction` object is procured by + calling the :meth:`_engine.Connection.begin` method of + :class:`_engine.Connection`:: + + from sqlalchemy import create_engine + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + connection = engine.connect() + trans = connection.begin() + connection.execute(text("insert into x (a, b) values (1, 2)")) + trans.commit() + + The object provides :meth:`.rollback` and :meth:`.commit` + methods in order to control transaction boundaries. It + also implements a context manager interface so that + the Python ``with`` statement can be used with the + :meth:`_engine.Connection.begin` method:: + + with connection.begin(): + connection.execute(text("insert into x (a, b) values (1, 2)")) + + The Transaction object is **not** threadsafe. + + .. seealso:: + + :meth:`_engine.Connection.begin` + + :meth:`_engine.Connection.begin_twophase` + + :meth:`_engine.Connection.begin_nested` + + .. index:: + single: thread safety; Transaction + """ # noqa + + __slots__ = () + + _is_root: bool = False + is_active: bool + connection: Connection + + def __init__(self, connection: Connection): + raise NotImplementedError() + + @property + def _deactivated_from_connection(self) -> bool: + """True if this transaction is totally deactivated from the connection + and therefore can no longer affect its state. + + """ + raise NotImplementedError() + + def _do_close(self) -> None: + raise NotImplementedError() + + def _do_rollback(self) -> None: + raise NotImplementedError() + + def _do_commit(self) -> None: + raise NotImplementedError() + + @property + def is_valid(self) -> bool: + return self.is_active and not self.connection.invalidated + + def close(self) -> None: + """Close this :class:`.Transaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + try: + self._do_close() + finally: + assert not self.is_active + + def rollback(self) -> None: + """Roll back this :class:`.Transaction`. + + The implementation of this may vary based on the type of transaction in + use: + + * For a simple database transaction (e.g. :class:`.RootTransaction`), + it corresponds to a ROLLBACK. + + * For a :class:`.NestedTransaction`, it corresponds to a + "ROLLBACK TO SAVEPOINT" operation. + + * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two + phase transactions may be used. + + + """ + try: + self._do_rollback() + finally: + assert not self.is_active + + def commit(self) -> None: + """Commit this :class:`.Transaction`. + + The implementation of this may vary based on the type of transaction in + use: + + * For a simple database transaction (e.g. :class:`.RootTransaction`), + it corresponds to a COMMIT. + + * For a :class:`.NestedTransaction`, it corresponds to a + "RELEASE SAVEPOINT" operation. + + * For a :class:`.TwoPhaseTransaction`, DBAPI-specific methods for two + phase transactions may be used. + + """ + try: + self._do_commit() + finally: + assert not self.is_active + + def _get_subject(self) -> Connection: + return self.connection + + def _transaction_is_active(self) -> bool: + return self.is_active + + def _transaction_is_closed(self) -> bool: + return not self._deactivated_from_connection + + def _rollback_can_be_called(self) -> bool: + # for RootTransaction / NestedTransaction, it's safe to call + # rollback() even if the transaction is deactive and no warnings + # will be emitted. tested in + # test_transaction.py -> test_no_rollback_in_deactive(?:_savepoint)? + return True + + +class RootTransaction(Transaction): + """Represent the "root" transaction on a :class:`_engine.Connection`. + + This corresponds to the current "BEGIN/COMMIT/ROLLBACK" that's occurring + for the :class:`_engine.Connection`. The :class:`_engine.RootTransaction` + is created by calling upon the :meth:`_engine.Connection.begin` method, and + remains associated with the :class:`_engine.Connection` throughout its + active span. The current :class:`_engine.RootTransaction` in use is + accessible via the :attr:`_engine.Connection.get_transaction` method of + :class:`_engine.Connection`. + + In :term:`2.0 style` use, the :class:`_engine.Connection` also employs + "autobegin" behavior that will create a new + :class:`_engine.RootTransaction` whenever a connection in a + non-transactional state is used to emit commands on the DBAPI connection. + The scope of the :class:`_engine.RootTransaction` in 2.0 style + use can be controlled using the :meth:`_engine.Connection.commit` and + :meth:`_engine.Connection.rollback` methods. + + + """ + + _is_root = True + + __slots__ = ("connection", "is_active") + + def __init__(self, connection: Connection): + assert connection._transaction is None + if connection._trans_context_manager: + TransactionalContext._trans_ctx_check(connection) + self.connection = connection + self._connection_begin_impl() + connection._transaction = self + + self.is_active = True + + def _deactivate_from_connection(self) -> None: + if self.is_active: + assert self.connection._transaction is self + self.is_active = False + + elif self.connection._transaction is not self: + util.warn("transaction already deassociated from connection") + + @property + def _deactivated_from_connection(self) -> bool: + return self.connection._transaction is not self + + def _connection_begin_impl(self) -> None: + self.connection._begin_impl(self) + + def _connection_rollback_impl(self) -> None: + self.connection._rollback_impl() + + def _connection_commit_impl(self) -> None: + self.connection._commit_impl() + + def _close_impl(self, try_deactivate: bool = False) -> None: + try: + if self.is_active: + self._connection_rollback_impl() + + if self.connection._nested_transaction: + self.connection._nested_transaction._cancel() + finally: + if self.is_active or try_deactivate: + self._deactivate_from_connection() + if self.connection._transaction is self: + self.connection._transaction = None + + assert not self.is_active + assert self.connection._transaction is not self + + def _do_close(self) -> None: + self._close_impl() + + def _do_rollback(self) -> None: + self._close_impl(try_deactivate=True) + + def _do_commit(self) -> None: + if self.is_active: + assert self.connection._transaction is self + + try: + self._connection_commit_impl() + finally: + # whether or not commit succeeds, cancel any + # nested transactions, make this transaction "inactive" + # and remove it as a reset agent + if self.connection._nested_transaction: + self.connection._nested_transaction._cancel() + + self._deactivate_from_connection() + + # ...however only remove as the connection's current transaction + # if commit succeeded. otherwise it stays on so that a rollback + # needs to occur. + self.connection._transaction = None + else: + if self.connection._transaction is self: + self.connection._invalid_transaction() + else: + raise exc.InvalidRequestError("This transaction is inactive") + + assert not self.is_active + assert self.connection._transaction is not self + + +class NestedTransaction(Transaction): + """Represent a 'nested', or SAVEPOINT transaction. + + The :class:`.NestedTransaction` object is created by calling the + :meth:`_engine.Connection.begin_nested` method of + :class:`_engine.Connection`. + + When using :class:`.NestedTransaction`, the semantics of "begin" / + "commit" / "rollback" are as follows: + + * the "begin" operation corresponds to the "BEGIN SAVEPOINT" command, where + the savepoint is given an explicit name that is part of the state + of this object. + + * The :meth:`.NestedTransaction.commit` method corresponds to a + "RELEASE SAVEPOINT" operation, using the savepoint identifier associated + with this :class:`.NestedTransaction`. + + * The :meth:`.NestedTransaction.rollback` method corresponds to a + "ROLLBACK TO SAVEPOINT" operation, using the savepoint identifier + associated with this :class:`.NestedTransaction`. + + The rationale for mimicking the semantics of an outer transaction in + terms of savepoints so that code may deal with a "savepoint" transaction + and an "outer" transaction in an agnostic way. + + .. seealso:: + + :ref:`session_begin_nested` - ORM version of the SAVEPOINT API. + + """ + + __slots__ = ("connection", "is_active", "_savepoint", "_previous_nested") + + _savepoint: str + + def __init__(self, connection: Connection): + assert connection._transaction is not None + if connection._trans_context_manager: + TransactionalContext._trans_ctx_check(connection) + self.connection = connection + self._savepoint = self.connection._savepoint_impl() + self.is_active = True + self._previous_nested = connection._nested_transaction + connection._nested_transaction = self + + def _deactivate_from_connection(self, warn: bool = True) -> None: + if self.connection._nested_transaction is self: + self.connection._nested_transaction = self._previous_nested + elif warn: + util.warn( + "nested transaction already deassociated from connection" + ) + + @property + def _deactivated_from_connection(self) -> bool: + return self.connection._nested_transaction is not self + + def _cancel(self) -> None: + # called by RootTransaction when the outer transaction is + # committed, rolled back, or closed to cancel all savepoints + # without any action being taken + self.is_active = False + self._deactivate_from_connection() + if self._previous_nested: + self._previous_nested._cancel() + + def _close_impl( + self, deactivate_from_connection: bool, warn_already_deactive: bool + ) -> None: + try: + if ( + self.is_active + and self.connection._transaction + and self.connection._transaction.is_active + ): + self.connection._rollback_to_savepoint_impl(self._savepoint) + finally: + self.is_active = False + + if deactivate_from_connection: + self._deactivate_from_connection(warn=warn_already_deactive) + + assert not self.is_active + if deactivate_from_connection: + assert self.connection._nested_transaction is not self + + def _do_close(self) -> None: + self._close_impl(True, False) + + def _do_rollback(self) -> None: + self._close_impl(True, True) + + def _do_commit(self) -> None: + if self.is_active: + try: + self.connection._release_savepoint_impl(self._savepoint) + finally: + # nested trans becomes inactive on failed release + # unconditionally. this prevents it from trying to + # emit SQL when it rolls back. + self.is_active = False + + # but only de-associate from connection if it succeeded + self._deactivate_from_connection() + else: + if self.connection._nested_transaction is self: + self.connection._invalid_transaction() + else: + raise exc.InvalidRequestError( + "This nested transaction is inactive" + ) + + +class TwoPhaseTransaction(RootTransaction): + """Represent a two-phase transaction. + + A new :class:`.TwoPhaseTransaction` object may be procured + using the :meth:`_engine.Connection.begin_twophase` method. + + The interface is the same as that of :class:`.Transaction` + with the addition of the :meth:`prepare` method. + + """ + + __slots__ = ("xid", "_is_prepared") + + xid: Any + + def __init__(self, connection: Connection, xid: Any): + self._is_prepared = False + self.xid = xid + super().__init__(connection) + + def prepare(self) -> None: + """Prepare this :class:`.TwoPhaseTransaction`. + + After a PREPARE, the transaction can be committed. + + """ + if not self.is_active: + raise exc.InvalidRequestError("This transaction is inactive") + self.connection._prepare_twophase_impl(self.xid) + self._is_prepared = True + + def _connection_begin_impl(self) -> None: + self.connection._begin_twophase_impl(self) + + def _connection_rollback_impl(self) -> None: + self.connection._rollback_twophase_impl(self.xid, self._is_prepared) + + def _connection_commit_impl(self) -> None: + self.connection._commit_twophase_impl(self.xid, self._is_prepared) + + +class Engine( + ConnectionEventsTarget, log.Identified, inspection.Inspectable["Inspector"] +): + """ + Connects a :class:`~sqlalchemy.pool.Pool` and + :class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a + source of database connectivity and behavior. + + An :class:`_engine.Engine` object is instantiated publicly using the + :func:`~sqlalchemy.create_engine` function. + + .. seealso:: + + :doc:`/core/engines` + + :ref:`connections_toplevel` + + """ + + dispatch: dispatcher[ConnectionEventsTarget] + + _compiled_cache: Optional[CompiledCacheType] + + _execution_options: _ExecuteOptions = _EMPTY_EXECUTION_OPTS + _has_events: bool = False + _connection_cls: Type[Connection] = Connection + _sqla_logger_namespace: str = "sqlalchemy.engine.Engine" + _is_future: bool = False + + _schema_translate_map: Optional[SchemaTranslateMapType] = None + _option_cls: Type[OptionEngine] + + dialect: Dialect + pool: Pool + url: URL + hide_parameters: bool + + def __init__( + self, + pool: Pool, + dialect: Dialect, + url: URL, + logging_name: Optional[str] = None, + echo: Optional[_EchoFlagType] = None, + query_cache_size: int = 500, + execution_options: Optional[Mapping[str, Any]] = None, + hide_parameters: bool = False, + ): + self.pool = pool + self.url = url + self.dialect = dialect + if logging_name: + self.logging_name = logging_name + self.echo = echo + self.hide_parameters = hide_parameters + if query_cache_size != 0: + self._compiled_cache = util.LRUCache( + query_cache_size, size_alert=self._lru_size_alert + ) + else: + self._compiled_cache = None + log.instance_logger(self, echoflag=echo) + if execution_options: + self.update_execution_options(**execution_options) + + def _lru_size_alert(self, cache: util.LRUCache[Any, Any]) -> None: + if self._should_log_info(): + self.logger.info( + "Compiled cache size pruning from %d items to %d. " + "Increase cache size to reduce the frequency of pruning.", + len(cache), + cache.capacity, + ) + + @property + def engine(self) -> Engine: + """Returns this :class:`.Engine`. + + Used for legacy schemes that accept :class:`.Connection` / + :class:`.Engine` objects within the same variable. + + """ + return self + + def clear_compiled_cache(self) -> None: + """Clear the compiled cache associated with the dialect. + + This applies **only** to the built-in cache that is established + via the :paramref:`_engine.create_engine.query_cache_size` parameter. + It will not impact any dictionary caches that were passed via the + :paramref:`.Connection.execution_options.compiled_cache` parameter. + + .. versionadded:: 1.4 + + """ + if self._compiled_cache: + self._compiled_cache.clear() + + def update_execution_options(self, **opt: Any) -> None: + r"""Update the default execution_options dictionary + of this :class:`_engine.Engine`. + + The given keys/values in \**opt are added to the + default execution options that will be used for + all connections. The initial contents of this dictionary + can be sent via the ``execution_options`` parameter + to :func:`_sa.create_engine`. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :meth:`_engine.Engine.execution_options` + + """ + self.dispatch.set_engine_execution_options(self, opt) + self._execution_options = self._execution_options.union(opt) + self.dialect.set_engine_execution_options(self, opt) + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + **opt: Any, + ) -> OptionEngine: ... + + @overload + def execution_options(self, **opt: Any) -> OptionEngine: ... + + def execution_options(self, **opt: Any) -> OptionEngine: + """Return a new :class:`_engine.Engine` that will provide + :class:`_engine.Connection` objects with the given execution options. + + The returned :class:`_engine.Engine` remains related to the original + :class:`_engine.Engine` in that it shares the same connection pool and + other state: + + * The :class:`_pool.Pool` used by the new :class:`_engine.Engine` + is the + same instance. The :meth:`_engine.Engine.dispose` + method will replace + the connection pool instance for the parent engine as well + as this one. + * Event listeners are "cascaded" - meaning, the new + :class:`_engine.Engine` + inherits the events of the parent, and new events can be associated + with the new :class:`_engine.Engine` individually. + * The logging configuration and logging_name is copied from the parent + :class:`_engine.Engine`. + + The intent of the :meth:`_engine.Engine.execution_options` method is + to implement schemes where multiple :class:`_engine.Engine` + objects refer to the same connection pool, but are differentiated + by options that affect some execution-level behavior for each + engine. One such example is breaking into separate "reader" and + "writer" :class:`_engine.Engine` instances, where one + :class:`_engine.Engine` + has a lower :term:`isolation level` setting configured or is even + transaction-disabled using "autocommit". An example of this + configuration is at :ref:`dbapi_autocommit_multiple`. + + Another example is one that + uses a custom option ``shard_id`` which is consumed by an event + to change the current schema on a database connection:: + + from sqlalchemy import event + from sqlalchemy.engine import Engine + + primary_engine = create_engine("mysql+mysqldb://") + shard1 = primary_engine.execution_options(shard_id="shard1") + shard2 = primary_engine.execution_options(shard_id="shard2") + + shards = {"default": "base", "shard_1": "db1", "shard_2": "db2"} + + @event.listens_for(Engine, "before_cursor_execute") + def _switch_shard(conn, cursor, stmt, + params, context, executemany): + shard_id = conn.get_execution_options().get('shard_id', "default") + current_shard = conn.info.get("current_shard", None) + + if current_shard != shard_id: + cursor.execute("use %s" % shards[shard_id]) + conn.info["current_shard"] = shard_id + + The above recipe illustrates two :class:`_engine.Engine` objects that + will each serve as factories for :class:`_engine.Connection` objects + that have pre-established "shard_id" execution options present. A + :meth:`_events.ConnectionEvents.before_cursor_execute` event handler + then interprets this execution option to emit a MySQL ``use`` statement + to switch databases before a statement execution, while at the same + time keeping track of which database we've established using the + :attr:`_engine.Connection.info` dictionary. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + - update execution options + on a :class:`_engine.Connection` object. + + :meth:`_engine.Engine.update_execution_options` + - update the execution + options for a given :class:`_engine.Engine` in place. + + :meth:`_engine.Engine.get_execution_options` + + + """ # noqa: E501 + return self._option_cls(self, opt) + + def get_execution_options(self) -> _ExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded: 1.3 + + .. seealso:: + + :meth:`_engine.Engine.execution_options` + """ + return self._execution_options + + @property + def name(self) -> str: + """String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + """ + + return self.dialect.name + + @property + def driver(self) -> str: + """Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + """ + + return self.dialect.driver + + echo = log.echo_property() + + def __repr__(self) -> str: + return "Engine(%r)" % (self.url,) + + def dispose(self, close: bool = True) -> None: + """Dispose of the connection pool used by this + :class:`_engine.Engine`. + + A new connection pool is created immediately after the old one has been + disposed. The previous connection pool is disposed either actively, by + closing out all currently checked-in connections in that pool, or + passively, by losing references to it but otherwise not closing any + connections. The latter strategy is more appropriate for an initializer + in a forked Python process. + + :param close: if left at its default of ``True``, has the + effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`_engine.Engine`, + so when they are closed individually, eventually the + :class:`_pool.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + If set to ``False``, the previous connection pool is de-referenced, + and otherwise not touched in any way. + + .. versionadded:: 1.4.33 Added the :paramref:`.Engine.dispose.close` + parameter to allow the replacement of a connection pool in a child + process without interfering with the connections used by the parent + process. + + + .. seealso:: + + :ref:`engine_disposal` + + :ref:`pooling_multiprocessing` + + """ + if close: + self.pool.dispose() + self.pool = self.pool.recreate() + self.dispatch.engine_disposed(self) + + @contextlib.contextmanager + def _optional_conn_ctx_manager( + self, connection: Optional[Connection] = None + ) -> Iterator[Connection]: + if connection is None: + with self.connect() as conn: + yield conn + else: + yield connection + + @contextlib.contextmanager + def begin(self) -> Iterator[Connection]: + """Return a context manager delivering a :class:`_engine.Connection` + with a :class:`.Transaction` established. + + E.g.:: + + with engine.begin() as conn: + conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + conn.execute(text("my_special_procedure(5)")) + + Upon successful operation, the :class:`.Transaction` + is committed. If an error is raised, the :class:`.Transaction` + is rolled back. + + .. seealso:: + + :meth:`_engine.Engine.connect` - procure a + :class:`_engine.Connection` from + an :class:`_engine.Engine`. + + :meth:`_engine.Connection.begin` - start a :class:`.Transaction` + for a particular :class:`_engine.Connection`. + + """ + with self.connect() as conn: + with conn.begin(): + yield conn + + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, + **kwargs: Any, + ) -> None: + with self.begin() as conn: + conn._run_ddl_visitor(visitorcallable, element, **kwargs) + + def connect(self) -> Connection: + """Return a new :class:`_engine.Connection` object. + + The :class:`_engine.Connection` acts as a Python context manager, so + the typical use of this method looks like:: + + with engine.connect() as connection: + connection.execute(text("insert into table values ('foo')")) + connection.commit() + + Where above, after the block is completed, the connection is "closed" + and its underlying DBAPI resources are returned to the connection pool. + This also has the effect of rolling back any transaction that + was explicitly begun or was begun via autobegin, and will + emit the :meth:`_events.ConnectionEvents.rollback` event if one was + started and is still in progress. + + .. seealso:: + + :meth:`_engine.Engine.begin` + + """ + + return self._connection_cls(self) + + def raw_connection(self) -> PoolProxiedConnection: + """Return a "raw" DBAPI connection from the connection pool. + + The returned object is a proxied version of the DBAPI + connection object used by the underlying driver in use. + The object will have all the same behavior as the real DBAPI + connection, except that its ``close()`` method will result in the + connection being returned to the pool, rather than being closed + for real. + + This method provides direct DBAPI connection access for + special situations when the API provided by + :class:`_engine.Connection` + is not needed. When a :class:`_engine.Connection` object is already + present, the DBAPI connection is available using + the :attr:`_engine.Connection.connection` accessor. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return self.pool.connect() + + +class OptionEngineMixin(log.Identified): + _sa_propagate_class_events = False + + dispatch: dispatcher[ConnectionEventsTarget] + _compiled_cache: Optional[CompiledCacheType] + dialect: Dialect + pool: Pool + url: URL + hide_parameters: bool + echo: log.echo_property + + def __init__( + self, proxied: Engine, execution_options: CoreExecuteOptionsParameter + ): + self._proxied = proxied + self.url = proxied.url + self.dialect = proxied.dialect + self.logging_name = proxied.logging_name + self.echo = proxied.echo + self._compiled_cache = proxied._compiled_cache + self.hide_parameters = proxied.hide_parameters + log.instance_logger(self, echoflag=self.echo) + + # note: this will propagate events that are assigned to the parent + # engine after this OptionEngine is created. Since we share + # the events of the parent we also disallow class-level events + # to apply to the OptionEngine class directly. + # + # the other way this can work would be to transfer existing + # events only, using: + # self.dispatch._update(proxied.dispatch) + # + # that might be more appropriate however it would be a behavioral + # change for logic that assigns events to the parent engine and + # would like it to take effect for the already-created sub-engine. + self.dispatch = self.dispatch._join(proxied.dispatch) + + self._execution_options = proxied._execution_options + self.update_execution_options(**execution_options) + + def update_execution_options(self, **opt: Any) -> None: + raise NotImplementedError() + + if not typing.TYPE_CHECKING: + # https://github.com/python/typing/discussions/1095 + + @property + def pool(self) -> Pool: + return self._proxied.pool + + @pool.setter + def pool(self, pool: Pool) -> None: + self._proxied.pool = pool + + @property + def _has_events(self) -> bool: + return self._proxied._has_events or self.__dict__.get( + "_has_events", False + ) + + @_has_events.setter + def _has_events(self, value: bool) -> None: + self.__dict__["_has_events"] = value + + +class OptionEngine(OptionEngineMixin, Engine): + def update_execution_options(self, **opt: Any) -> None: + Engine.update_execution_options(self, **opt) + + +Engine._option_cls = OptionEngine diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/characteristics.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/characteristics.py new file mode 100644 index 0000000..7dd3a2f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/characteristics.py @@ -0,0 +1,81 @@ +# engine/characteristics.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import abc +import typing +from typing import Any +from typing import ClassVar + +if typing.TYPE_CHECKING: + from .interfaces import DBAPIConnection + from .interfaces import Dialect + + +class ConnectionCharacteristic(abc.ABC): + """An abstract base for an object that can set, get and reset a + per-connection characteristic, typically one that gets reset when the + connection is returned to the connection pool. + + transaction isolation is the canonical example, and the + ``IsolationLevelCharacteristic`` implementation provides this for the + ``DefaultDialect``. + + The ``ConnectionCharacteristic`` class should call upon the ``Dialect`` for + the implementation of each method. The object exists strictly to serve as + a dialect visitor that can be placed into the + ``DefaultDialect.connection_characteristics`` dictionary where it will take + effect for calls to :meth:`_engine.Connection.execution_options` and + related APIs. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + transactional: ClassVar[bool] = False + + @abc.abstractmethod + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: + """Reset the characteristic on the connection to its default value.""" + + @abc.abstractmethod + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: + """set characteristic on the connection to a given value.""" + + @abc.abstractmethod + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: + """Given a DBAPI connection, get the current value of the + characteristic. + + """ + + +class IsolationLevelCharacteristic(ConnectionCharacteristic): + transactional: ClassVar[bool] = True + + def reset_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> None: + dialect.reset_isolation_level(dbapi_conn) + + def set_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any + ) -> None: + dialect._assert_and_set_isolation_level(dbapi_conn, value) + + def get_characteristic( + self, dialect: Dialect, dbapi_conn: DBAPIConnection + ) -> Any: + return dialect.get_isolation_level(dbapi_conn) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/create.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/create.py new file mode 100644 index 0000000..74a3cf8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/create.py @@ -0,0 +1,875 @@ +# engine/create.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import inspect +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import overload +from typing import Type +from typing import Union + +from . import base +from . import url as _url +from .interfaces import DBAPIConnection +from .mock import create_mock_engine +from .. import event +from .. import exc +from .. import util +from ..pool import _AdhocProxiedConnection +from ..pool import ConnectionPoolEntry +from ..sql import compiler +from ..util import immutabledict + +if typing.TYPE_CHECKING: + from .base import Engine + from .interfaces import _ExecuteOptions + from .interfaces import _ParamStyle + from .interfaces import IsolationLevel + from .url import URL + from ..log import _EchoFlagType + from ..pool import _CreatorFnType + from ..pool import _CreatorWRecFnType + from ..pool import _ResetStyleArgType + from ..pool import Pool + from ..util.typing import Literal + + +@overload +def create_engine( + url: Union[str, URL], + *, + connect_args: Dict[Any, Any] = ..., + convert_unicode: bool = ..., + creator: Union[_CreatorFnType, _CreatorWRecFnType] = ..., + echo: _EchoFlagType = ..., + echo_pool: _EchoFlagType = ..., + enable_from_linting: bool = ..., + execution_options: _ExecuteOptions = ..., + future: Literal[True], + hide_parameters: bool = ..., + implicit_returning: Literal[True] = ..., + insertmanyvalues_page_size: int = ..., + isolation_level: IsolationLevel = ..., + json_deserializer: Callable[..., Any] = ..., + json_serializer: Callable[..., Any] = ..., + label_length: Optional[int] = ..., + logging_name: str = ..., + max_identifier_length: Optional[int] = ..., + max_overflow: int = ..., + module: Optional[Any] = ..., + paramstyle: Optional[_ParamStyle] = ..., + pool: Optional[Pool] = ..., + poolclass: Optional[Type[Pool]] = ..., + pool_logging_name: str = ..., + pool_pre_ping: bool = ..., + pool_size: int = ..., + pool_recycle: int = ..., + pool_reset_on_return: Optional[_ResetStyleArgType] = ..., + pool_timeout: float = ..., + pool_use_lifo: bool = ..., + plugins: List[str] = ..., + query_cache_size: int = ..., + use_insertmanyvalues: bool = ..., + **kwargs: Any, +) -> Engine: ... + + +@overload +def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: ... + + +@util.deprecated_params( + strategy=( + "1.4", + "The :paramref:`_sa.create_engine.strategy` keyword is deprecated, " + "and the only argument accepted is 'mock'; please use " + ":func:`.create_mock_engine` going forward. For general " + "customization of create_engine which may have been accomplished " + "using strategies, see :class:`.CreateEnginePlugin`.", + ), + empty_in_strategy=( + "1.4", + "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is " + "deprecated, and no longer has any effect. All IN expressions " + "are now rendered using " + 'the "expanding parameter" strategy which renders a set of bound' + 'expressions, or an "empty set" SELECT, at statement execution' + "time.", + ), + implicit_returning=( + "2.0", + "The :paramref:`_sa.create_engine.implicit_returning` parameter " + "is deprecated and will be removed in a future release. ", + ), +) +def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: + """Create a new :class:`_engine.Engine` instance. + + The standard calling form is to send the :ref:`URL ` as the + first positional argument, usually a string + that indicates database dialect and connection arguments:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + + .. note:: + + Please review :ref:`database_urls` for general guidelines in composing + URL strings. In particular, special characters, such as those often + part of passwords, must be URL encoded to be properly parsed. + + Additional keyword arguments may then follow it which + establish various options on the resulting :class:`_engine.Engine` + and its underlying :class:`.Dialect` and :class:`_pool.Pool` + constructs:: + + engine = create_engine("mysql+mysqldb://scott:tiger@hostname/dbname", + pool_recycle=3600, echo=True) + + The string form of the URL is + ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where + ``dialect`` is a database name such as ``mysql``, ``oracle``, + ``postgresql``, etc., and ``driver`` the name of a DBAPI, such as + ``psycopg2``, ``pyodbc``, ``cx_oracle``, etc. Alternatively, + the URL can be an instance of :class:`~sqlalchemy.engine.url.URL`. + + ``**kwargs`` takes a wide variety of options which are routed + towards their appropriate components. Arguments may be specific to + the :class:`_engine.Engine`, the underlying :class:`.Dialect`, + as well as the + :class:`_pool.Pool`. Specific dialects also accept keyword arguments that + are unique to that dialect. Here, we describe the parameters + that are common to most :func:`_sa.create_engine()` usage. + + Once established, the newly resulting :class:`_engine.Engine` will + request a connection from the underlying :class:`_pool.Pool` once + :meth:`_engine.Engine.connect` is called, or a method which depends on it + such as :meth:`_engine.Engine.execute` is invoked. The + :class:`_pool.Pool` in turn + will establish the first actual DBAPI connection when this request + is received. The :func:`_sa.create_engine` call itself does **not** + establish any actual DBAPI connections directly. + + .. seealso:: + + :doc:`/core/engines` + + :doc:`/dialects/index` + + :ref:`connections_toplevel` + + :param connect_args: a dictionary of options which will be + passed directly to the DBAPI's ``connect()`` method as + additional keyword arguments. See the example + at :ref:`custom_dbapi_args`. + + :param creator: a callable which returns a DBAPI connection. + This creation function will be passed to the underlying + connection pool and will be used to create all new database + connections. Usage of this function causes connection + parameters specified in the URL argument to be bypassed. + + This hook is not as flexible as the newer + :meth:`_events.DialectEvents.do_connect` hook which allows complete + control over how a connection is made to the database, given the full + set of URL arguments and state beforehand. + + .. seealso:: + + :meth:`_events.DialectEvents.do_connect` - event hook that allows + full control over DBAPI connection mechanics. + + :ref:`custom_dbapi_args` + + :param echo=False: if True, the Engine will log all statements + as well as a ``repr()`` of their parameter lists to the default log + handler, which defaults to ``sys.stdout`` for output. If set to the + string ``"debug"``, result rows will be printed to the standard output + as well. The ``echo`` attribute of ``Engine`` can be modified at any + time to turn logging on and off; direct control of logging is also + available using the standard Python ``logging`` module. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + + :param echo_pool=False: if True, the connection pool will log + informational output such as when connections are invalidated + as well as when connections are recycled to the default log handler, + which defaults to ``sys.stdout`` for output. If set to the string + ``"debug"``, the logging will include pool checkouts and checkins. + Direct control of logging is also available using the standard Python + ``logging`` module. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + + :param empty_in_strategy: No longer used; SQLAlchemy now uses + "empty set" behavior for IN in all cases. + + :param enable_from_linting: defaults to True. Will emit a warning + if a given SELECT statement is found to have un-linked FROM elements + which would cause a cartesian product. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`change_4737` + + :param execution_options: Dictionary execution options which will + be applied to all connections. See + :meth:`~sqlalchemy.engine.Connection.execution_options` + + :param future: Use the 2.0 style :class:`_engine.Engine` and + :class:`_engine.Connection` API. + + As of SQLAlchemy 2.0, this parameter is present for backwards + compatibility only and must remain at its default value of ``True``. + + The :paramref:`_sa.create_engine.future` parameter will be + deprecated in a subsequent 2.x release and eventually removed. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 All :class:`_engine.Engine` objects are + "future" style engines and there is no longer a ``future=False`` + mode of operation. + + .. seealso:: + + :ref:`migration_20_toplevel` + + :param hide_parameters: Boolean, when set to True, SQL statement parameters + will not be displayed in INFO logging nor will they be formatted into + the string representation of :class:`.StatementError` objects. + + .. versionadded:: 1.3.8 + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :param implicit_returning=True: Legacy parameter that may only be set + to True. In SQLAlchemy 2.0, this parameter does nothing. In order to + disable "implicit returning" for statements invoked by the ORM, + configure this on a per-table basis using the + :paramref:`.Table.implicit_returning` parameter. + + + :param insertmanyvalues_page_size: number of rows to format into an + INSERT statement when the statement uses "insertmanyvalues" mode, which is + a paged form of bulk insert that is used for many backends when using + :term:`executemany` execution typically in conjunction with RETURNING. + Defaults to 1000, but may also be subject to dialect-specific limiting + factors which may override this value on a per-statement basis. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + :ref:`engine_insertmanyvalues_page_size` + + :paramref:`_engine.Connection.execution_options.insertmanyvalues_page_size` + + :param isolation_level: optional string name of an isolation level + which will be set on all new connections unconditionally. + Isolation levels are typically some subset of the string names + ``"SERIALIZABLE"``, ``"REPEATABLE READ"``, + ``"READ COMMITTED"``, ``"READ UNCOMMITTED"`` and ``"AUTOCOMMIT"`` + based on backend. + + The :paramref:`_sa.create_engine.isolation_level` parameter is + in contrast to the + :paramref:`.Connection.execution_options.isolation_level` + execution option, which may be set on an individual + :class:`.Connection`, as well as the same parameter passed to + :meth:`.Engine.execution_options`, where it may be used to create + multiple engines with different isolation levels that share a common + connection pool and dialect. + + .. versionchanged:: 2.0 The + :paramref:`_sa.create_engine.isolation_level` + parameter has been generalized to work on all dialects which support + the concept of isolation level, and is provided as a more succinct, + up front configuration switch in contrast to the execution option + which is more of an ad-hoc programmatic option. + + .. seealso:: + + :ref:`dbapi_autocommit` + + :param json_deserializer: for dialects that support the + :class:`_types.JSON` + datatype, this is a Python callable that will convert a JSON string + to a Python object. By default, the Python ``json.loads`` function is + used. + + .. versionchanged:: 1.3.7 The SQLite dialect renamed this from + ``_json_deserializer``. + + :param json_serializer: for dialects that support the :class:`_types.JSON` + datatype, this is a Python callable that will render a given object + as JSON. By default, the Python ``json.dumps`` function is used. + + .. versionchanged:: 1.3.7 The SQLite dialect renamed this from + ``_json_serializer``. + + + :param label_length=None: optional integer value which limits + the size of dynamically generated column labels to that many + characters. If less than 6, labels are generated as + "_(counter)". If ``None``, the value of + ``dialect.max_identifier_length``, which may be affected via the + :paramref:`_sa.create_engine.max_identifier_length` parameter, + is used instead. The value of + :paramref:`_sa.create_engine.label_length` + may not be larger than that of + :paramref:`_sa.create_engine.max_identfier_length`. + + .. seealso:: + + :paramref:`_sa.create_engine.max_identifier_length` + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.engine" logger. Defaults to a hexstring of the + object's id. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :paramref:`_engine.Connection.execution_options.logging_token` + + :param max_identifier_length: integer; override the max_identifier_length + determined by the dialect. if ``None`` or zero, has no effect. This + is the database's configured maximum number of characters that may be + used in a SQL identifier such as a table name, column name, or label + name. All dialects determine this value automatically, however in the + case of a new database version for which this value has changed but + SQLAlchemy's dialect has not been adjusted, the value may be passed + here. + + .. versionadded:: 1.3.9 + + .. seealso:: + + :paramref:`_sa.create_engine.label_length` + + :param max_overflow=10: the number of connections to allow in + connection pool "overflow", that is connections that can be + opened above and beyond the pool_size setting, which defaults + to five. this is only used with :class:`~sqlalchemy.pool.QueuePool`. + + :param module=None: reference to a Python module object (the module + itself, not its string name). Specifies an alternate DBAPI module to + be used by the engine's dialect. Each sub-dialect references a + specific DBAPI which will be imported before first connect. This + parameter causes the import to be bypassed, and the given module to + be used instead. Can be used for testing of DBAPIs as well as to + inject "mock" DBAPI implementations into the :class:`_engine.Engine`. + + :param paramstyle=None: The `paramstyle `_ + to use when rendering bound parameters. This style defaults to the + one recommended by the DBAPI itself, which is retrieved from the + ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept + more than one paramstyle, and in particular it may be desirable + to change a "named" paramstyle into a "positional" one, or vice versa. + When this attribute is passed, it should be one of the values + ``"qmark"``, ``"numeric"``, ``"named"``, ``"format"`` or + ``"pyformat"``, and should correspond to a parameter style known + to be supported by the DBAPI in use. + + :param pool=None: an already-constructed instance of + :class:`~sqlalchemy.pool.Pool`, such as a + :class:`~sqlalchemy.pool.QueuePool` instance. If non-None, this + pool will be used directly as the underlying connection pool + for the engine, bypassing whatever connection parameters are + present in the URL argument. For information on constructing + connection pools manually, see :ref:`pooling_toplevel`. + + :param poolclass=None: a :class:`~sqlalchemy.pool.Pool` + subclass, which will be used to create a connection pool + instance using the connection parameters given in the URL. Note + this differs from ``pool`` in that you don't actually + instantiate the pool in this case, you just indicate what type + of pool to be used. + + :param pool_logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :param pool_pre_ping: boolean, if True will enable the connection pool + "pre-ping" feature that tests connections for liveness upon + each checkout. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`pool_disconnects_pessimistic` + + :param pool_size=5: the number of connections to keep open + inside the connection pool. This used with + :class:`~sqlalchemy.pool.QueuePool` as + well as :class:`~sqlalchemy.pool.SingletonThreadPool`. With + :class:`~sqlalchemy.pool.QueuePool`, a ``pool_size`` setting + of 0 indicates no limit; to disable pooling, set ``poolclass`` to + :class:`~sqlalchemy.pool.NullPool` instead. + + :param pool_recycle=-1: this setting causes the pool to recycle + connections after the given number of seconds has passed. It + defaults to -1, or no timeout. For example, setting to 3600 + means connections will be recycled after one hour. Note that + MySQL in particular will disconnect automatically if no + activity is detected on a connection for eight hours (although + this is configurable with the MySQLDB connection itself and the + server configuration as well). + + .. seealso:: + + :ref:`pool_setting_recycle` + + :param pool_reset_on_return='rollback': set the + :paramref:`_pool.Pool.reset_on_return` parameter of the underlying + :class:`_pool.Pool` object, which can be set to the values + ``"rollback"``, ``"commit"``, or ``None``. + + .. seealso:: + + :ref:`pool_reset_on_return` + + :param pool_timeout=30: number of seconds to wait before giving + up on getting a connection from the pool. This is only used + with :class:`~sqlalchemy.pool.QueuePool`. This can be a float but is + subject to the limitations of Python time functions which may not be + reliable in the tens of milliseconds. + + .. note: don't use 30.0 above, it seems to break with the :param tag + + :param pool_use_lifo=False: use LIFO (last-in-first-out) when retrieving + connections from :class:`.QueuePool` instead of FIFO + (first-in-first-out). Using LIFO, a server-side timeout scheme can + reduce the number of connections used during non- peak periods of + use. When planning for server-side timeouts, ensure that a recycle or + pre-ping strategy is in use to gracefully handle stale connections. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`pool_use_lifo` + + :ref:`pool_disconnects` + + :param plugins: string list of plugin names to load. See + :class:`.CreateEnginePlugin` for background. + + .. versionadded:: 1.2.3 + + :param query_cache_size: size of the cache used to cache the SQL string + form of queries. Set to zero to disable caching. + + The cache is pruned of its least recently used items when its size reaches + N * 1.5. Defaults to 500, meaning the cache will always store at least + 500 SQL statements when filled, and will grow up to 750 items at which + point it is pruned back down to 500 by removing the 250 least recently + used items. + + Caching is accomplished on a per-statement basis by generating a + cache key that represents the statement's structure, then generating + string SQL for the current dialect only if that key is not present + in the cache. All statements support caching, however some features + such as an INSERT with a large set of parameters will intentionally + bypass the cache. SQL logging will indicate statistics for each + statement whether or not it were pull from the cache. + + .. note:: some ORM functions related to unit-of-work persistence as well + as some attribute loading strategies will make use of individual + per-mapper caches outside of the main cache. + + + .. seealso:: + + :ref:`sql_caching` + + .. versionadded:: 1.4 + + :param use_insertmanyvalues: True by default, use the "insertmanyvalues" + execution style for INSERT..RETURNING statements by default. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + """ # noqa + + if "strategy" in kwargs: + strat = kwargs.pop("strategy") + if strat == "mock": + # this case is deprecated + return create_mock_engine(url, **kwargs) # type: ignore + else: + raise exc.ArgumentError("unknown strategy: %r" % strat) + + kwargs.pop("empty_in_strategy", None) + + # create url.URL object + u = _url.make_url(url) + + u, plugins, kwargs = u._instantiate_plugins(kwargs) + + entrypoint = u._get_entrypoint() + _is_async = kwargs.pop("_is_async", False) + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(u) + else: + dialect_cls = entrypoint.get_dialect_cls(u) + + if kwargs.pop("_coerce_config", False): + + def pop_kwarg(key: str, default: Optional[Any] = None) -> Any: + value = kwargs.pop(key, default) + if key in dialect_cls.engine_config_types: + value = dialect_cls.engine_config_types[key](value) + return value + + else: + pop_kwarg = kwargs.pop # type: ignore + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kwargs: + dialect_args[k] = pop_kwarg(k) + + dbapi = kwargs.pop("module", None) + if dbapi is None: + dbapi_args = {} + + if "import_dbapi" in dialect_cls.__dict__: + dbapi_meth = dialect_cls.import_dbapi + + elif hasattr(dialect_cls, "dbapi") and inspect.ismethod( + dialect_cls.dbapi + ): + util.warn_deprecated( + "The dbapi() classmethod on dialect classes has been " + "renamed to import_dbapi(). Implement an import_dbapi() " + f"classmethod directly on class {dialect_cls} to remove this " + "warning; the old .dbapi() classmethod may be maintained for " + "backwards compatibility.", + "2.0", + ) + dbapi_meth = dialect_cls.dbapi + else: + dbapi_meth = dialect_cls.import_dbapi + + for k in util.get_func_kwargs(dbapi_meth): + if k in kwargs: + dbapi_args[k] = pop_kwarg(k) + dbapi = dbapi_meth(**dbapi_args) + + dialect_args["dbapi"] = dbapi + + dialect_args.setdefault("compiler_linting", compiler.NO_LINTING) + enable_from_linting = kwargs.pop("enable_from_linting", True) + if enable_from_linting: + dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS + + for plugin in plugins: + plugin.handle_dialect_kwargs(dialect_cls, dialect_args) + + # create dialect + dialect = dialect_cls(**dialect_args) + + # assemble connection arguments + (cargs_tup, cparams) = dialect.create_connect_args(u) + cparams.update(pop_kwarg("connect_args", {})) + + if "async_fallback" in cparams and util.asbool(cparams["async_fallback"]): + util.warn_deprecated( + "The async_fallback dialect argument is deprecated and will be " + "removed in SQLAlchemy 2.1.", + "2.0", + ) + + cargs = list(cargs_tup) # allow mutability + + # look for existing pool or create + pool = pop_kwarg("pool", None) + if pool is None: + + def connect( + connection_record: Optional[ConnectionPoolEntry] = None, + ) -> DBAPIConnection: + if dialect._has_events: + for fn in dialect.dispatch.do_connect: + connection = cast( + DBAPIConnection, + fn(dialect, connection_record, cargs, cparams), + ) + if connection is not None: + return connection + + return dialect.connect(*cargs, **cparams) + + creator = pop_kwarg("creator", connect) + + poolclass = pop_kwarg("poolclass", None) + if poolclass is None: + poolclass = dialect.get_dialect_pool_class(u) + pool_args = {"dialect": dialect} + + # consume pool arguments from kwargs, translating a few of + # the arguments + for k in util.get_cls_kwargs(poolclass): + tk = _pool_translate_kwargs.get(k, k) + if tk in kwargs: + pool_args[k] = pop_kwarg(tk) + + for plugin in plugins: + plugin.handle_pool_kwargs(poolclass, pool_args) + + pool = poolclass(creator, **pool_args) + else: + pool._dialect = dialect + + if ( + hasattr(pool, "_is_asyncio") + and pool._is_asyncio is not dialect.is_async + ): + raise exc.ArgumentError( + f"Pool class {pool.__class__.__name__} cannot be " + f"used with {'non-' if not dialect.is_async else ''}" + "asyncio engine", + code="pcls", + ) + + # create engine. + if not pop_kwarg("future", True): + raise exc.ArgumentError( + "The 'future' parameter passed to " + "create_engine() may only be set to True." + ) + + engineclass = base.Engine + + engine_args = {} + for k in util.get_cls_kwargs(engineclass): + if k in kwargs: + engine_args[k] = pop_kwarg(k) + + # internal flags used by the test suite for instrumenting / proxying + # engines with mocks etc. + _initialize = kwargs.pop("_initialize", True) + + # all kwargs should be consumed + if kwargs: + raise TypeError( + "Invalid argument(s) %s sent to create_engine(), " + "using configuration %s/%s/%s. Please check that the " + "keyword arguments are appropriate for this combination " + "of components." + % ( + ",".join("'%s'" % k for k in kwargs), + dialect.__class__.__name__, + pool.__class__.__name__, + engineclass.__name__, + ) + ) + + engine = engineclass(pool, dialect, u, **engine_args) + + if _initialize: + do_on_connect = dialect.on_connect_url(u) + if do_on_connect: + + def on_connect( + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + assert do_on_connect is not None + do_on_connect(dbapi_connection) + + event.listen(pool, "connect", on_connect) + + builtin_on_connect = dialect._builtin_onconnect() + if builtin_on_connect: + event.listen(pool, "connect", builtin_on_connect) + + def first_connect( + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + c = base.Connection( + engine, + connection=_AdhocProxiedConnection( + dbapi_connection, connection_record + ), + _has_events=False, + # reconnecting will be a reentrant condition, so if the + # connection goes away, Connection is then closed + _allow_revalidate=False, + # dont trigger the autobegin sequence + # within the up front dialect checks + _allow_autobegin=False, + ) + c._execution_options = util.EMPTY_DICT + + try: + dialect.initialize(c) + finally: + # note that "invalidated" and "closed" are mutually + # exclusive in 1.4 Connection. + if not c.invalidated and not c.closed: + # transaction is rolled back otherwise, tested by + # test/dialect/postgresql/test_dialect.py + # ::MiscBackendTest::test_initial_transaction_state + dialect.do_rollback(c.connection) + + # previously, the "first_connect" event was used here, which was then + # scaled back if the "on_connect" handler were present. now, + # since "on_connect" is virtually always present, just use + # "connect" event with once_unless_exception in all cases so that + # the connection event flow is consistent in all cases. + event.listen( + pool, "connect", first_connect, _once_unless_exception=True + ) + + dialect_cls.engine_created(engine) + if entrypoint is not dialect_cls: + entrypoint.engine_created(engine) + + for plugin in plugins: + plugin.engine_created(engine) + + return engine + + +def engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> Engine: + """Create a new Engine instance using a configuration dictionary. + + The dictionary is typically produced from a config file. + + The keys of interest to ``engine_from_config()`` should be prefixed, e.g. + ``sqlalchemy.url``, ``sqlalchemy.echo``, etc. The 'prefix' argument + indicates the prefix to be searched for. Each matching key (after the + prefix is stripped) is treated as though it were the corresponding keyword + argument to a :func:`_sa.create_engine` call. + + The only required key is (assuming the default prefix) ``sqlalchemy.url``, + which provides the :ref:`database URL `. + + A select set of keyword arguments will be "coerced" to their + expected type based on string values. The set of arguments + is extensible per-dialect using the ``engine_config_types`` accessor. + + :param configuration: A dictionary (typically produced from a config file, + but this is not a requirement). Items whose keys start with the value + of 'prefix' will have that prefix stripped, and will then be passed to + :func:`_sa.create_engine`. + + :param prefix: Prefix to match and then strip from keys + in 'configuration'. + + :param kwargs: Each keyword argument to ``engine_from_config()`` itself + overrides the corresponding item taken from the 'configuration' + dictionary. Keyword arguments should *not* be prefixed. + + """ + + options = { + key[len(prefix) :]: configuration[key] + for key in configuration + if key.startswith(prefix) + } + options["_coerce_config"] = True + options.update(kwargs) + url = options.pop("url") + return create_engine(url, **options) + + +@overload +def create_pool_from_url( + url: Union[str, URL], + *, + poolclass: Optional[Type[Pool]] = ..., + logging_name: str = ..., + pre_ping: bool = ..., + size: int = ..., + recycle: int = ..., + reset_on_return: Optional[_ResetStyleArgType] = ..., + timeout: float = ..., + use_lifo: bool = ..., + **kwargs: Any, +) -> Pool: ... + + +@overload +def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: ... + + +def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: + """Create a pool instance from the given url. + + If ``poolclass`` is not provided the pool class used + is selected using the dialect specified in the URL. + + The arguments passed to :func:`_sa.create_pool_from_url` are + identical to the pool argument passed to the :func:`_sa.create_engine` + function. + + .. versionadded:: 2.0.10 + """ + + for key in _pool_translate_kwargs: + if key in kwargs: + kwargs[_pool_translate_kwargs[key]] = kwargs.pop(key) + + engine = create_engine(url, **kwargs, _initialize=False) + return engine.pool + + +_pool_translate_kwargs = immutabledict( + { + "logging_name": "pool_logging_name", + "echo": "echo_pool", + "timeout": "pool_timeout", + "recycle": "pool_recycle", + "events": "pool_events", # deprecated + "reset_on_return": "pool_reset_on_return", + "pre_ping": "pool_pre_ping", + "use_lifo": "pool_use_lifo", + } +) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/cursor.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/cursor.py new file mode 100644 index 0000000..71767db --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/cursor.py @@ -0,0 +1,2178 @@ +# engine/cursor.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Define cursor-specific result set constructs including +:class:`.CursorResult`.""" + + +from __future__ import annotations + +import collections +import functools +import operator +import typing +from typing import Any +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .result import IteratorResult +from .result import MergedResult +from .result import Result +from .result import ResultMetaData +from .result import SimpleResultMetaData +from .result import tuplegetter +from .row import Row +from .. import exc +from .. import util +from ..sql import elements +from ..sql import sqltypes +from ..sql import util as sql_util +from ..sql.base import _generative +from ..sql.compiler import ResultColumnsEntry +from ..sql.compiler import RM_NAME +from ..sql.compiler import RM_OBJECTS +from ..sql.compiler import RM_RENDERED_NAME +from ..sql.compiler import RM_TYPE +from ..sql.type_api import TypeEngine +from ..util import compat +from ..util.typing import Literal +from ..util.typing import Self + + +if typing.TYPE_CHECKING: + from .base import Connection + from .default import DefaultExecutionContext + from .interfaces import _DBAPICursorDescription + from .interfaces import DBAPICursor + from .interfaces import Dialect + from .interfaces import ExecutionContext + from .result import _KeyIndexType + from .result import _KeyMapRecType + from .result import _KeyMapType + from .result import _KeyType + from .result import _ProcessorsType + from .result import _TupleGetterType + from ..sql.type_api import _ResultProcessorType + + +_T = TypeVar("_T", bound=Any) + + +# metadata entry tuple indexes. +# using raw tuple is faster than namedtuple. +# these match up to the positions in +# _CursorKeyMapRecType +MD_INDEX: Literal[0] = 0 +"""integer index in cursor.description + +""" + +MD_RESULT_MAP_INDEX: Literal[1] = 1 +"""integer index in compiled._result_columns""" + +MD_OBJECTS: Literal[2] = 2 +"""other string keys and ColumnElement obj that can match. + +This comes from compiler.RM_OBJECTS / compiler.ResultColumnsEntry.objects + +""" + +MD_LOOKUP_KEY: Literal[3] = 3 +"""string key we usually expect for key-based lookup + +this comes from compiler.RM_NAME / compiler.ResultColumnsEntry.name +""" + + +MD_RENDERED_NAME: Literal[4] = 4 +"""name that is usually in cursor.description + +this comes from compiler.RENDERED_NAME / compiler.ResultColumnsEntry.keyname +""" + + +MD_PROCESSOR: Literal[5] = 5 +"""callable to process a result value into a row""" + +MD_UNTRANSLATED: Literal[6] = 6 +"""raw name from cursor.description""" + + +_CursorKeyMapRecType = Tuple[ + Optional[int], # MD_INDEX, None means the record is ambiguously named + int, # MD_RESULT_MAP_INDEX + List[Any], # MD_OBJECTS + str, # MD_LOOKUP_KEY + str, # MD_RENDERED_NAME + Optional["_ResultProcessorType[Any]"], # MD_PROCESSOR + Optional[str], # MD_UNTRANSLATED +] + +_CursorKeyMapType = Mapping["_KeyType", _CursorKeyMapRecType] + +# same as _CursorKeyMapRecType except the MD_INDEX value is definitely +# not None +_NonAmbigCursorKeyMapRecType = Tuple[ + int, + int, + List[Any], + str, + str, + Optional["_ResultProcessorType[Any]"], + str, +] + + +class CursorResultMetaData(ResultMetaData): + """Result metadata for DBAPI cursors.""" + + __slots__ = ( + "_keymap", + "_processors", + "_keys", + "_keymap_by_result_column_idx", + "_tuplefilter", + "_translated_indexes", + "_safe_for_cache", + "_unpickled", + "_key_to_index", + # don't need _unique_filters support here for now. Can be added + # if a need arises. + ) + + _keymap: _CursorKeyMapType + _processors: _ProcessorsType + _keymap_by_result_column_idx: Optional[Dict[int, _KeyMapRecType]] + _unpickled: bool + _safe_for_cache: bool + _translated_indexes: Optional[List[int]] + + returns_rows: ClassVar[bool] = True + + def _has_key(self, key: Any) -> bool: + return key in self._keymap + + def _for_freeze(self) -> ResultMetaData: + return SimpleResultMetaData( + self._keys, + extra=[self._keymap[key][MD_OBJECTS] for key in self._keys], + ) + + def _make_new_metadata( + self, + *, + unpickled: bool, + processors: _ProcessorsType, + keys: Sequence[str], + keymap: _KeyMapType, + tuplefilter: Optional[_TupleGetterType], + translated_indexes: Optional[List[int]], + safe_for_cache: bool, + keymap_by_result_column_idx: Any, + ) -> CursorResultMetaData: + new_obj = self.__class__.__new__(self.__class__) + new_obj._unpickled = unpickled + new_obj._processors = processors + new_obj._keys = keys + new_obj._keymap = keymap + new_obj._tuplefilter = tuplefilter + new_obj._translated_indexes = translated_indexes + new_obj._safe_for_cache = safe_for_cache + new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx + new_obj._key_to_index = self._make_key_to_index(keymap, MD_INDEX) + return new_obj + + def _remove_processors(self) -> CursorResultMetaData: + assert not self._tuplefilter + return self._make_new_metadata( + unpickled=self._unpickled, + processors=[None] * len(self._processors), + tuplefilter=None, + translated_indexes=None, + keymap={ + key: value[0:5] + (None,) + value[6:] + for key, value in self._keymap.items() + }, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) + + def _splice_horizontally( + self, other: CursorResultMetaData + ) -> CursorResultMetaData: + assert not self._tuplefilter + + keymap = dict(self._keymap) + offset = len(self._keys) + keymap.update( + { + key: ( + # int index should be None for ambiguous key + ( + value[0] + offset + if value[0] is not None and key not in keymap + else None + ), + value[1] + offset, + *value[2:], + ) + for key, value in other._keymap.items() + } + ) + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors + other._processors, # type: ignore + tuplefilter=None, + translated_indexes=None, + keys=self._keys + other._keys, # type: ignore + keymap=keymap, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx={ + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in keymap.values() + }, + ) + + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + recs = list(self._metadata_for_keys(keys)) + + indexes = [rec[MD_INDEX] for rec in recs] + new_keys: List[str] = [rec[MD_LOOKUP_KEY] for rec in recs] + + if self._translated_indexes: + indexes = [self._translated_indexes[idx] for idx in indexes] + tup = tuplegetter(*indexes) + new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)] + + keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} + # TODO: need unit test for: + # result = connection.execute("raw sql, no columns").scalars() + # without the "or ()" it's failing because MD_OBJECTS is None + keymap.update( + (e, new_rec) + for new_rec in new_recs + for e in new_rec[MD_OBJECTS] or () + ) + + return self._make_new_metadata( + unpickled=self._unpickled, + processors=self._processors, + keys=new_keys, + tuplefilter=tup, + translated_indexes=indexes, + keymap=keymap, # type: ignore[arg-type] + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) + + def _adapt_to_context(self, context: ExecutionContext) -> ResultMetaData: + """When using a cached Compiled construct that has a _result_map, + for a new statement that used the cached Compiled, we need to ensure + the keymap has the Column objects from our new statement as keys. + So here we rewrite keymap with new entries for the new columns + as matched to those of the cached statement. + + """ + + if not context.compiled or not context.compiled._result_columns: + return self + + compiled_statement = context.compiled.statement + invoked_statement = context.invoked_statement + + if TYPE_CHECKING: + assert isinstance(invoked_statement, elements.ClauseElement) + + if compiled_statement is invoked_statement: + return self + + assert invoked_statement is not None + + # this is the most common path for Core statements when + # caching is used. In ORM use, this codepath is not really used + # as the _result_disable_adapt_to_context execution option is + # set by the ORM. + + # make a copy and add the columns from the invoked statement + # to the result map. + + keymap_by_position = self._keymap_by_result_column_idx + + if keymap_by_position is None: + # first retrival from cache, this map will not be set up yet, + # initialize lazily + keymap_by_position = self._keymap_by_result_column_idx = { + metadata_entry[MD_RESULT_MAP_INDEX]: metadata_entry + for metadata_entry in self._keymap.values() + } + + assert not self._tuplefilter + return self._make_new_metadata( + keymap=compat.dict_union( + self._keymap, + { + new: keymap_by_position[idx] + for idx, new in enumerate( + invoked_statement._all_selected_columns + ) + if idx in keymap_by_position + }, + ), + unpickled=self._unpickled, + processors=self._processors, + tuplefilter=None, + translated_indexes=None, + keys=self._keys, + safe_for_cache=self._safe_for_cache, + keymap_by_result_column_idx=self._keymap_by_result_column_idx, + ) + + def __init__( + self, + parent: CursorResult[Any], + cursor_description: _DBAPICursorDescription, + ): + context = parent.context + self._tuplefilter = None + self._translated_indexes = None + self._safe_for_cache = self._unpickled = False + + if context.result_column_struct: + ( + result_columns, + cols_are_ordered, + textual_ordered, + ad_hoc_textual, + loose_column_name_matching, + ) = context.result_column_struct + num_ctx_cols = len(result_columns) + else: + result_columns = cols_are_ordered = ( # type: ignore + num_ctx_cols + ) = ad_hoc_textual = loose_column_name_matching = ( + textual_ordered + ) = False + + # merge cursor.description with the column info + # present in the compiled structure, if any + raw = self._merge_cursor_description( + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ad_hoc_textual, + loose_column_name_matching, + ) + + # processors in key order which are used when building up + # a row + self._processors = [ + metadata_entry[MD_PROCESSOR] for metadata_entry in raw + ] + + # this is used when using this ResultMetaData in a Core-only cache + # retrieval context. it's initialized on first cache retrieval + # when the _result_disable_adapt_to_context execution option + # (which the ORM generally sets) is not set. + self._keymap_by_result_column_idx = None + + # for compiled SQL constructs, copy additional lookup keys into + # the key lookup map, such as Column objects, labels, + # column keys and other names + if num_ctx_cols: + # keymap by primary string... + by_key = { + metadata_entry[MD_LOOKUP_KEY]: metadata_entry + for metadata_entry in raw + } + + if len(by_key) != num_ctx_cols: + # if by-primary-string dictionary smaller than + # number of columns, assume we have dupes; (this check + # is also in place if string dictionary is bigger, as + # can occur when '*' was used as one of the compiled columns, + # which may or may not be suggestive of dupes), rewrite + # dupe records with "None" for index which results in + # ambiguous column exception when accessed. + # + # this is considered to be the less common case as it is not + # common to have dupe column keys in a SELECT statement. + # + # new in 1.4: get the complete set of all possible keys, + # strings, objects, whatever, that are dupes across two + # different records, first. + index_by_key: Dict[Any, Any] = {} + dupes = set() + for metadata_entry in raw: + for key in (metadata_entry[MD_RENDERED_NAME],) + ( + metadata_entry[MD_OBJECTS] or () + ): + idx = metadata_entry[MD_INDEX] + # if this key has been associated with more than one + # positional index, it's a dupe + if index_by_key.setdefault(key, idx) != idx: + dupes.add(key) + + # then put everything we have into the keymap excluding only + # those keys that are dupes. + self._keymap = { + obj_elem: metadata_entry + for metadata_entry in raw + if metadata_entry[MD_OBJECTS] + for obj_elem in metadata_entry[MD_OBJECTS] + if obj_elem not in dupes + } + + # then for the dupe keys, put the "ambiguous column" + # record into by_key. + by_key.update( + { + key: (None, None, [], key, key, None, None) + for key in dupes + } + ) + + else: + # no dupes - copy secondary elements from compiled + # columns into self._keymap. this is the most common + # codepath for Core / ORM statement executions before the + # result metadata is cached + self._keymap = { + obj_elem: metadata_entry + for metadata_entry in raw + if metadata_entry[MD_OBJECTS] + for obj_elem in metadata_entry[MD_OBJECTS] + } + # update keymap with primary string names taking + # precedence + self._keymap.update(by_key) + else: + # no compiled objects to map, just create keymap by primary string + self._keymap = { + metadata_entry[MD_LOOKUP_KEY]: metadata_entry + for metadata_entry in raw + } + + # update keymap with "translated" names. In SQLAlchemy this is a + # sqlite only thing, and in fact impacting only extremely old SQLite + # versions unlikely to be present in modern Python versions. + # however, the pyhive third party dialect is + # also using this hook, which means others still might use it as well. + # I dislike having this awkward hook here but as long as we need + # to use names in cursor.description in some cases we need to have + # some hook to accomplish this. + if not num_ctx_cols and context._translate_colname: + self._keymap.update( + { + metadata_entry[MD_UNTRANSLATED]: self._keymap[ + metadata_entry[MD_LOOKUP_KEY] + ] + for metadata_entry in raw + if metadata_entry[MD_UNTRANSLATED] + } + ) + + self._key_to_index = self._make_key_to_index(self._keymap, MD_INDEX) + + def _merge_cursor_description( + self, + context, + cursor_description, + result_columns, + num_ctx_cols, + cols_are_ordered, + textual_ordered, + ad_hoc_textual, + loose_column_name_matching, + ): + """Merge a cursor.description with compiled result column information. + + There are at least four separate strategies used here, selected + depending on the type of SQL construct used to start with. + + The most common case is that of the compiled SQL expression construct, + which generated the column names present in the raw SQL string and + which has the identical number of columns as were reported by + cursor.description. In this case, we assume a 1-1 positional mapping + between the entries in cursor.description and the compiled object. + This is also the most performant case as we disregard extracting / + decoding the column names present in cursor.description since we + already have the desired name we generated in the compiled SQL + construct. + + The next common case is that of the completely raw string SQL, + such as passed to connection.execute(). In this case we have no + compiled construct to work with, so we extract and decode the + names from cursor.description and index those as the primary + result row target keys. + + The remaining fairly common case is that of the textual SQL + that includes at least partial column information; this is when + we use a :class:`_expression.TextualSelect` construct. + This construct may have + unordered or ordered column information. In the ordered case, we + merge the cursor.description and the compiled construct's information + positionally, and warn if there are additional description names + present, however we still decode the names in cursor.description + as we don't have a guarantee that the names in the columns match + on these. In the unordered case, we match names in cursor.description + to that of the compiled construct based on name matching. + In both of these cases, the cursor.description names and the column + expression objects and names are indexed as result row target keys. + + The final case is much less common, where we have a compiled + non-textual SQL expression construct, but the number of columns + in cursor.description doesn't match what's in the compiled + construct. We make the guess here that there might be textual + column expressions in the compiled construct that themselves include + a comma in them causing them to split. We do the same name-matching + as with textual non-ordered columns. + + The name-matched system of merging is the same as that used by + SQLAlchemy for all cases up through the 0.9 series. Positional + matching for compiled SQL expressions was introduced in 1.0 as a + major performance feature, and positional matching for textual + :class:`_expression.TextualSelect` objects in 1.1. + As name matching is no longer + a common case, it was acceptable to factor it into smaller generator- + oriented methods that are easier to understand, but incur slightly + more performance overhead. + + """ + + if ( + num_ctx_cols + and cols_are_ordered + and not textual_ordered + and num_ctx_cols == len(cursor_description) + ): + self._keys = [elem[0] for elem in result_columns] + # pure positional 1-1 case; doesn't need to read + # the names from cursor.description + + # most common case for Core and ORM + + # this metadata is safe to cache because we are guaranteed + # to have the columns in the same order for new executions + self._safe_for_cache = True + return [ + ( + idx, + idx, + rmap_entry[RM_OBJECTS], + rmap_entry[RM_NAME], + rmap_entry[RM_RENDERED_NAME], + context.get_result_processor( + rmap_entry[RM_TYPE], + rmap_entry[RM_RENDERED_NAME], + cursor_description[idx][1], + ), + None, + ) + for idx, rmap_entry in enumerate(result_columns) + ] + else: + # name-based or text-positional cases, where we need + # to read cursor.description names + + if textual_ordered or ( + ad_hoc_textual and len(cursor_description) == num_ctx_cols + ): + self._safe_for_cache = True + # textual positional case + raw_iterator = self._merge_textual_cols_by_position( + context, cursor_description, result_columns + ) + elif num_ctx_cols: + # compiled SQL with a mismatch of description cols + # vs. compiled cols, or textual w/ unordered columns + # the order of columns can change if the query is + # against a "select *", so not safe to cache + self._safe_for_cache = False + raw_iterator = self._merge_cols_by_name( + context, + cursor_description, + result_columns, + loose_column_name_matching, + ) + else: + # no compiled SQL, just a raw string, order of columns + # can change for "select *" + self._safe_for_cache = False + raw_iterator = self._merge_cols_by_none( + context, cursor_description + ) + + return [ + ( + idx, + ridx, + obj, + cursor_colname, + cursor_colname, + context.get_result_processor( + mapped_type, cursor_colname, coltype + ), + untranslated, + ) + for ( + idx, + ridx, + cursor_colname, + mapped_type, + coltype, + obj, + untranslated, + ) in raw_iterator + ] + + def _colnames_from_description(self, context, cursor_description): + """Extract column names and data types from a cursor.description. + + Applies unicode decoding, column translation, "normalization", + and case sensitivity rules to the names based on the dialect. + + """ + + dialect = context.dialect + translate_colname = context._translate_colname + normalize_name = ( + dialect.normalize_name if dialect.requires_name_normalize else None + ) + untranslated = None + + self._keys = [] + + for idx, rec in enumerate(cursor_description): + colname = rec[0] + coltype = rec[1] + + if translate_colname: + colname, untranslated = translate_colname(colname) + + if normalize_name: + colname = normalize_name(colname) + + self._keys.append(colname) + + yield idx, colname, untranslated, coltype + + def _merge_textual_cols_by_position( + self, context, cursor_description, result_columns + ): + num_ctx_cols = len(result_columns) + + if num_ctx_cols > len(cursor_description): + util.warn( + "Number of columns in textual SQL (%d) is " + "smaller than number of columns requested (%d)" + % (num_ctx_cols, len(cursor_description)) + ) + seen = set() + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): + if idx < num_ctx_cols: + ctx_rec = result_columns[idx] + obj = ctx_rec[RM_OBJECTS] + ridx = idx + mapped_type = ctx_rec[RM_TYPE] + if obj[0] in seen: + raise exc.InvalidRequestError( + "Duplicate column expression requested " + "in textual SQL: %r" % obj[0] + ) + seen.add(obj[0]) + else: + mapped_type = sqltypes.NULLTYPE + obj = None + ridx = None + yield idx, ridx, colname, mapped_type, coltype, obj, untranslated + + def _merge_cols_by_name( + self, + context, + cursor_description, + result_columns, + loose_column_name_matching, + ): + match_map = self._create_description_match_map( + result_columns, loose_column_name_matching + ) + mapped_type: TypeEngine[Any] + + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): + try: + ctx_rec = match_map[colname] + except KeyError: + mapped_type = sqltypes.NULLTYPE + obj = None + result_columns_idx = None + else: + obj = ctx_rec[1] + mapped_type = ctx_rec[2] + result_columns_idx = ctx_rec[3] + yield ( + idx, + result_columns_idx, + colname, + mapped_type, + coltype, + obj, + untranslated, + ) + + @classmethod + def _create_description_match_map( + cls, + result_columns: List[ResultColumnsEntry], + loose_column_name_matching: bool = False, + ) -> Dict[ + Union[str, object], Tuple[str, Tuple[Any, ...], TypeEngine[Any], int] + ]: + """when matching cursor.description to a set of names that are present + in a Compiled object, as is the case with TextualSelect, get all the + names we expect might match those in cursor.description. + """ + + d: Dict[ + Union[str, object], + Tuple[str, Tuple[Any, ...], TypeEngine[Any], int], + ] = {} + for ridx, elem in enumerate(result_columns): + key = elem[RM_RENDERED_NAME] + if key in d: + # conflicting keyname - just add the column-linked objects + # to the existing record. if there is a duplicate column + # name in the cursor description, this will allow all of those + # objects to raise an ambiguous column error + e_name, e_obj, e_type, e_ridx = d[key] + d[key] = e_name, e_obj + elem[RM_OBJECTS], e_type, ridx + else: + d[key] = (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx) + + if loose_column_name_matching: + # when using a textual statement with an unordered set + # of columns that line up, we are expecting the user + # to be using label names in the SQL that match to the column + # expressions. Enable more liberal matching for this case; + # duplicate keys that are ambiguous will be fixed later. + for r_key in elem[RM_OBJECTS]: + d.setdefault( + r_key, + (elem[RM_NAME], elem[RM_OBJECTS], elem[RM_TYPE], ridx), + ) + return d + + def _merge_cols_by_none(self, context, cursor_description): + for ( + idx, + colname, + untranslated, + coltype, + ) in self._colnames_from_description(context, cursor_description): + yield ( + idx, + None, + colname, + sqltypes.NULLTYPE, + coltype, + None, + untranslated, + ) + + if not TYPE_CHECKING: + + def _key_fallback( + self, key: Any, err: Optional[Exception], raiseerr: bool = True + ) -> Optional[NoReturn]: + if raiseerr: + if self._unpickled and isinstance(key, elements.ColumnElement): + raise exc.NoSuchColumnError( + "Row was unpickled; lookup by ColumnElement " + "is unsupported" + ) from err + else: + raise exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" + % util.string_or_unprintable(key) + ) from err + else: + return None + + def _raise_for_ambiguous_column_name(self, rec): + raise exc.InvalidRequestError( + "Ambiguous column name '%s' in " + "result set column descriptions" % rec[MD_LOOKUP_KEY] + ) + + def _index_for_key(self, key: Any, raiseerr: bool = True) -> Optional[int]: + # TODO: can consider pre-loading ints and negative ints + # into _keymap - also no coverage here + if isinstance(key, int): + key = self._keys[key] + + try: + rec = self._keymap[key] + except KeyError as ke: + x = self._key_fallback(key, ke, raiseerr) + assert x is None + return None + + index = rec[0] + + if index is None: + self._raise_for_ambiguous_column_name(rec) + return index + + def _indexes_for_keys(self, keys): + try: + return [self._keymap[key][0] for key in keys] + except KeyError as ke: + # ensure it raises + CursorResultMetaData._key_fallback(self, ke.args[0], ke) + + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_NonAmbigCursorKeyMapRecType]: + for key in keys: + if int in key.__class__.__mro__: + key = self._keys[key] + + try: + rec = self._keymap[key] + except KeyError as ke: + # ensure it raises + CursorResultMetaData._key_fallback(self, ke.args[0], ke) + + index = rec[MD_INDEX] + + if index is None: + self._raise_for_ambiguous_column_name(rec) + + yield cast(_NonAmbigCursorKeyMapRecType, rec) + + def __getstate__(self): + # TODO: consider serializing this as SimpleResultMetaData + return { + "_keymap": { + key: ( + rec[MD_INDEX], + rec[MD_RESULT_MAP_INDEX], + [], + key, + rec[MD_RENDERED_NAME], + None, + None, + ) + for key, rec in self._keymap.items() + if isinstance(key, (str, int)) + }, + "_keys": self._keys, + "_translated_indexes": self._translated_indexes, + } + + def __setstate__(self, state): + self._processors = [None for _ in range(len(state["_keys"]))] + self._keymap = state["_keymap"] + self._keymap_by_result_column_idx = None + self._key_to_index = self._make_key_to_index(self._keymap, MD_INDEX) + self._keys = state["_keys"] + self._unpickled = True + if state["_translated_indexes"]: + self._translated_indexes = cast( + "List[int]", state["_translated_indexes"] + ) + self._tuplefilter = tuplegetter(*self._translated_indexes) + else: + self._translated_indexes = self._tuplefilter = None + + +class ResultFetchStrategy: + """Define a fetching strategy for a result object. + + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + alternate_cursor_description: Optional[_DBAPICursorDescription] = None + + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + raise NotImplementedError() + + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + raise NotImplementedError() + + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: + return + + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: + raise NotImplementedError() + + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: + raise NotImplementedError() + + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: + raise NotImplementedError() + + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: + raise err + + +class NoCursorFetchStrategy(ResultFetchStrategy): + """Cursor strategy for a result that has no open cursor. + + There are two varieties of this strategy, one for DQL and one for + DML (and also DDL), each of which represent a result that had a cursor + but no longer has one. + + """ + + __slots__ = () + + def soft_close(self, result, dbapi_cursor): + pass + + def hard_close(self, result, dbapi_cursor): + pass + + def fetchone(self, result, dbapi_cursor, hard_close=False): + return self._non_result(result, None) + + def fetchmany(self, result, dbapi_cursor, size=None): + return self._non_result(result, []) + + def fetchall(self, result, dbapi_cursor): + return self._non_result(result, []) + + def _non_result(self, result, default, err=None): + raise NotImplementedError() + + +class NoCursorDQLFetchStrategy(NoCursorFetchStrategy): + """Cursor strategy for a DQL result that has no open cursor. + + This is a result set that can return rows, i.e. for a SELECT, or for an + INSERT, UPDATE, DELETE that includes RETURNING. However it is in the state + where the cursor is closed and no rows remain available. The owning result + object may or may not be "hard closed", which determines if the fetch + methods send empty results or raise for closed result. + + """ + + __slots__ = () + + def _non_result(self, result, default, err=None): + if result.closed: + raise exc.ResourceClosedError( + "This result object is closed." + ) from err + else: + return default + + +_NO_CURSOR_DQL = NoCursorDQLFetchStrategy() + + +class NoCursorDMLFetchStrategy(NoCursorFetchStrategy): + """Cursor strategy for a DML result that has no open cursor. + + This is a result set that does not return rows, i.e. for an INSERT, + UPDATE, DELETE that does not include RETURNING. + + """ + + __slots__ = () + + def _non_result(self, result, default, err=None): + # we only expect to have a _NoResultMetaData() here right now. + assert not result._metadata.returns_rows + result._metadata._we_dont_return_rows(err) + + +_NO_CURSOR_DML = NoCursorDMLFetchStrategy() + + +class CursorFetchStrategy(ResultFetchStrategy): + """Call fetch methods from a DBAPI cursor. + + Alternate versions of this class may instead buffer the rows from + cursors or not use cursors at all. + + """ + + __slots__ = () + + def soft_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + result.cursor_strategy = _NO_CURSOR_DQL + + def hard_close( + self, result: CursorResult[Any], dbapi_cursor: Optional[DBAPICursor] + ) -> None: + result.cursor_strategy = _NO_CURSOR_DQL + + def handle_exception( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + err: BaseException, + ) -> NoReturn: + result.connection._handle_dbapi_exception( + err, None, None, dbapi_cursor, result.context + ) + + def yield_per( + self, + result: CursorResult[Any], + dbapi_cursor: Optional[DBAPICursor], + num: int, + ) -> None: + result.cursor_strategy = BufferedRowCursorFetchStrategy( + dbapi_cursor, + {"max_row_buffer": num}, + initial_buffer=collections.deque(), + growth_factor=0, + ) + + def fetchone( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + hard_close: bool = False, + ) -> Any: + try: + row = dbapi_cursor.fetchone() + if row is None: + result._soft_close(hard=hard_close) + return row + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + def fetchmany( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + size: Optional[int] = None, + ) -> Any: + try: + if size is None: + l = dbapi_cursor.fetchmany() + else: + l = dbapi_cursor.fetchmany(size) + + if not l: + result._soft_close() + return l + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + def fetchall( + self, + result: CursorResult[Any], + dbapi_cursor: DBAPICursor, + ) -> Any: + try: + rows = dbapi_cursor.fetchall() + result._soft_close() + return rows + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + +_DEFAULT_FETCH = CursorFetchStrategy() + + +class BufferedRowCursorFetchStrategy(CursorFetchStrategy): + """A cursor fetch strategy with row buffering behavior. + + This strategy buffers the contents of a selection of rows + before ``fetchone()`` is called. This is to allow the results of + ``cursor.description`` to be available immediately, when + interfacing with a DB-API that requires rows to be consumed before + this information is available (currently psycopg2, when used with + server-side cursors). + + The pre-fetching behavior fetches only one row initially, and then + grows its buffer size by a fixed amount with each successive need + for additional rows up the ``max_row_buffer`` size, which defaults + to 1000:: + + with psycopg2_engine.connect() as conn: + + result = conn.execution_options( + stream_results=True, max_row_buffer=50 + ).execute(text("select * from table")) + + .. versionadded:: 1.4 ``max_row_buffer`` may now exceed 1000 rows. + + .. seealso:: + + :ref:`psycopg2_execution_options` + """ + + __slots__ = ("_max_row_buffer", "_rowbuffer", "_bufsize", "_growth_factor") + + def __init__( + self, + dbapi_cursor, + execution_options, + growth_factor=5, + initial_buffer=None, + ): + self._max_row_buffer = execution_options.get("max_row_buffer", 1000) + + if initial_buffer is not None: + self._rowbuffer = initial_buffer + else: + self._rowbuffer = collections.deque(dbapi_cursor.fetchmany(1)) + self._growth_factor = growth_factor + + if growth_factor: + self._bufsize = min(self._max_row_buffer, self._growth_factor) + else: + self._bufsize = self._max_row_buffer + + @classmethod + def create(cls, result): + return BufferedRowCursorFetchStrategy( + result.cursor, + result.context.execution_options, + ) + + def _buffer_rows(self, result, dbapi_cursor): + """this is currently used only by fetchone().""" + + size = self._bufsize + try: + if size < 1: + new_rows = dbapi_cursor.fetchall() + else: + new_rows = dbapi_cursor.fetchmany(size) + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + if not new_rows: + return + self._rowbuffer = collections.deque(new_rows) + if self._growth_factor and size < self._max_row_buffer: + self._bufsize = min( + self._max_row_buffer, size * self._growth_factor + ) + + def yield_per(self, result, dbapi_cursor, num): + self._growth_factor = 0 + self._max_row_buffer = self._bufsize = num + + def soft_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().soft_close(result, dbapi_cursor) + + def hard_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().hard_close(result, dbapi_cursor) + + def fetchone(self, result, dbapi_cursor, hard_close=False): + if not self._rowbuffer: + self._buffer_rows(result, dbapi_cursor) + if not self._rowbuffer: + try: + result._soft_close(hard=hard_close) + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + return None + return self._rowbuffer.popleft() + + def fetchmany(self, result, dbapi_cursor, size=None): + if size is None: + return self.fetchall(result, dbapi_cursor) + + buf = list(self._rowbuffer) + lb = len(buf) + if size > lb: + try: + new = dbapi_cursor.fetchmany(size - lb) + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + else: + if not new: + result._soft_close() + else: + buf.extend(new) + + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result + + def fetchall(self, result, dbapi_cursor): + try: + ret = list(self._rowbuffer) + list(dbapi_cursor.fetchall()) + self._rowbuffer.clear() + result._soft_close() + return ret + except BaseException as e: + self.handle_exception(result, dbapi_cursor, e) + + +class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): + """A cursor strategy that buffers rows fully upon creation. + + Used for operations where a result is to be delivered + after the database conversation can not be continued, + such as MSSQL INSERT...OUTPUT after an autocommit. + + """ + + __slots__ = ("_rowbuffer", "alternate_cursor_description") + + def __init__( + self, dbapi_cursor, alternate_description=None, initial_buffer=None + ): + self.alternate_cursor_description = alternate_description + if initial_buffer is not None: + self._rowbuffer = collections.deque(initial_buffer) + else: + self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) + + def yield_per(self, result, dbapi_cursor, num): + pass + + def soft_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().soft_close(result, dbapi_cursor) + + def hard_close(self, result, dbapi_cursor): + self._rowbuffer.clear() + super().hard_close(result, dbapi_cursor) + + def fetchone(self, result, dbapi_cursor, hard_close=False): + if self._rowbuffer: + return self._rowbuffer.popleft() + else: + result._soft_close(hard=hard_close) + return None + + def fetchmany(self, result, dbapi_cursor, size=None): + if size is None: + return self.fetchall(result, dbapi_cursor) + + buf = list(self._rowbuffer) + rows = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + if not rows: + result._soft_close() + return rows + + def fetchall(self, result, dbapi_cursor): + ret = self._rowbuffer + self._rowbuffer = collections.deque() + result._soft_close() + return ret + + +class _NoResultMetaData(ResultMetaData): + __slots__ = () + + returns_rows = False + + def _we_dont_return_rows(self, err=None): + raise exc.ResourceClosedError( + "This result object does not return rows. " + "It has been closed automatically." + ) from err + + def _index_for_key(self, keys, raiseerr): + self._we_dont_return_rows() + + def _metadata_for_keys(self, key): + self._we_dont_return_rows() + + def _reduce(self, keys): + self._we_dont_return_rows() + + @property + def _keymap(self): + self._we_dont_return_rows() + + @property + def _key_to_index(self): + self._we_dont_return_rows() + + @property + def _processors(self): + self._we_dont_return_rows() + + @property + def keys(self): + self._we_dont_return_rows() + + +_NO_RESULT_METADATA = _NoResultMetaData() + + +def null_dml_result() -> IteratorResult[Any]: + it: IteratorResult[Any] = IteratorResult(_NoResultMetaData(), iter([])) + it._soft_close() + return it + + +class CursorResult(Result[_T]): + """A Result that is representing state from a DBAPI cursor. + + .. versionchanged:: 1.4 The :class:`.CursorResult`` + class replaces the previous :class:`.ResultProxy` interface. + This classes are based on the :class:`.Result` calling API + which provides an updated usage model and calling facade for + SQLAlchemy Core and SQLAlchemy ORM. + + Returns database rows via the :class:`.Row` class, which provides + additional API features and behaviors on top of the raw data returned by + the DBAPI. Through the use of filters such as the :meth:`.Result.scalars` + method, other kinds of objects may also be returned. + + .. seealso:: + + :ref:`tutorial_selecting_data` - introductory material for accessing + :class:`_engine.CursorResult` and :class:`.Row` objects. + + """ + + __slots__ = ( + "context", + "dialect", + "cursor", + "cursor_strategy", + "_echo", + "connection", + ) + + _metadata: Union[CursorResultMetaData, _NoResultMetaData] + _no_result_metadata = _NO_RESULT_METADATA + _soft_closed: bool = False + closed: bool = False + _is_cursor = True + + context: DefaultExecutionContext + dialect: Dialect + cursor_strategy: ResultFetchStrategy + connection: Connection + + def __init__( + self, + context: DefaultExecutionContext, + cursor_strategy: ResultFetchStrategy, + cursor_description: Optional[_DBAPICursorDescription], + ): + self.context = context + self.dialect = context.dialect + self.cursor = context.cursor + self.cursor_strategy = cursor_strategy + self.connection = context.root_connection + self._echo = echo = ( + self.connection._echo and context.engine._should_log_debug() + ) + + if cursor_description is not None: + # inline of Result._row_getter(), set up an initial row + # getter assuming no transformations will be called as this + # is the most common case + + metadata = self._init_metadata(context, cursor_description) + + _make_row: Any + _make_row = functools.partial( + Row, + metadata, + metadata._effective_processors, + metadata._key_to_index, + ) + + if context._num_sentinel_cols: + sentinel_filter = operator.itemgetter( + slice(-context._num_sentinel_cols) + ) + + def _sliced_row(raw_data): + return _make_row(sentinel_filter(raw_data)) + + sliced_row = _sliced_row + else: + sliced_row = _make_row + + if echo: + log = self.context.connection._log_debug + + def _log_row(row): + log("Row %r", sql_util._repr_row(row)) + return row + + self._row_logging_fn = _log_row + + def _make_row_2(row): + return _log_row(sliced_row(row)) + + make_row = _make_row_2 + else: + make_row = sliced_row + self._set_memoized_attribute("_row_getter", make_row) + + else: + assert context._num_sentinel_cols == 0 + self._metadata = self._no_result_metadata + + def _init_metadata(self, context, cursor_description): + if context.compiled: + compiled = context.compiled + + if compiled._cached_metadata: + metadata = compiled._cached_metadata + else: + metadata = CursorResultMetaData(self, cursor_description) + if metadata._safe_for_cache: + compiled._cached_metadata = metadata + + # result rewrite/ adapt step. this is to suit the case + # when we are invoked against a cached Compiled object, we want + # to rewrite the ResultMetaData to reflect the Column objects + # that are in our current SQL statement object, not the one + # that is associated with the cached Compiled object. + # the Compiled object may also tell us to not + # actually do this step; this is to support the ORM where + # it is to produce a new Result object in any case, and will + # be using the cached Column objects against this database result + # so we don't want to rewrite them. + # + # Basically this step suits the use case where the end user + # is using Core SQL expressions and is accessing columns in the + # result row using row._mapping[table.c.column]. + if ( + not context.execution_options.get( + "_result_disable_adapt_to_context", False + ) + and compiled._result_columns + and context.cache_hit is context.dialect.CACHE_HIT + and compiled.statement is not context.invoked_statement + ): + metadata = metadata._adapt_to_context(context) + + self._metadata = metadata + + else: + self._metadata = metadata = CursorResultMetaData( + self, cursor_description + ) + if self._echo: + context.connection._log_debug( + "Col %r", tuple(x[0] for x in cursor_description) + ) + return metadata + + def _soft_close(self, hard=False): + """Soft close this :class:`_engine.CursorResult`. + + This releases all DBAPI cursor resources, but leaves the + CursorResult "open" from a semantic perspective, meaning the + fetchXXX() methods will continue to return empty results. + + This method is called automatically when: + + * all result rows are exhausted using the fetchXXX() methods. + * cursor.description is None. + + This method is **not public**, but is documented in order to clarify + the "autoclose" process used. + + .. seealso:: + + :meth:`_engine.CursorResult.close` + + + """ + + if (not hard and self._soft_closed) or (hard and self.closed): + return + + if hard: + self.closed = True + self.cursor_strategy.hard_close(self, self.cursor) + else: + self.cursor_strategy.soft_close(self, self.cursor) + + if not self._soft_closed: + cursor = self.cursor + self.cursor = None # type: ignore + self.connection._safe_close_cursor(cursor) + self._soft_closed = True + + @property + def inserted_primary_key_rows(self): + """Return the value of + :attr:`_engine.CursorResult.inserted_primary_key` + as a row contained within a list; some dialects may support a + multiple row form as well. + + .. note:: As indicated below, in current SQLAlchemy versions this + accessor is only useful beyond what's already supplied by + :attr:`_engine.CursorResult.inserted_primary_key` when using the + :ref:`postgresql_psycopg2` dialect. Future versions hope to + generalize this feature to more dialects. + + This accessor is added to support dialects that offer the feature + that is currently implemented by the :ref:`psycopg2_executemany_mode` + feature, currently **only the psycopg2 dialect**, which provides + for many rows to be INSERTed at once while still retaining the + behavior of being able to return server-generated primary key values. + + * **When using the psycopg2 dialect, or other dialects that may support + "fast executemany" style inserts in upcoming releases** : When + invoking an INSERT statement while passing a list of rows as the + second argument to :meth:`_engine.Connection.execute`, this accessor + will then provide a list of rows, where each row contains the primary + key value for each row that was INSERTed. + + * **When using all other dialects / backends that don't yet support + this feature**: This accessor is only useful for **single row INSERT + statements**, and returns the same information as that of the + :attr:`_engine.CursorResult.inserted_primary_key` within a + single-element list. When an INSERT statement is executed in + conjunction with a list of rows to be INSERTed, the list will contain + one row per row inserted in the statement, however it will contain + ``None`` for any server-generated values. + + Future releases of SQLAlchemy will further generalize the + "fast execution helper" feature of psycopg2 to suit other dialects, + thus allowing this accessor to be of more general use. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.CursorResult.inserted_primary_key` + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled expression construct." + ) + elif not self.context.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() expression construct." + ) + elif self.context._is_explicit_returning: + raise exc.InvalidRequestError( + "Can't call inserted_primary_key " + "when returning() " + "is used." + ) + return self.context.inserted_primary_key_rows + + @property + def inserted_primary_key(self): + """Return the primary key for the row just inserted. + + The return value is a :class:`_result.Row` object representing + a named tuple of primary key values in the order in which the + primary key columns are configured in the source + :class:`_schema.Table`. + + .. versionchanged:: 1.4.8 - the + :attr:`_engine.CursorResult.inserted_primary_key` + value is now a named tuple via the :class:`_result.Row` class, + rather than a plain tuple. + + This accessor only applies to single row :func:`_expression.insert` + constructs which did not explicitly specify + :meth:`_expression.Insert.returning`. Support for multirow inserts, + while not yet available for most backends, would be accessed using + the :attr:`_engine.CursorResult.inserted_primary_key_rows` accessor. + + Note that primary key columns which specify a server_default clause, or + otherwise do not qualify as "autoincrement" columns (see the notes at + :class:`_schema.Column`), and were generated using the database-side + default, will appear in this list as ``None`` unless the backend + supports "returning" and the insert statement executed with the + "implicit returning" enabled. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() construct. + + """ + + if self.context.executemany: + raise exc.InvalidRequestError( + "This statement was an executemany call; if primary key " + "returning is supported, please " + "use .inserted_primary_key_rows." + ) + + ikp = self.inserted_primary_key_rows + if ikp: + return ikp[0] + else: + return None + + def last_updated_params(self): + """Return the collection of updated parameters from this + execution. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an update() construct. + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled expression construct." + ) + elif not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an update() expression construct." + ) + elif self.context.executemany: + return self.context.compiled_parameters + else: + return self.context.compiled_parameters[0] + + def last_inserted_params(self): + """Return the collection of inserted parameters from this + execution. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() construct. + + """ + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled expression construct." + ) + elif not self.context.isinsert: + raise exc.InvalidRequestError( + "Statement is not an insert() expression construct." + ) + elif self.context.executemany: + return self.context.compiled_parameters + else: + return self.context.compiled_parameters[0] + + @property + def returned_defaults_rows(self): + """Return a list of rows each containing the values of default + columns that were fetched using + the :meth:`.ValuesBase.return_defaults` feature. + + The return value is a list of :class:`.Row` objects. + + .. versionadded:: 1.4 + + """ + return self.context.returned_default_rows + + def splice_horizontally(self, other): + """Return a new :class:`.CursorResult` that "horizontally splices" + together the rows of this :class:`.CursorResult` with that of another + :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "horizontally splices" means that for each row in the first and second + result sets, a new row that concatenates the two rows together is + produced, which then becomes the new row. The incoming + :class:`.CursorResult` must have the identical number of rows. It is + typically expected that the two result sets come from the same sort + order as well, as the result rows are spliced together based on their + position in the result. + + The expected use case here is so that multiple INSERT..RETURNING + statements (which definitely need to be sorted) against different + tables can produce a single result that looks like a JOIN of those two + tables. + + E.g.:: + + r1 = connection.execute( + users.insert().returning( + users.c.user_name, + users.c.user_id, + sort_by_parameter_order=True + ), + user_values + ) + + r2 = connection.execute( + addresses.insert().returning( + addresses.c.address_id, + addresses.c.address, + addresses.c.user_id, + sort_by_parameter_order=True + ), + address_values + ) + + rows = r1.splice_horizontally(r2).all() + assert ( + rows == + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ] + ) + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.CursorResult.splice_vertically` + + + """ + + clone = self._generate() + total_rows = [ + tuple(r1) + tuple(r2) + for r1, r2 in zip( + list(self._raw_row_iterator()), + list(other._raw_row_iterator()), + ) + ] + + clone._metadata = clone._metadata._splice_horizontally(other._metadata) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def splice_vertically(self, other): + """Return a new :class:`.CursorResult` that "vertically splices", + i.e. "extends", the rows of this :class:`.CursorResult` with that of + another :class:`.CursorResult`. + + .. tip:: This method is for the benefit of the SQLAlchemy ORM and is + not intended for general use. + + "vertically splices" means the rows of the given result are appended to + the rows of this cursor result. The incoming :class:`.CursorResult` + must have rows that represent the identical list of columns in the + identical order as they are in this :class:`.CursorResult`. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.CursorResult.splice_horizontally` + + """ + clone = self._generate() + total_rows = list(self._raw_row_iterator()) + list( + other._raw_row_iterator() + ) + + clone.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + initial_buffer=total_rows, + ) + clone._reset_memoizations() + return clone + + def _rewind(self, rows): + """rewind this result back to the given rowset. + + this is used internally for the case where an :class:`.Insert` + construct combines the use of + :meth:`.Insert.return_defaults` along with the + "supplemental columns" feature. + + """ + + if self._echo: + self.context.connection._log_debug( + "CursorResult rewound %d row(s)", len(rows) + ) + + # the rows given are expected to be Row objects, so we + # have to clear out processors which have already run on these + # rows + self._metadata = cast( + CursorResultMetaData, self._metadata + )._remove_processors() + + self.cursor_strategy = FullyBufferedCursorFetchStrategy( + None, + # TODO: if these are Row objects, can we save on not having to + # re-make new Row objects out of them a second time? is that + # what's actually happening right now? maybe look into this + initial_buffer=rows, + ) + self._reset_memoizations() + return self + + @property + def returned_defaults(self): + """Return the values of default columns that were fetched using + the :meth:`.ValuesBase.return_defaults` feature. + + The value is an instance of :class:`.Row`, or ``None`` + if :meth:`.ValuesBase.return_defaults` was not used or if the + backend does not support RETURNING. + + .. seealso:: + + :meth:`.ValuesBase.return_defaults` + + """ + + if self.context.executemany: + raise exc.InvalidRequestError( + "This statement was an executemany call; if return defaults " + "is supported, please use .returned_defaults_rows." + ) + + rows = self.context.returned_default_rows + if rows: + return rows[0] + else: + return None + + def lastrow_has_defaults(self): + """Return ``lastrow_has_defaults()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + """ + + return self.context.lastrow_has_defaults() + + def postfetch_cols(self): + """Return ``postfetch_cols()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() or update() construct. + + """ + + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled expression construct." + ) + elif not self.context.isinsert and not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an insert() or update() " + "expression construct." + ) + return self.context.postfetch_cols + + def prefetch_cols(self): + """Return ``prefetch_cols()`` from the underlying + :class:`.ExecutionContext`. + + See :class:`.ExecutionContext` for details. + + Raises :class:`~sqlalchemy.exc.InvalidRequestError` if the executed + statement is not a compiled expression construct + or is not an insert() or update() construct. + + """ + + if not self.context.compiled: + raise exc.InvalidRequestError( + "Statement is not a compiled expression construct." + ) + elif not self.context.isinsert and not self.context.isupdate: + raise exc.InvalidRequestError( + "Statement is not an insert() or update() " + "expression construct." + ) + return self.context.prefetch_cols + + def supports_sane_rowcount(self): + """Return ``supports_sane_rowcount`` from the dialect. + + See :attr:`_engine.CursorResult.rowcount` for background. + + """ + + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + """Return ``supports_sane_multi_rowcount`` from the dialect. + + See :attr:`_engine.CursorResult.rowcount` for background. + + """ + + return self.dialect.supports_sane_multi_rowcount + + @util.memoized_property + def rowcount(self) -> int: + """Return the 'rowcount' for this result. + + The primary purpose of 'rowcount' is to report the number of rows + matched by the WHERE criterion of an UPDATE or DELETE statement + executed once (i.e. for a single parameter set), which may then be + compared to the number of rows expected to be updated or deleted as a + means of asserting data integrity. + + This attribute is transferred from the ``cursor.rowcount`` attribute + of the DBAPI before the cursor is closed, to support DBAPIs that + don't make this value available after cursor close. Some DBAPIs may + offer meaningful values for other kinds of statements, such as INSERT + and SELECT statements as well. In order to retrieve ``cursor.rowcount`` + for these statements, set the + :paramref:`.Connection.execution_options.preserve_rowcount` + execution option to True, which will cause the ``cursor.rowcount`` + value to be unconditionally memoized before any results are returned + or the cursor is closed, regardless of statement type. + + For cases where the DBAPI does not support rowcount for a particular + kind of statement and/or execution, the returned value will be ``-1``, + which is delivered directly from the DBAPI and is part of :pep:`249`. + All DBAPIs should support rowcount for single-parameter-set + UPDATE and DELETE statements, however. + + .. note:: + + Notes regarding :attr:`_engine.CursorResult.rowcount`: + + + * This attribute returns the number of rows *matched*, + which is not necessarily the same as the number of rows + that were actually *modified*. For example, an UPDATE statement + may have no net change on a given row if the SET values + given are the same as those present in the row already. + Such a row would be matched but not modified. + On backends that feature both styles, such as MySQL, + rowcount is configured to return the match + count in all cases. + + * :attr:`_engine.CursorResult.rowcount` in the default case is + *only* useful in conjunction with an UPDATE or DELETE statement, + and only with a single set of parameters. For other kinds of + statements, SQLAlchemy will not attempt to pre-memoize the value + unless the + :paramref:`.Connection.execution_options.preserve_rowcount` + execution option is used. Note that contrary to :pep:`249`, many + DBAPIs do not support rowcount values for statements that are not + UPDATE or DELETE, particularly when rows are being returned which + are not fully pre-buffered. DBAPIs that dont support rowcount + for a particular kind of statement should return the value ``-1`` + for such statements. + + * :attr:`_engine.CursorResult.rowcount` may not be meaningful + when executing a single statement with multiple parameter sets + (i.e. an :term:`executemany`). Most DBAPIs do not sum "rowcount" + values across multiple parameter sets and will return ``-1`` + when accessed. + + * SQLAlchemy's :ref:`engine_insertmanyvalues` feature does support + a correct population of :attr:`_engine.CursorResult.rowcount` + when the :paramref:`.Connection.execution_options.preserve_rowcount` + execution option is set to True. + + * Statements that use RETURNING may not support rowcount, returning + a ``-1`` value instead. + + .. seealso:: + + :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial` + + :paramref:`.Connection.execution_options.preserve_rowcount` + + """ # noqa: E501 + try: + return self.context.rowcount + except BaseException as e: + self.cursor_strategy.handle_exception(self, self.cursor, e) + raise # not called + + @property + def lastrowid(self): + """Return the 'lastrowid' accessor on the DBAPI cursor. + + This is a DBAPI specific method and is only functional + for those backends which support it, for statements + where it is appropriate. It's behavior is not + consistent across backends. + + Usage of this method is normally unnecessary when + using insert() expression constructs; the + :attr:`~CursorResult.inserted_primary_key` attribute provides a + tuple of primary key values for a newly inserted row, + regardless of database backend. + + """ + try: + return self.context.get_lastrowid() + except BaseException as e: + self.cursor_strategy.handle_exception(self, self.cursor, e) + + @property + def returns_rows(self): + """True if this :class:`_engine.CursorResult` returns zero or more + rows. + + I.e. if it is legal to call the methods + :meth:`_engine.CursorResult.fetchone`, + :meth:`_engine.CursorResult.fetchmany` + :meth:`_engine.CursorResult.fetchall`. + + Overall, the value of :attr:`_engine.CursorResult.returns_rows` should + always be synonymous with whether or not the DBAPI cursor had a + ``.description`` attribute, indicating the presence of result columns, + noting that a cursor that returns zero rows still has a + ``.description`` if a row-returning statement was emitted. + + This attribute should be True for all results that are against + SELECT statements, as well as for DML statements INSERT/UPDATE/DELETE + that use RETURNING. For INSERT/UPDATE/DELETE statements that were + not using RETURNING, the value will usually be False, however + there are some dialect-specific exceptions to this, such as when + using the MSSQL / pyodbc dialect a SELECT is emitted inline in + order to retrieve an inserted primary key value. + + + """ + return self._metadata.returns_rows + + @property + def is_insert(self): + """True if this :class:`_engine.CursorResult` is the result + of a executing an expression language compiled + :func:`_expression.insert` construct. + + When True, this implies that the + :attr:`inserted_primary_key` attribute is accessible, + assuming the statement did not include + a user defined "returning" construct. + + """ + return self.context.isinsert + + def _fetchiter_impl(self): + fetchone = self.cursor_strategy.fetchone + + while True: + row = fetchone(self, self.cursor) + if row is None: + break + yield row + + def _fetchone_impl(self, hard_close=False): + return self.cursor_strategy.fetchone(self, self.cursor, hard_close) + + def _fetchall_impl(self): + return self.cursor_strategy.fetchall(self, self.cursor) + + def _fetchmany_impl(self, size=None): + return self.cursor_strategy.fetchmany(self, self.cursor, size) + + def _raw_row_iterator(self): + return self._fetchiter_impl() + + def merge(self, *others: Result[Any]) -> MergedResult[Any]: + merged_result = super().merge(*others) + if self.context._has_rowcount: + merged_result.rowcount = sum( + cast("CursorResult[Any]", result).rowcount + for result in (self,) + others + ) + return merged_result + + def close(self) -> Any: + """Close this :class:`_engine.CursorResult`. + + This closes out the underlying DBAPI cursor corresponding to the + statement execution, if one is still present. Note that the DBAPI + cursor is automatically released when the :class:`_engine.CursorResult` + exhausts all available rows. :meth:`_engine.CursorResult.close` is + generally an optional method except in the case when discarding a + :class:`_engine.CursorResult` that still has additional rows pending + for fetch. + + After this method is called, it is no longer valid to call upon + the fetch methods, which will raise a :class:`.ResourceClosedError` + on subsequent use. + + .. seealso:: + + :ref:`connections_toplevel` + + """ + self._soft_close(hard=True) + + @_generative + def yield_per(self, num: int) -> Self: + self._yield_per = num + self.cursor_strategy.yield_per(self, self.cursor, num) + return self + + +ResultProxy = CursorResult diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/default.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/default.py new file mode 100644 index 0000000..90cafe4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/default.py @@ -0,0 +1,2343 @@ +# engine/default.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Default implementations of per-dialect sqlalchemy.engine classes. + +These are semi-private implementation classes which are only of importance +to database dialect authors; dialects will usually use the classes here +as the base class for their own corresponding classes. + +""" + +from __future__ import annotations + +import functools +import operator +import random +import re +from time import perf_counter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from . import characteristics +from . import cursor as _cursor +from . import interfaces +from .base import Connection +from .interfaces import CacheStats +from .interfaces import DBAPICursor +from .interfaces import Dialect +from .interfaces import ExecuteStyle +from .interfaces import ExecutionContext +from .reflection import ObjectKind +from .reflection import ObjectScope +from .. import event +from .. import exc +from .. import pool +from .. import util +from ..sql import compiler +from ..sql import dml +from ..sql import expression +from ..sql import type_api +from ..sql._typing import is_tuple_type +from ..sql.base import _NoArg +from ..sql.compiler import DDLCompiler +from ..sql.compiler import InsertmanyvaluesSentinelOpts +from ..sql.compiler import SQLCompiler +from ..sql.elements import quoted_name +from ..util.typing import Final +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from types import ModuleType + + from .base import Engine + from .cursor import ResultFetchStrategy + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPICursorDescription + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import _MutableCoreSingleExecuteParams + from .interfaces import _ParamStyle + from .interfaces import DBAPIConnection + from .interfaces import IsolationLevel + from .row import Row + from .url import URL + from ..event import _ListenerFnType + from ..pool import Pool + from ..pool import PoolProxiedConnection + from ..sql import Executable + from ..sql.compiler import Compiled + from ..sql.compiler import Linting + from ..sql.compiler import ResultColumnsEntry + from ..sql.dml import DMLState + from ..sql.dml import UpdateBase + from ..sql.elements import BindParameter + from ..sql.schema import Column + from ..sql.type_api import _BindProcessorType + from ..sql.type_api import _ResultProcessorType + from ..sql.type_api import TypeEngine + +# When we're handed literal SQL, ensure it's a SELECT query +SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) + + +( + CACHE_HIT, + CACHE_MISS, + CACHING_DISABLED, + NO_CACHE_KEY, + NO_DIALECT_SUPPORT, +) = list(CacheStats) + + +class DefaultDialect(Dialect): + """Default implementation of Dialect""" + + statement_compiler = compiler.SQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler_cls = compiler.GenericTypeCompiler + + preparer = compiler.IdentifierPreparer + supports_alter = True + supports_comments = False + supports_constraint_comments = False + inline_comments = False + supports_statement_cache = True + + div_is_floordiv = True + + bind_typing = interfaces.BindTyping.NONE + + include_set_input_sizes: Optional[Set[Any]] = None + exclude_set_input_sizes: Optional[Set[Any]] = None + + # the first value we'd get for an autoincrement column. + default_sequence_base = 1 + + # most DBAPIs happy with this for execute(). + # not cx_oracle. + execute_sequence_format = tuple + + supports_schemas = True + supports_views = True + supports_sequences = False + sequences_optional = False + preexecute_autoincrement_sequences = False + supports_identity_columns = False + postfetch_lastrowid = True + favor_returning_over_lastrowid = False + insert_null_pk_still_autoincrements = False + update_returning = False + delete_returning = False + update_returning_multifrom = False + delete_returning_multifrom = False + insert_returning = False + + cte_follows_insert = False + + supports_native_enum = False + supports_native_boolean = False + supports_native_uuid = False + returns_native_bytes = False + + non_native_boolean_check_constraint = True + + supports_simple_order_by_label = True + + tuple_in_values = False + + connection_characteristics = util.immutabledict( + {"isolation_level": characteristics.IsolationLevelCharacteristic()} + ) + + engine_config_types: Mapping[str, Any] = util.immutabledict( + { + "pool_timeout": util.asint, + "echo": util.bool_or_str("debug"), + "echo_pool": util.bool_or_str("debug"), + "pool_recycle": util.asint, + "pool_size": util.asint, + "max_overflow": util.asint, + "future": util.asbool, + } + ) + + # if the NUMERIC type + # returns decimal.Decimal. + # *not* the FLOAT type however. + supports_native_decimal = False + + name = "default" + + # length at which to truncate + # any identifier. + max_identifier_length = 9999 + _user_defined_max_identifier_length: Optional[int] = None + + isolation_level: Optional[str] = None + + # sub-categories of max_identifier_length. + # currently these accommodate for MySQL which allows alias names + # of 255 but DDL names only of 64. + max_index_name_length: Optional[int] = None + max_constraint_name_length: Optional[int] = None + + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] = {} + default_paramstyle = "named" + + supports_default_values = False + """dialect supports INSERT... DEFAULT VALUES syntax""" + + supports_default_metavalue = False + """dialect supports INSERT... VALUES (DEFAULT) syntax""" + + default_metavalue_token = "DEFAULT" + """for INSERT... VALUES (DEFAULT) syntax, the token to put in the + parenthesis.""" + + # not sure if this is a real thing but the compiler will deliver it + # if this is the only flag enabled. + supports_empty_insert = True + """dialect supports INSERT () VALUES ()""" + + supports_multivalues_insert = False + + use_insertmanyvalues: bool = False + + use_insertmanyvalues_wo_returning: bool = False + + insertmanyvalues_implicit_sentinel: InsertmanyvaluesSentinelOpts = ( + InsertmanyvaluesSentinelOpts.NOT_SUPPORTED + ) + + insertmanyvalues_page_size: int = 1000 + insertmanyvalues_max_parameters = 32700 + + supports_is_distinct_from = True + + supports_server_side_cursors = False + + server_side_cursors = False + + # extra record-level locking features (#4860) + supports_for_update_of = False + + server_version_info = None + + default_schema_name: Optional[str] = None + + # indicates symbol names are + # UPPERCASEd if they are case insensitive + # within the database. + # if this is True, the methods normalize_name() + # and denormalize_name() must be provided. + requires_name_normalize = False + + is_async = False + + has_terminate = False + + # TODO: this is not to be part of 2.0. implement rudimentary binary + # literals for SQLite, PostgreSQL, MySQL only within + # _Binary.literal_processor + _legacy_binary_type_literal_encoding = "utf-8" + + @util.deprecated_params( + empty_in_strategy=( + "1.4", + "The :paramref:`_sa.create_engine.empty_in_strategy` keyword is " + "deprecated, and no longer has any effect. All IN expressions " + "are now rendered using " + 'the "expanding parameter" strategy which renders a set of bound' + 'expressions, or an "empty set" SELECT, at statement execution' + "time.", + ), + server_side_cursors=( + "1.4", + "The :paramref:`_sa.create_engine.server_side_cursors` parameter " + "is deprecated and will be removed in a future release. Please " + "use the " + ":paramref:`_engine.Connection.execution_options.stream_results` " + "parameter.", + ), + ) + def __init__( + self, + paramstyle: Optional[_ParamStyle] = None, + isolation_level: Optional[IsolationLevel] = None, + dbapi: Optional[ModuleType] = None, + implicit_returning: Literal[True] = True, + supports_native_boolean: Optional[bool] = None, + max_identifier_length: Optional[int] = None, + label_length: Optional[int] = None, + insertmanyvalues_page_size: Union[_NoArg, int] = _NoArg.NO_ARG, + use_insertmanyvalues: Optional[bool] = None, + # util.deprecated_params decorator cannot render the + # Linting.NO_LINTING constant + compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore + server_side_cursors: bool = False, + **kwargs: Any, + ): + if server_side_cursors: + if not self.supports_server_side_cursors: + raise exc.ArgumentError( + "Dialect %s does not support server side cursors" % self + ) + else: + self.server_side_cursors = True + + if getattr(self, "use_setinputsizes", False): + util.warn_deprecated( + "The dialect-level use_setinputsizes attribute is " + "deprecated. Please use " + "bind_typing = BindTyping.SETINPUTSIZES", + "2.0", + ) + self.bind_typing = interfaces.BindTyping.SETINPUTSIZES + + self.positional = False + self._ischema = None + + self.dbapi = dbapi + + if paramstyle is not None: + self.paramstyle = paramstyle + elif self.dbapi is not None: + self.paramstyle = self.dbapi.paramstyle + else: + self.paramstyle = self.default_paramstyle + self.positional = self.paramstyle in ( + "qmark", + "format", + "numeric", + "numeric_dollar", + ) + self.identifier_preparer = self.preparer(self) + self._on_connect_isolation_level = isolation_level + + legacy_tt_callable = getattr(self, "type_compiler", None) + if legacy_tt_callable is not None: + tt_callable = cast( + Type[compiler.GenericTypeCompiler], + self.type_compiler, + ) + else: + tt_callable = self.type_compiler_cls + + self.type_compiler_instance = self.type_compiler = tt_callable(self) + + if supports_native_boolean is not None: + self.supports_native_boolean = supports_native_boolean + + self._user_defined_max_identifier_length = max_identifier_length + if self._user_defined_max_identifier_length: + self.max_identifier_length = ( + self._user_defined_max_identifier_length + ) + self.label_length = label_length + self.compiler_linting = compiler_linting + + if use_insertmanyvalues is not None: + self.use_insertmanyvalues = use_insertmanyvalues + + if insertmanyvalues_page_size is not _NoArg.NO_ARG: + self.insertmanyvalues_page_size = insertmanyvalues_page_size + + @property + @util.deprecated( + "2.0", + "full_returning is deprecated, please use insert_returning, " + "update_returning, delete_returning", + ) + def full_returning(self): + return ( + self.insert_returning + and self.update_returning + and self.delete_returning + ) + + @util.memoized_property + def insert_executemany_returning(self): + """Default implementation for insert_executemany_returning, if not + otherwise overridden by the specific dialect. + + The default dialect determines "insert_executemany_returning" is + available if the dialect in use has opted into using the + "use_insertmanyvalues" feature. If they haven't opted into that, then + this attribute is False, unless the dialect in question overrides this + and provides some other implementation (such as the Oracle dialect). + + """ + return self.insert_returning and self.use_insertmanyvalues + + @util.memoized_property + def insert_executemany_returning_sort_by_parameter_order(self): + """Default implementation for + insert_executemany_returning_deterministic_order, if not otherwise + overridden by the specific dialect. + + The default dialect determines "insert_executemany_returning" can have + deterministic order only if the dialect in use has opted into using the + "use_insertmanyvalues" feature, which implements deterministic ordering + using client side sentinel columns only by default. The + "insertmanyvalues" feature also features alternate forms that can + use server-generated PK values as "sentinels", but those are only + used if the :attr:`.Dialect.insertmanyvalues_implicit_sentinel` + bitflag enables those alternate SQL forms, which are disabled + by default. + + If the dialect in use hasn't opted into that, then this attribute is + False, unless the dialect in question overrides this and provides some + other implementation (such as the Oracle dialect). + + """ + return self.insert_returning and self.use_insertmanyvalues + + update_executemany_returning = False + delete_executemany_returning = False + + @util.memoized_property + def loaded_dbapi(self) -> ModuleType: + if self.dbapi is None: + raise exc.InvalidRequestError( + f"Dialect {self} does not have a Python DBAPI established " + "and cannot be used for actual database interaction" + ) + return self.dbapi + + @util.memoized_property + def _bind_typing_render_casts(self): + return self.bind_typing is interfaces.BindTyping.RENDER_CASTS + + def _ensure_has_table_connection(self, arg): + if not isinstance(arg, Connection): + raise exc.ArgumentError( + "The argument passed to Dialect.has_table() should be a " + "%s, got %s. " + "Additionally, the Dialect.has_table() method is for " + "internal dialect " + "use only; please use " + "``inspect(some_engine).has_table(>)`` " + "for public API use." % (Connection, type(arg)) + ) + + @util.memoized_property + def _supports_statement_cache(self): + ssc = self.__class__.__dict__.get("supports_statement_cache", None) + if ssc is None: + util.warn( + "Dialect %s:%s will not make use of SQL compilation caching " + "as it does not set the 'supports_statement_cache' attribute " + "to ``True``. This can have " + "significant performance implications including some " + "performance degradations in comparison to prior SQLAlchemy " + "versions. Dialect maintainers should seek to set this " + "attribute to True after appropriate development and testing " + "for SQLAlchemy 1.4 caching support. Alternatively, this " + "attribute may be set to False which will disable this " + "warning." % (self.name, self.driver), + code="cprf", + ) + + return bool(ssc) + + @util.memoized_property + def _type_memos(self): + return weakref.WeakKeyDictionary() + + @property + def dialect_description(self): + return self.name + "+" + self.driver + + @property + def supports_sane_rowcount_returning(self): + """True if this dialect supports sane rowcount even if RETURNING is + in use. + + For dialects that don't support RETURNING, this is synonymous with + ``supports_sane_rowcount``. + + """ + return self.supports_sane_rowcount + + @classmethod + def get_pool_class(cls, url: URL) -> Type[Pool]: + return getattr(cls, "poolclass", pool.QueuePool) + + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: + return self.get_pool_class(url) + + @classmethod + def load_provisioning(cls): + package = ".".join(cls.__module__.split(".")[0:-1]) + try: + __import__(package + ".provision") + except ImportError: + pass + + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: + if self._on_connect_isolation_level is not None: + + def builtin_connect(dbapi_conn, conn_rec): + self._assert_and_set_isolation_level( + dbapi_conn, self._on_connect_isolation_level + ) + + return builtin_connect + else: + return None + + def initialize(self, connection): + try: + self.server_version_info = self._get_server_version_info( + connection + ) + except NotImplementedError: + self.server_version_info = None + try: + self.default_schema_name = self._get_default_schema_name( + connection + ) + except NotImplementedError: + self.default_schema_name = None + + try: + self.default_isolation_level = self.get_default_isolation_level( + connection.connection.dbapi_connection + ) + except NotImplementedError: + self.default_isolation_level = None + + if not self._user_defined_max_identifier_length: + max_ident_length = self._check_max_identifier_length(connection) + if max_ident_length: + self.max_identifier_length = max_ident_length + + if ( + self.label_length + and self.label_length > self.max_identifier_length + ): + raise exc.ArgumentError( + "Label length of %d is greater than this dialect's" + " maximum identifier length of %d" + % (self.label_length, self.max_identifier_length) + ) + + def on_connect(self): + # inherits the docstring from interfaces.Dialect.on_connect + return None + + def _check_max_identifier_length(self, connection): + """Perform a connection / server version specific check to determine + the max_identifier_length. + + If the dialect's class level max_identifier_length should be used, + can return None. + + .. versionadded:: 1.3.9 + + """ + return None + + def get_default_isolation_level(self, dbapi_conn): + """Given a DBAPI connection, return its isolation level, or + a default isolation level if one cannot be retrieved. + + May be overridden by subclasses in order to provide a + "fallback" isolation level for databases that cannot reliably + retrieve the actual isolation level. + + By default, calls the :meth:`_engine.Interfaces.get_isolation_level` + method, propagating any exceptions raised. + + .. versionadded:: 1.3.22 + + """ + return self.get_isolation_level(dbapi_conn) + + def type_descriptor(self, typeobj): + """Provide a database-specific :class:`.TypeEngine` object, given + the generic object which comes from the types module. + + This method looks for a dictionary called + ``colspecs`` as a class or instance-level variable, + and passes on to :func:`_types.adapt_type`. + + """ + return type_api.adapt_type(typeobj, self.colspecs) + + def has_index(self, connection, table_name, index_name, schema=None, **kw): + if not self.has_table(connection, table_name, schema=schema, **kw): + return False + for idx in self.get_indexes( + connection, table_name, schema=schema, **kw + ): + if idx["name"] == index_name: + return True + else: + return False + + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + return schema_name in self.get_schema_names(connection, **kw) + + def validate_identifier(self, ident): + if len(ident) > self.max_identifier_length: + raise exc.IdentifierError( + "Identifier '%s' exceeds maximum length of %d characters" + % (ident, self.max_identifier_length) + ) + + def connect(self, *cargs, **cparams): + # inherits the docstring from interfaces.Dialect.connect + return self.loaded_dbapi.connect(*cargs, **cparams) + + def create_connect_args(self, url): + # inherits the docstring from interfaces.Dialect.create_connect_args + opts = url.translate_connect_args() + opts.update(url.query) + return ([], opts) + + def set_engine_execution_options( + self, engine: Engine, opts: Mapping[str, Any] + ) -> None: + supported_names = set(self.connection_characteristics).intersection( + opts + ) + if supported_names: + characteristics: Mapping[str, Any] = util.immutabledict( + (name, opts[name]) for name in supported_names + ) + + @event.listens_for(engine, "engine_connect") + def set_connection_characteristics(connection): + self._set_connection_characteristics( + connection, characteristics + ) + + def set_connection_execution_options( + self, connection: Connection, opts: Mapping[str, Any] + ) -> None: + supported_names = set(self.connection_characteristics).intersection( + opts + ) + if supported_names: + characteristics: Mapping[str, Any] = util.immutabledict( + (name, opts[name]) for name in supported_names + ) + self._set_connection_characteristics(connection, characteristics) + + def _set_connection_characteristics(self, connection, characteristics): + characteristic_values = [ + (name, self.connection_characteristics[name], value) + for name, value in characteristics.items() + ] + + if connection.in_transaction(): + trans_objs = [ + (name, obj) + for name, obj, value in characteristic_values + if obj.transactional + ] + if trans_objs: + raise exc.InvalidRequestError( + "This connection has already initialized a SQLAlchemy " + "Transaction() object via begin() or autobegin; " + "%s may not be altered unless rollback() or commit() " + "is called first." + % (", ".join(name for name, obj in trans_objs)) + ) + + dbapi_connection = connection.connection.dbapi_connection + for name, characteristic, value in characteristic_values: + characteristic.set_characteristic(self, dbapi_connection, value) + connection.connection._connection_record.finalize_callback.append( + functools.partial(self._reset_characteristics, characteristics) + ) + + def _reset_characteristics(self, characteristics, dbapi_connection): + for characteristic_name in characteristics: + characteristic = self.connection_characteristics[ + characteristic_name + ] + characteristic.reset_characteristic(self, dbapi_connection) + + def do_begin(self, dbapi_connection): + pass + + def do_rollback(self, dbapi_connection): + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection): + dbapi_connection.commit() + + def do_terminate(self, dbapi_connection): + self.do_close(dbapi_connection) + + def do_close(self, dbapi_connection): + dbapi_connection.close() + + @util.memoized_property + def _dialect_specific_select_one(self): + return str(expression.select(1).compile(dialect=self)) + + def _do_ping_w_event(self, dbapi_connection: DBAPIConnection) -> bool: + try: + return self.do_ping(dbapi_connection) + except self.loaded_dbapi.Error as err: + is_disconnect = self.is_disconnect(err, dbapi_connection, None) + + if self._has_events: + try: + Connection._handle_dbapi_exception_noconnection( + err, + self, + is_disconnect=is_disconnect, + invalidate_pool_on_disconnect=False, + is_pre_ping=True, + ) + except exc.StatementError as new_err: + is_disconnect = new_err.connection_invalidated + + if is_disconnect: + return False + else: + raise + + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: + cursor = None + + cursor = dbapi_connection.cursor() + try: + cursor.execute(self._dialect_specific_select_one) + finally: + cursor.close() + return True + + def create_xid(self): + """Create a random two-phase transaction ID. + + This id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). Its format is unspecified. + """ + + return "_sa_%032x" % random.randint(0, 2**128) + + def do_savepoint(self, connection, name): + connection.execute(expression.SavepointClause(name)) + + def do_rollback_to_savepoint(self, connection, name): + connection.execute(expression.RollbackToSavepointClause(name)) + + def do_release_savepoint(self, connection, name): + connection.execute(expression.ReleaseSavepointClause(name)) + + def _deliver_insertmanyvalues_batches( + self, cursor, statement, parameters, generic_setinputsizes, context + ): + context = cast(DefaultExecutionContext, context) + compiled = cast(SQLCompiler, context.compiled) + + _composite_sentinel_proc: Sequence[ + Optional[_ResultProcessorType[Any]] + ] = () + _scalar_sentinel_proc: Optional[_ResultProcessorType[Any]] = None + _sentinel_proc_initialized: bool = False + + compiled_parameters = context.compiled_parameters + + imv = compiled._insertmanyvalues + assert imv is not None + + is_returning: Final[bool] = bool(compiled.effective_returning) + batch_size = context.execution_options.get( + "insertmanyvalues_page_size", self.insertmanyvalues_page_size + ) + + if compiled.schema_translate_map: + schema_translate_map = context.execution_options.get( + "schema_translate_map", {} + ) + else: + schema_translate_map = None + + if is_returning: + result: Optional[List[Any]] = [] + context._insertmanyvalues_rows = result + + sort_by_parameter_order = imv.sort_by_parameter_order + + else: + sort_by_parameter_order = False + result = None + + for imv_batch in compiled._deliver_insertmanyvalues_batches( + statement, + parameters, + compiled_parameters, + generic_setinputsizes, + batch_size, + sort_by_parameter_order, + schema_translate_map, + ): + yield imv_batch + + if is_returning: + + rows = context.fetchall_for_returning(cursor) + + # I would have thought "is_returning: Final[bool]" + # would have assured this but pylance thinks not + assert result is not None + + if imv.num_sentinel_columns and not imv_batch.is_downgraded: + composite_sentinel = imv.num_sentinel_columns > 1 + if imv.implicit_sentinel: + # for implicit sentinel, which is currently single-col + # integer autoincrement, do a simple sort. + assert not composite_sentinel + result.extend( + sorted(rows, key=operator.itemgetter(-1)) + ) + continue + + # otherwise, create dictionaries to match up batches + # with parameters + assert imv.sentinel_param_keys + assert imv.sentinel_columns + + _nsc = imv.num_sentinel_columns + + if not _sentinel_proc_initialized: + if composite_sentinel: + _composite_sentinel_proc = [ + col.type._cached_result_processor( + self, cursor_desc[1] + ) + for col, cursor_desc in zip( + imv.sentinel_columns, + cursor.description[-_nsc:], + ) + ] + else: + _scalar_sentinel_proc = ( + imv.sentinel_columns[0] + ).type._cached_result_processor( + self, cursor.description[-1][1] + ) + _sentinel_proc_initialized = True + + rows_by_sentinel: Union[ + Dict[Tuple[Any, ...], Any], + Dict[Any, Any], + ] + if composite_sentinel: + rows_by_sentinel = { + tuple( + (proc(val) if proc else val) + for val, proc in zip( + row[-_nsc:], _composite_sentinel_proc + ) + ): row + for row in rows + } + elif _scalar_sentinel_proc: + rows_by_sentinel = { + _scalar_sentinel_proc(row[-1]): row for row in rows + } + else: + rows_by_sentinel = {row[-1]: row for row in rows} + + if len(rows_by_sentinel) != len(imv_batch.batch): + # see test_insert_exec.py:: + # IMVSentinelTest::test_sentinel_incorrect_rowcount + # for coverage / demonstration + raise exc.InvalidRequestError( + f"Sentinel-keyed result set did not produce " + f"correct number of rows {len(imv_batch.batch)}; " + "produced " + f"{len(rows_by_sentinel)}. Please ensure the " + "sentinel column is fully unique and populated in " + "all cases." + ) + + try: + ordered_rows = [ + rows_by_sentinel[sentinel_keys] + for sentinel_keys in imv_batch.sentinel_values + ] + except KeyError as ke: + # see test_insert_exec.py:: + # IMVSentinelTest::test_sentinel_cant_match_keys + # for coverage / demonstration + raise exc.InvalidRequestError( + f"Can't match sentinel values in result set to " + f"parameter sets; key {ke.args[0]!r} was not " + "found. " + "There may be a mismatch between the datatype " + "passed to the DBAPI driver vs. that which it " + "returns in a result row. Ensure the given " + "Python value matches the expected result type " + "*exactly*, taking care to not rely upon implicit " + "conversions which may occur such as when using " + "strings in place of UUID or integer values, etc. " + ) from ke + + result.extend(ordered_rows) + + else: + result.extend(rows) + + def do_executemany(self, cursor, statement, parameters, context=None): + cursor.executemany(statement, parameters) + + def do_execute(self, cursor, statement, parameters, context=None): + cursor.execute(statement, parameters) + + def do_execute_no_params(self, cursor, statement, context=None): + cursor.execute(statement) + + def is_disconnect(self, e, connection, cursor): + return False + + @util.memoized_instancemethod + def _gen_allowed_isolation_levels(self, dbapi_conn): + try: + raw_levels = list(self.get_isolation_level_values(dbapi_conn)) + except NotImplementedError: + return None + else: + normalized_levels = [ + level.replace("_", " ").upper() for level in raw_levels + ] + if raw_levels != normalized_levels: + raise ValueError( + f"Dialect {self.name!r} get_isolation_level_values() " + f"method should return names as UPPERCASE using spaces, " + f"not underscores; got " + f"{sorted(set(raw_levels).difference(normalized_levels))}" + ) + return tuple(normalized_levels) + + def _assert_and_set_isolation_level(self, dbapi_conn, level): + level = level.replace("_", " ").upper() + + _allowed_isolation_levels = self._gen_allowed_isolation_levels( + dbapi_conn + ) + if ( + _allowed_isolation_levels + and level not in _allowed_isolation_levels + ): + raise exc.ArgumentError( + f"Invalid value {level!r} for isolation_level. " + f"Valid isolation levels for {self.name!r} are " + f"{', '.join(_allowed_isolation_levels)}" + ) + + self.set_isolation_level(dbapi_conn, level) + + def reset_isolation_level(self, dbapi_conn): + if self._on_connect_isolation_level is not None: + assert ( + self._on_connect_isolation_level == "AUTOCOMMIT" + or self._on_connect_isolation_level + == self.default_isolation_level + ) + self._assert_and_set_isolation_level( + dbapi_conn, self._on_connect_isolation_level + ) + else: + assert self.default_isolation_level is not None + self._assert_and_set_isolation_level( + dbapi_conn, + self.default_isolation_level, + ) + + def normalize_name(self, name): + if name is None: + return None + + name_lower = name.lower() + name_upper = name.upper() + + if name_upper == name_lower: + # name has no upper/lower conversion, e.g. non-european characters. + # return unchanged + return name + elif name_upper == name and not ( + self.identifier_preparer._requires_quotes + )(name_lower): + # name is all uppercase and doesn't require quoting; normalize + # to all lower case + return name_lower + elif name_lower == name: + # name is all lower case, which if denormalized means we need to + # force quoting on it + return quoted_name(name, quote=True) + else: + # name is mixed case, means it will be quoted in SQL when used + # later, no normalizes + return name + + def denormalize_name(self, name): + if name is None: + return None + + name_lower = name.lower() + name_upper = name.upper() + + if name_upper == name_lower: + # name has no upper/lower conversion, e.g. non-european characters. + # return unchanged + return name + elif name_lower == name and not ( + self.identifier_preparer._requires_quotes + )(name_lower): + name = name_upper + return name + + def get_driver_connection(self, connection): + return connection + + def _overrides_default(self, method): + return ( + getattr(type(self), method).__code__ + is not getattr(DefaultDialect, method).__code__ + ) + + def _default_multi_reflect( + self, + single_tbl_method, + connection, + kind, + schema, + filter_names, + scope, + **kw, + ): + names_fns = [] + temp_names_fns = [] + if ObjectKind.TABLE in kind: + names_fns.append(self.get_table_names) + temp_names_fns.append(self.get_temp_table_names) + if ObjectKind.VIEW in kind: + names_fns.append(self.get_view_names) + temp_names_fns.append(self.get_temp_view_names) + if ObjectKind.MATERIALIZED_VIEW in kind: + names_fns.append(self.get_materialized_view_names) + # no temp materialized view at the moment + # temp_names_fns.append(self.get_temp_materialized_view_names) + + unreflectable = kw.pop("unreflectable", {}) + + if ( + filter_names + and scope is ObjectScope.ANY + and kind is ObjectKind.ANY + ): + # if names are given and no qualification on type of table + # (i.e. the Table(..., autoload) case), take the names as given, + # don't run names queries. If a table does not exit + # NoSuchTableError is raised and it's skipped + + # this also suits the case for mssql where we can reflect + # individual temp tables but there's no temp_names_fn + names = filter_names + else: + names = [] + name_kw = {"schema": schema, **kw} + fns = [] + if ObjectScope.DEFAULT in scope: + fns.extend(names_fns) + if ObjectScope.TEMPORARY in scope: + fns.extend(temp_names_fns) + + for fn in fns: + try: + names.extend(fn(connection, **name_kw)) + except NotImplementedError: + pass + + if filter_names: + filter_names = set(filter_names) + + # iterate over all the tables/views and call the single table method + for table in names: + if not filter_names or table in filter_names: + key = (schema, table) + try: + yield ( + key, + single_tbl_method( + connection, table, schema=schema, **kw + ), + ) + except exc.UnreflectableTableError as err: + if key not in unreflectable: + unreflectable[key] = err + except exc.NoSuchTableError: + pass + + def get_multi_table_options(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_options, connection, **kw + ) + + def get_multi_columns(self, connection, **kw): + return self._default_multi_reflect(self.get_columns, connection, **kw) + + def get_multi_pk_constraint(self, connection, **kw): + return self._default_multi_reflect( + self.get_pk_constraint, connection, **kw + ) + + def get_multi_foreign_keys(self, connection, **kw): + return self._default_multi_reflect( + self.get_foreign_keys, connection, **kw + ) + + def get_multi_indexes(self, connection, **kw): + return self._default_multi_reflect(self.get_indexes, connection, **kw) + + def get_multi_unique_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_unique_constraints, connection, **kw + ) + + def get_multi_check_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_check_constraints, connection, **kw + ) + + def get_multi_table_comment(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_comment, connection, **kw + ) + + +class StrCompileDialect(DefaultDialect): + statement_compiler = compiler.StrSQLCompiler + ddl_compiler = compiler.DDLCompiler + type_compiler_cls = compiler.StrSQLTypeCompiler + preparer = compiler.IdentifierPreparer + + insert_returning = True + update_returning = True + delete_returning = True + + supports_statement_cache = True + + supports_identity_columns = True + + supports_sequences = True + sequences_optional = True + preexecute_autoincrement_sequences = False + + supports_native_boolean = True + + supports_multivalues_insert = True + supports_simple_order_by_label = True + + +class DefaultExecutionContext(ExecutionContext): + isinsert = False + isupdate = False + isdelete = False + is_crud = False + is_text = False + isddl = False + + execute_style: ExecuteStyle = ExecuteStyle.EXECUTE + + compiled: Optional[Compiled] = None + result_column_struct: Optional[ + Tuple[List[ResultColumnsEntry], bool, bool, bool, bool] + ] = None + returned_default_rows: Optional[Sequence[Row[Any]]] = None + + execution_options: _ExecuteOptions = util.EMPTY_DICT + + cursor_fetch_strategy = _cursor._DEFAULT_FETCH + + invoked_statement: Optional[Executable] = None + + _is_implicit_returning = False + _is_explicit_returning = False + _is_supplemental_returning = False + _is_server_side = False + + _soft_closed = False + + _rowcount: Optional[int] = None + + # a hook for SQLite's translation of + # result column names + # NOTE: pyhive is using this hook, can't remove it :( + _translate_colname: Optional[Callable[[str], str]] = None + + _expanded_parameters: Mapping[str, List[str]] = util.immutabledict() + """used by set_input_sizes(). + + This collection comes from ``ExpandedState.parameter_expansion``. + + """ + + cache_hit = NO_CACHE_KEY + + root_connection: Connection + _dbapi_connection: PoolProxiedConnection + dialect: Dialect + unicode_statement: str + cursor: DBAPICursor + compiled_parameters: List[_MutableCoreSingleExecuteParams] + parameters: _DBAPIMultiExecuteParams + extracted_parameters: Optional[Sequence[BindParameter[Any]]] + + _empty_dict_params = cast("Mapping[str, Any]", util.EMPTY_DICT) + + _insertmanyvalues_rows: Optional[List[Tuple[Any, ...]]] = None + _num_sentinel_cols: int = 0 + + @classmethod + def _init_ddl( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: + """Initialize execution context for an ExecutableDDLElement + construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + + self.compiled = compiled = compiled_ddl + self.isddl = True + + self.execution_options = execution_options + + self.unicode_statement = str(compiled) + if compiled.schema_translate_map: + schema_translate_map = self.execution_options.get( + "schema_translate_map", {} + ) + + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, schema_translate_map + ) + + self.statement = self.unicode_statement + + self.cursor = self.create_cursor() + self.compiled_parameters = [] + + if dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [self._empty_dict_params] + + return self + + @classmethod + def _init_compiled( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: Optional[Sequence[BindParameter[Any]]], + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: + """Initialize execution context for a Compiled construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + self.extracted_parameters = extracted_parameters + self.invoked_statement = invoked_statement + self.compiled = compiled + self.cache_hit = cache_hit + + self.execution_options = execution_options + + self.result_column_struct = ( + compiled._result_columns, + compiled._ordered_columns, + compiled._textual_ordered_columns, + compiled._ad_hoc_textual, + compiled._loose_column_name_matching, + ) + + self.isinsert = ii = compiled.isinsert + self.isupdate = iu = compiled.isupdate + self.isdelete = id_ = compiled.isdelete + self.is_text = compiled.isplaintext + + if ii or iu or id_: + dml_statement = compiled.compile_state.statement # type: ignore + if TYPE_CHECKING: + assert isinstance(dml_statement, UpdateBase) + self.is_crud = True + self._is_explicit_returning = ier = bool(dml_statement._returning) + self._is_implicit_returning = iir = bool( + compiled.implicit_returning + ) + if iir and dml_statement._supplemental_returning: + self._is_supplemental_returning = True + + # dont mix implicit and explicit returning + assert not (iir and ier) + + if (ier or iir) and compiled.for_executemany: + if ii and not self.dialect.insert_executemany_returning: + raise exc.InvalidRequestError( + f"Dialect {self.dialect.dialect_description} with " + f"current server capabilities does not support " + "INSERT..RETURNING when executemany is used" + ) + elif ( + ii + and dml_statement._sort_by_parameter_order + and not self.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501 + ): + raise exc.InvalidRequestError( + f"Dialect {self.dialect.dialect_description} with " + f"current server capabilities does not support " + "INSERT..RETURNING with deterministic row ordering " + "when executemany is used" + ) + elif ( + ii + and self.dialect.use_insertmanyvalues + and not compiled._insertmanyvalues + ): + raise exc.InvalidRequestError( + 'Statement does not have "insertmanyvalues" ' + "enabled, can't use INSERT..RETURNING with " + "executemany in this case." + ) + elif iu and not self.dialect.update_executemany_returning: + raise exc.InvalidRequestError( + f"Dialect {self.dialect.dialect_description} with " + f"current server capabilities does not support " + "UPDATE..RETURNING when executemany is used" + ) + elif id_ and not self.dialect.delete_executemany_returning: + raise exc.InvalidRequestError( + f"Dialect {self.dialect.dialect_description} with " + f"current server capabilities does not support " + "DELETE..RETURNING when executemany is used" + ) + + if not parameters: + self.compiled_parameters = [ + compiled.construct_params( + extracted_parameters=extracted_parameters, + escape_names=False, + ) + ] + else: + self.compiled_parameters = [ + compiled.construct_params( + m, + escape_names=False, + _group_number=grp, + extracted_parameters=extracted_parameters, + ) + for grp, m in enumerate(parameters) + ] + + if len(parameters) > 1: + if self.isinsert and compiled._insertmanyvalues: + self.execute_style = ExecuteStyle.INSERTMANYVALUES + + imv = compiled._insertmanyvalues + if imv.sentinel_columns is not None: + self._num_sentinel_cols = imv.num_sentinel_columns + else: + self.execute_style = ExecuteStyle.EXECUTEMANY + + self.unicode_statement = compiled.string + + self.cursor = self.create_cursor() + + if self.compiled.insert_prefetch or self.compiled.update_prefetch: + self._process_execute_defaults() + + processors = compiled._bind_processors + + flattened_processors: Mapping[ + str, _BindProcessorType[Any] + ] = processors # type: ignore[assignment] + + if compiled.literal_execute_params or compiled.post_compile_params: + if self.executemany: + raise exc.InvalidRequestError( + "'literal_execute' or 'expanding' parameters can't be " + "used with executemany()" + ) + + expanded_state = compiled._process_parameters_for_postcompile( + self.compiled_parameters[0] + ) + + # re-assign self.unicode_statement + self.unicode_statement = expanded_state.statement + + self._expanded_parameters = expanded_state.parameter_expansion + + flattened_processors = dict(processors) # type: ignore + flattened_processors.update(expanded_state.processors) + positiontup = expanded_state.positiontup + elif compiled.positional: + positiontup = self.compiled.positiontup + else: + positiontup = None + + if compiled.schema_translate_map: + schema_translate_map = self.execution_options.get( + "schema_translate_map", {} + ) + rst = compiled.preparer._render_schema_translates + self.unicode_statement = rst( + self.unicode_statement, schema_translate_map + ) + + # final self.unicode_statement is now assigned, encode if needed + # by dialect + self.statement = self.unicode_statement + + # Convert the dictionary of bind parameter values + # into a dict or list to be sent to the DBAPI's + # execute() or executemany() method. + + if compiled.positional: + core_positional_parameters: MutableSequence[Sequence[Any]] = [] + assert positiontup is not None + for compiled_params in self.compiled_parameters: + l_param: List[Any] = [ + ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + ) + for key in positiontup + ] + core_positional_parameters.append( + dialect.execute_sequence_format(l_param) + ) + + self.parameters = core_positional_parameters + else: + core_dict_parameters: MutableSequence[Dict[str, Any]] = [] + escaped_names = compiled.escaped_bind_names + + # note that currently, "expanded" parameters will be present + # in self.compiled_parameters in their quoted form. This is + # slightly inconsistent with the approach taken as of + # #8056 where self.compiled_parameters is meant to contain unquoted + # param names. + d_param: Dict[str, Any] + for compiled_params in self.compiled_parameters: + if escaped_names: + d_param = { + escaped_names.get(key, key): ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + ) + for key in compiled_params + } + else: + d_param = { + key: ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + ) + for key in compiled_params + } + + core_dict_parameters.append(d_param) + + self.parameters = core_dict_parameters + + return self + + @classmethod + def _init_statement( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: + """Initialize execution context for a string SQL statement.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + self.is_text = True + + self.execution_options = execution_options + + if not parameters: + if self.dialect.positional: + self.parameters = [dialect.execute_sequence_format()] + else: + self.parameters = [self._empty_dict_params] + elif isinstance(parameters[0], dialect.execute_sequence_format): + self.parameters = parameters + elif isinstance(parameters[0], dict): + self.parameters = parameters + else: + self.parameters = [ + dialect.execute_sequence_format(p) for p in parameters + ] + + if len(parameters) > 1: + self.execute_style = ExecuteStyle.EXECUTEMANY + + self.statement = self.unicode_statement = statement + + self.cursor = self.create_cursor() + return self + + @classmethod + def _init_default( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: + """Initialize execution context for a ColumnDefault construct.""" + + self = cls.__new__(cls) + self.root_connection = connection + self._dbapi_connection = dbapi_connection + self.dialect = connection.dialect + + self.execution_options = execution_options + + self.cursor = self.create_cursor() + return self + + def _get_cache_stats(self) -> str: + if self.compiled is None: + return "raw sql" + + now = perf_counter() + + ch = self.cache_hit + + gen_time = self.compiled._gen_time + assert gen_time is not None + + if ch is NO_CACHE_KEY: + return "no key %.5fs" % (now - gen_time,) + elif ch is CACHE_HIT: + return "cached since %.4gs ago" % (now - gen_time,) + elif ch is CACHE_MISS: + return "generated in %.5fs" % (now - gen_time,) + elif ch is CACHING_DISABLED: + if "_cache_disable_reason" in self.execution_options: + return "caching disabled (%s) %.5fs " % ( + self.execution_options["_cache_disable_reason"], + now - gen_time, + ) + else: + return "caching disabled %.5fs" % (now - gen_time,) + elif ch is NO_DIALECT_SUPPORT: + return "dialect %s+%s does not support caching %.5fs" % ( + self.dialect.name, + self.dialect.driver, + now - gen_time, + ) + else: + return "unknown" + + @property + def executemany(self): + return self.execute_style in ( + ExecuteStyle.EXECUTEMANY, + ExecuteStyle.INSERTMANYVALUES, + ) + + @util.memoized_property + def identifier_preparer(self): + if self.compiled: + return self.compiled.preparer + elif "schema_translate_map" in self.execution_options: + return self.dialect.identifier_preparer._with_schema_translate( + self.execution_options["schema_translate_map"] + ) + else: + return self.dialect.identifier_preparer + + @util.memoized_property + def engine(self): + return self.root_connection.engine + + @util.memoized_property + def postfetch_cols(self) -> Optional[Sequence[Column[Any]]]: + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) + return self.compiled.postfetch + + @util.memoized_property + def prefetch_cols(self) -> Optional[Sequence[Column[Any]]]: + if TYPE_CHECKING: + assert isinstance(self.compiled, SQLCompiler) + if self.isinsert: + return self.compiled.insert_prefetch + elif self.isupdate: + return self.compiled.update_prefetch + else: + return () + + @util.memoized_property + def no_parameters(self): + return self.execution_options.get("no_parameters", False) + + def _execute_scalar(self, stmt, type_, parameters=None): + """Execute a string statement on the current cursor, returning a + scalar result. + + Used to fire off sequences, default phrases, and "select lastrowid" + types of statements individually or in the context of a parent INSERT + or UPDATE statement. + + """ + + conn = self.root_connection + + if "schema_translate_map" in self.execution_options: + schema_translate_map = self.execution_options.get( + "schema_translate_map", {} + ) + + rst = self.identifier_preparer._render_schema_translates + stmt = rst(stmt, schema_translate_map) + + if not parameters: + if self.dialect.positional: + parameters = self.dialect.execute_sequence_format() + else: + parameters = {} + + conn._cursor_execute(self.cursor, stmt, parameters, context=self) + row = self.cursor.fetchone() + if row is not None: + r = row[0] + else: + r = None + if type_ is not None: + # apply type post processors to the result + proc = type_._cached_result_processor( + self.dialect, self.cursor.description[0][1] + ) + if proc: + return proc(r) + return r + + @util.memoized_property + def connection(self): + return self.root_connection + + def _use_server_side_cursor(self): + if not self.dialect.supports_server_side_cursors: + return False + + if self.dialect.server_side_cursors: + # this is deprecated + use_server_side = self.execution_options.get( + "stream_results", True + ) and ( + self.compiled + and isinstance(self.compiled.statement, expression.Selectable) + or ( + ( + not self.compiled + or isinstance( + self.compiled.statement, expression.TextClause + ) + ) + and self.unicode_statement + and SERVER_SIDE_CURSOR_RE.match(self.unicode_statement) + ) + ) + else: + use_server_side = self.execution_options.get( + "stream_results", False + ) + + return use_server_side + + def create_cursor(self): + if ( + # inlining initial preference checks for SS cursors + self.dialect.supports_server_side_cursors + and ( + self.execution_options.get("stream_results", False) + or ( + self.dialect.server_side_cursors + and self._use_server_side_cursor() + ) + ) + ): + self._is_server_side = True + return self.create_server_side_cursor() + else: + self._is_server_side = False + return self.create_default_cursor() + + def fetchall_for_returning(self, cursor): + return cursor.fetchall() + + def create_default_cursor(self): + return self._dbapi_connection.cursor() + + def create_server_side_cursor(self): + raise NotImplementedError() + + def pre_exec(self): + pass + + def get_out_parameter_values(self, names): + raise NotImplementedError( + "This dialect does not support OUT parameters" + ) + + def post_exec(self): + pass + + def get_result_processor(self, type_, colname, coltype): + """Return a 'result processor' for a given type as present in + cursor.description. + + This has a default implementation that dialects can override + for context-sensitive result type handling. + + """ + return type_._cached_result_processor(self.dialect, coltype) + + def get_lastrowid(self): + """return self.cursor.lastrowid, or equivalent, after an INSERT. + + This may involve calling special cursor functions, issuing a new SELECT + on the cursor (or a new one), or returning a stored value that was + calculated within post_exec(). + + This function will only be called for dialects which support "implicit" + primary key generation, keep preexecute_autoincrement_sequences set to + False, and when no explicit id value was bound to the statement. + + The function is called once for an INSERT statement that would need to + return the last inserted primary key for those dialects that make use + of the lastrowid concept. In these cases, it is called directly after + :meth:`.ExecutionContext.post_exec`. + + """ + return self.cursor.lastrowid + + def handle_dbapi_exception(self, e): + pass + + @util.non_memoized_property + def rowcount(self) -> int: + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + + @property + def _has_rowcount(self): + return self._rowcount is not None + + def supports_sane_rowcount(self): + return self.dialect.supports_sane_rowcount + + def supports_sane_multi_rowcount(self): + return self.dialect.supports_sane_multi_rowcount + + def _setup_result_proxy(self): + exec_opt = self.execution_options + + if self._rowcount is None and exec_opt.get("preserve_rowcount", False): + self._rowcount = self.cursor.rowcount + + if self.is_crud or self.is_text: + result = self._setup_dml_or_text_result() + yp = sr = False + else: + yp = exec_opt.get("yield_per", None) + sr = self._is_server_side or exec_opt.get("stream_results", False) + strategy = self.cursor_fetch_strategy + if sr and strategy is _cursor._DEFAULT_FETCH: + strategy = _cursor.BufferedRowCursorFetchStrategy( + self.cursor, self.execution_options + ) + cursor_description: _DBAPICursorDescription = ( + strategy.alternate_cursor_description + or self.cursor.description + ) + if cursor_description is None: + strategy = _cursor._NO_CURSOR_DQL + + result = _cursor.CursorResult(self, strategy, cursor_description) + + compiled = self.compiled + + if ( + compiled + and not self.isddl + and cast(SQLCompiler, compiled).has_out_parameters + ): + self._setup_out_parameters(result) + + self._soft_closed = result._soft_closed + + if yp: + result = result.yield_per(yp) + + return result + + def _setup_out_parameters(self, result): + compiled = cast(SQLCompiler, self.compiled) + + out_bindparams = [ + (param, name) + for param, name in compiled.bind_names.items() + if param.isoutparam + ] + out_parameters = {} + + for bindparam, raw_value in zip( + [param for param, name in out_bindparams], + self.get_out_parameter_values( + [name for param, name in out_bindparams] + ), + ): + type_ = bindparam.type + impl_type = type_.dialect_impl(self.dialect) + dbapi_type = impl_type.get_dbapi_type(self.dialect.loaded_dbapi) + result_processor = impl_type.result_processor( + self.dialect, dbapi_type + ) + if result_processor is not None: + raw_value = result_processor(raw_value) + out_parameters[bindparam.key] = raw_value + + result.out_parameters = out_parameters + + def _setup_dml_or_text_result(self): + compiled = cast(SQLCompiler, self.compiled) + + strategy: ResultFetchStrategy = self.cursor_fetch_strategy + + if self.isinsert: + if ( + self.execute_style is ExecuteStyle.INSERTMANYVALUES + and compiled.effective_returning + ): + strategy = _cursor.FullyBufferedCursorFetchStrategy( + self.cursor, + initial_buffer=self._insertmanyvalues_rows, + # maintain alt cursor description if set by the + # dialect, e.g. mssql preserves it + alternate_description=( + strategy.alternate_cursor_description + ), + ) + + if compiled.postfetch_lastrowid: + self.inserted_primary_key_rows = ( + self._setup_ins_pk_from_lastrowid() + ) + # else if not self._is_implicit_returning, + # the default inserted_primary_key_rows accessor will + # return an "empty" primary key collection when accessed. + + if self._is_server_side and strategy is _cursor._DEFAULT_FETCH: + strategy = _cursor.BufferedRowCursorFetchStrategy( + self.cursor, self.execution_options + ) + + if strategy is _cursor._NO_CURSOR_DML: + cursor_description = None + else: + cursor_description = ( + strategy.alternate_cursor_description + or self.cursor.description + ) + + if cursor_description is None: + strategy = _cursor._NO_CURSOR_DML + elif self._num_sentinel_cols: + assert self.execute_style is ExecuteStyle.INSERTMANYVALUES + # strip out the sentinel columns from cursor description + # a similar logic is done to the rows only in CursorResult + cursor_description = cursor_description[ + 0 : -self._num_sentinel_cols + ] + + result: _cursor.CursorResult[Any] = _cursor.CursorResult( + self, strategy, cursor_description + ) + + if self.isinsert: + if self._is_implicit_returning: + rows = result.all() + + self.returned_default_rows = rows + + self.inserted_primary_key_rows = ( + self._setup_ins_pk_from_implicit_returning(result, rows) + ) + + # test that it has a cursor metadata that is accurate. the + # first row will have been fetched and current assumptions + # are that the result has only one row, until executemany() + # support is added here. + assert result._metadata.returns_rows + + # Insert statement has both return_defaults() and + # returning(). rewind the result on the list of rows + # we just used. + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() + elif not self._is_explicit_returning: + result._soft_close() + + # we assume here the result does not return any rows. + # *usually*, this will be true. However, some dialects + # such as that of MSSQL/pyodbc need to SELECT a post fetch + # function so this is not necessarily true. + # assert not result.returns_rows + + elif self._is_implicit_returning: + rows = result.all() + + if rows: + self.returned_default_rows = rows + self._rowcount = len(rows) + + if self._is_supplemental_returning: + result._rewind(rows) + else: + result._soft_close() + + # test that it has a cursor metadata that is accurate. + # the rows have all been fetched however. + assert result._metadata.returns_rows + + elif not result._metadata.returns_rows: + # no results, get rowcount + # (which requires open cursor on some drivers) + if self._rowcount is None: + self._rowcount = self.cursor.rowcount + result._soft_close() + elif self.isupdate or self.isdelete: + if self._rowcount is None: + self._rowcount = self.cursor.rowcount + return result + + @util.memoized_property + def inserted_primary_key_rows(self): + # if no specific "get primary key" strategy was set up + # during execution, return a "default" primary key based + # on what's in the compiled_parameters and nothing else. + return self._setup_ins_pk_from_empty() + + def _setup_ins_pk_from_lastrowid(self): + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter + lastrowid = self.get_lastrowid() + return [getter(lastrowid, self.compiled_parameters[0])] + + def _setup_ins_pk_from_empty(self): + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_lastrowid_getter + return [getter(None, param) for param in self.compiled_parameters] + + def _setup_ins_pk_from_implicit_returning(self, result, rows): + if not rows: + return [] + + getter = cast( + SQLCompiler, self.compiled + )._inserted_primary_key_from_returning_getter + compiled_params = self.compiled_parameters + + return [ + getter(row, param) for row, param in zip(rows, compiled_params) + ] + + def lastrow_has_defaults(self): + return (self.isinsert or self.isupdate) and bool( + cast(SQLCompiler, self.compiled).postfetch + ) + + def _prepare_set_input_sizes( + self, + ) -> Optional[List[Tuple[str, Any, TypeEngine[Any]]]]: + """Given a cursor and ClauseParameters, prepare arguments + in order to call the appropriate + style of ``setinputsizes()`` on the cursor, using DB-API types + from the bind parameter's ``TypeEngine`` objects. + + This method only called by those dialects which set + the :attr:`.Dialect.bind_typing` attribute to + :attr:`.BindTyping.SETINPUTSIZES`. cx_Oracle is the only DBAPI + that requires setinputsizes(), pyodbc offers it as an option. + + Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used + for pg8000 and asyncpg, which has been changed to inline rendering + of casts. + + """ + if self.isddl or self.is_text: + return None + + compiled = cast(SQLCompiler, self.compiled) + + inputsizes = compiled._get_set_input_sizes_lookup() + + if inputsizes is None: + return None + + dialect = self.dialect + + # all of the rest of this... cython? + + if dialect._has_events: + inputsizes = dict(inputsizes) + dialect.dispatch.do_setinputsizes( + inputsizes, self.cursor, self.statement, self.parameters, self + ) + + if compiled.escaped_bind_names: + escaped_bind_names = compiled.escaped_bind_names + else: + escaped_bind_names = None + + if dialect.positional: + items = [ + (key, compiled.binds[key]) + for key in compiled.positiontup or () + ] + else: + items = [ + (key, bindparam) + for bindparam, key in compiled.bind_names.items() + ] + + generic_inputsizes: List[Tuple[str, Any, TypeEngine[Any]]] = [] + for key, bindparam in items: + if bindparam in compiled.literal_execute_params: + continue + + if key in self._expanded_parameters: + if is_tuple_type(bindparam.type): + num = len(bindparam.type.types) + dbtypes = inputsizes[bindparam] + generic_inputsizes.extend( + ( + ( + escaped_bind_names.get(paramname, paramname) + if escaped_bind_names is not None + else paramname + ), + dbtypes[idx % num], + bindparam.type.types[idx % num], + ) + for idx, paramname in enumerate( + self._expanded_parameters[key] + ) + ) + else: + dbtype = inputsizes.get(bindparam, None) + generic_inputsizes.extend( + ( + ( + escaped_bind_names.get(paramname, paramname) + if escaped_bind_names is not None + else paramname + ), + dbtype, + bindparam.type, + ) + for paramname in self._expanded_parameters[key] + ) + else: + dbtype = inputsizes.get(bindparam, None) + + escaped_name = ( + escaped_bind_names.get(key, key) + if escaped_bind_names is not None + else key + ) + + generic_inputsizes.append( + (escaped_name, dbtype, bindparam.type) + ) + + return generic_inputsizes + + def _exec_default(self, column, default, type_): + if default.is_sequence: + return self.fire_sequence(default, type_) + elif default.is_callable: + # this codepath is not normally used as it's inlined + # into _process_execute_defaults + self.current_column = column + return default.arg(self) + elif default.is_clause_element: + return self._exec_default_clause_element(column, default, type_) + else: + # this codepath is not normally used as it's inlined + # into _process_execute_defaults + return default.arg + + def _exec_default_clause_element(self, column, default, type_): + # execute a default that's a complete clause element. Here, we have + # to re-implement a miniature version of the compile->parameters-> + # cursor.execute() sequence, since we don't want to modify the state + # of the connection / result in progress or create new connection/ + # result objects etc. + # .. versionchanged:: 1.4 + + if not default._arg_is_typed: + default_arg = expression.type_coerce(default.arg, type_) + else: + default_arg = default.arg + compiled = expression.select(default_arg).compile(dialect=self.dialect) + compiled_params = compiled.construct_params() + processors = compiled._bind_processors + if compiled.positional: + parameters = self.dialect.execute_sequence_format( + [ + ( + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + ) + for key in compiled.positiontup or () + ] + ) + else: + parameters = { + key: ( + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + ) + for key in compiled_params + } + return self._execute_scalar( + str(compiled), type_, parameters=parameters + ) + + current_parameters: Optional[_CoreSingleExecuteParams] = None + """A dictionary of parameters applied to the current row. + + This attribute is only available in the context of a user-defined default + generation function, e.g. as described at :ref:`context_default_functions`. + It consists of a dictionary which includes entries for each column/value + pair that is to be part of the INSERT or UPDATE statement. The keys of the + dictionary will be the key value of each :class:`_schema.Column`, + which is usually + synonymous with the name. + + Note that the :attr:`.DefaultExecutionContext.current_parameters` attribute + does not accommodate for the "multi-values" feature of the + :meth:`_expression.Insert.values` method. The + :meth:`.DefaultExecutionContext.get_current_parameters` method should be + preferred. + + .. seealso:: + + :meth:`.DefaultExecutionContext.get_current_parameters` + + :ref:`context_default_functions` + + """ + + def get_current_parameters(self, isolate_multiinsert_groups=True): + """Return a dictionary of parameters applied to the current row. + + This method can only be used in the context of a user-defined default + generation function, e.g. as described at + :ref:`context_default_functions`. When invoked, a dictionary is + returned which includes entries for each column/value pair that is part + of the INSERT or UPDATE statement. The keys of the dictionary will be + the key value of each :class:`_schema.Column`, + which is usually synonymous + with the name. + + :param isolate_multiinsert_groups=True: indicates that multi-valued + INSERT constructs created using :meth:`_expression.Insert.values` + should be + handled by returning only the subset of parameters that are local + to the current column default invocation. When ``False``, the + raw parameters of the statement are returned including the + naming convention used in the case of multi-valued INSERT. + + .. versionadded:: 1.2 added + :meth:`.DefaultExecutionContext.get_current_parameters` + which provides more functionality over the existing + :attr:`.DefaultExecutionContext.current_parameters` + attribute. + + .. seealso:: + + :attr:`.DefaultExecutionContext.current_parameters` + + :ref:`context_default_functions` + + """ + try: + parameters = self.current_parameters + column = self.current_column + except AttributeError: + raise exc.InvalidRequestError( + "get_current_parameters() can only be invoked in the " + "context of a Python side column default function" + ) + else: + assert column is not None + assert parameters is not None + compile_state = cast( + "DMLState", cast(SQLCompiler, self.compiled).compile_state + ) + assert compile_state is not None + if ( + isolate_multiinsert_groups + and dml.isinsert(compile_state) + and compile_state._has_multi_parameters + ): + if column._is_multiparam_column: + index = column.index + 1 + d = {column.original.key: parameters[column.key]} + else: + d = {column.key: parameters[column.key]} + index = 0 + assert compile_state._dict_parameters is not None + keys = compile_state._dict_parameters.keys() + d.update( + (key, parameters["%s_m%d" % (key, index)]) for key in keys + ) + return d + else: + return parameters + + def get_insert_default(self, column): + if column.default is None: + return None + else: + return self._exec_default(column, column.default, column.type) + + def get_update_default(self, column): + if column.onupdate is None: + return None + else: + return self._exec_default(column, column.onupdate, column.type) + + def _process_execute_defaults(self): + compiled = cast(SQLCompiler, self.compiled) + + key_getter = compiled._within_exec_param_key_getter + + sentinel_counter = 0 + + if compiled.insert_prefetch: + prefetch_recs = [ + ( + c, + key_getter(c), + c._default_description_tuple, + self.get_insert_default, + ) + for c in compiled.insert_prefetch + ] + elif compiled.update_prefetch: + prefetch_recs = [ + ( + c, + key_getter(c), + c._onupdate_description_tuple, + self.get_update_default, + ) + for c in compiled.update_prefetch + ] + else: + prefetch_recs = [] + + for param in self.compiled_parameters: + self.current_parameters = param + + for ( + c, + param_key, + (arg, is_scalar, is_callable, is_sentinel), + fallback, + ) in prefetch_recs: + if is_sentinel: + param[param_key] = sentinel_counter + sentinel_counter += 1 + elif is_scalar: + param[param_key] = arg + elif is_callable: + self.current_column = c + param[param_key] = arg(self) + else: + val = fallback(c) + if val is not None: + param[param_key] = val + + del self.current_parameters + + +DefaultDialect.execution_ctx_cls = DefaultExecutionContext diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/events.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/events.py new file mode 100644 index 0000000..b8e8936 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/events.py @@ -0,0 +1,951 @@ +# engine/events.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +from __future__ import annotations + +import typing +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +from .base import Connection +from .base import Engine +from .interfaces import ConnectionEventsTarget +from .interfaces import DBAPIConnection +from .interfaces import DBAPICursor +from .interfaces import Dialect +from .. import event +from .. import exc +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .interfaces import _CoreMultiExecuteParams + from .interfaces import _CoreSingleExecuteParams + from .interfaces import _DBAPIAnyExecuteParams + from .interfaces import _DBAPIMultiExecuteParams + from .interfaces import _DBAPISingleExecuteParams + from .interfaces import _ExecuteOptions + from .interfaces import ExceptionContext + from .interfaces import ExecutionContext + from .result import Result + from ..pool import ConnectionPoolEntry + from ..sql import Executable + from ..sql.elements import BindParameter + + +class ConnectionEvents(event.Events[ConnectionEventsTarget]): + """Available events for + :class:`_engine.Connection` and :class:`_engine.Engine`. + + The methods here define the name of an event as well as the names of + members that are passed to listener functions. + + An event listener can be associated with any + :class:`_engine.Connection` or :class:`_engine.Engine` + class or instance, such as an :class:`_engine.Engine`, e.g.:: + + from sqlalchemy import event, create_engine + + def before_cursor_execute(conn, cursor, statement, parameters, context, + executemany): + log.info("Received statement: %s", statement) + + engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/test') + event.listen(engine, "before_cursor_execute", before_cursor_execute) + + or with a specific :class:`_engine.Connection`:: + + with engine.begin() as conn: + @event.listens_for(conn, 'before_cursor_execute') + def before_cursor_execute(conn, cursor, statement, parameters, + context, executemany): + log.info("Received statement: %s", statement) + + When the methods are called with a `statement` parameter, such as in + :meth:`.after_cursor_execute` or :meth:`.before_cursor_execute`, + the statement is the exact SQL string that was prepared for transmission + to the DBAPI ``cursor`` in the connection's :class:`.Dialect`. + + The :meth:`.before_execute` and :meth:`.before_cursor_execute` + events can also be established with the ``retval=True`` flag, which + allows modification of the statement and parameters to be sent + to the database. The :meth:`.before_cursor_execute` event is + particularly useful here to add ad-hoc string transformations, such + as comments, to all executions:: + + from sqlalchemy.engine import Engine + from sqlalchemy import event + + @event.listens_for(Engine, "before_cursor_execute", retval=True) + def comment_sql_calls(conn, cursor, statement, parameters, + context, executemany): + statement = statement + " -- some comment" + return statement, parameters + + .. note:: :class:`_events.ConnectionEvents` can be established on any + combination of :class:`_engine.Engine`, :class:`_engine.Connection`, + as well + as instances of each of those classes. Events across all + four scopes will fire off for a given instance of + :class:`_engine.Connection`. However, for performance reasons, the + :class:`_engine.Connection` object determines at instantiation time + whether or not its parent :class:`_engine.Engine` has event listeners + established. Event listeners added to the :class:`_engine.Engine` + class or to an instance of :class:`_engine.Engine` + *after* the instantiation + of a dependent :class:`_engine.Connection` instance will usually + *not* be available on that :class:`_engine.Connection` instance. + The newly + added listeners will instead take effect for + :class:`_engine.Connection` + instances created subsequent to those event listeners being + established on the parent :class:`_engine.Engine` class or instance. + + :param retval=False: Applies to the :meth:`.before_execute` and + :meth:`.before_cursor_execute` events only. When True, the + user-defined event function must have a return value, which + is a tuple of parameters that replace the given statement + and parameters. See those methods for a description of + specific return arguments. + + """ # noqa + + _target_class_doc = "SomeEngine" + _dispatch_target = ConnectionEventsTarget + + @classmethod + def _accept_with( + cls, + target: Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]], + identifier: str, + ) -> Optional[Union[ConnectionEventsTarget, Type[ConnectionEventsTarget]]]: + default_dispatch = super()._accept_with(target, identifier) + if default_dispatch is None and hasattr( + target, "_no_async_engine_events" + ): + target._no_async_engine_events() + + return default_dispatch + + @classmethod + def _listen( + cls, + event_key: event._EventKey[ConnectionEventsTarget], + *, + retval: bool = False, + **kw: Any, + ) -> None: + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) + target._has_events = True + + if not retval: + if identifier == "before_execute": + orig_fn = fn + + def wrap_before_execute( # type: ignore + conn, clauseelement, multiparams, params, execution_options + ): + orig_fn( + conn, + clauseelement, + multiparams, + params, + execution_options, + ) + return clauseelement, multiparams, params + + fn = wrap_before_execute + elif identifier == "before_cursor_execute": + orig_fn = fn + + def wrap_before_cursor_execute( # type: ignore + conn, cursor, statement, parameters, context, executemany + ): + orig_fn( + conn, + cursor, + statement, + parameters, + context, + executemany, + ) + return statement, parameters + + fn = wrap_before_cursor_execute + elif retval and identifier not in ( + "before_execute", + "before_cursor_execute", + ): + raise exc.ArgumentError( + "Only the 'before_execute', " + "'before_cursor_execute' and 'handle_error' engine " + "event listeners accept the 'retval=True' " + "argument." + ) + event_key.with_wrapper(fn).base_listen() + + @event._legacy_signature( + "1.4", + ["conn", "clauseelement", "multiparams", "params"], + lambda conn, clauseelement, multiparams, params, execution_options: ( + conn, + clauseelement, + multiparams, + params, + ), + ) + def before_execute( + self, + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + ) -> Optional[ + Tuple[Executable, _CoreMultiExecuteParams, _CoreSingleExecuteParams] + ]: + """Intercept high level execute() events, receiving uncompiled + SQL constructs and other objects prior to rendering into SQL. + + This event is good for debugging SQL compilation issues as well + as early manipulation of the parameters being sent to the database, + as the parameter lists will be in a consistent format here. + + This event can be optionally established with the ``retval=True`` + flag. The ``clauseelement``, ``multiparams``, and ``params`` + arguments should be returned as a three-tuple in this case:: + + @event.listens_for(Engine, "before_execute", retval=True) + def before_execute(conn, clauseelement, multiparams, params): + # do something with clauseelement, multiparams, params + return clauseelement, multiparams, params + + :param conn: :class:`_engine.Connection` object + :param clauseelement: SQL expression construct, :class:`.Compiled` + instance, or string statement passed to + :meth:`_engine.Connection.execute`. + :param multiparams: Multiple parameter sets, a list of dictionaries. + :param params: Single parameter set, a single dictionary. + :param execution_options: dictionary of execution + options passed along with the statement, if any. This is a merge + of all options that will be used, including those of the statement, + the connection, and those passed in to the method itself for + the 2.0 style of execution. + + .. versionadded: 1.4 + + .. seealso:: + + :meth:`.before_cursor_execute` + + """ + + @event._legacy_signature( + "1.4", + ["conn", "clauseelement", "multiparams", "params", "result"], + lambda conn, clauseelement, multiparams, params, execution_options, result: ( # noqa + conn, + clauseelement, + multiparams, + params, + result, + ), + ) + def after_execute( + self, + conn: Connection, + clauseelement: Executable, + multiparams: _CoreMultiExecuteParams, + params: _CoreSingleExecuteParams, + execution_options: _ExecuteOptions, + result: Result[Any], + ) -> None: + """Intercept high level execute() events after execute. + + + :param conn: :class:`_engine.Connection` object + :param clauseelement: SQL expression construct, :class:`.Compiled` + instance, or string statement passed to + :meth:`_engine.Connection.execute`. + :param multiparams: Multiple parameter sets, a list of dictionaries. + :param params: Single parameter set, a single dictionary. + :param execution_options: dictionary of execution + options passed along with the statement, if any. This is a merge + of all options that will be used, including those of the statement, + the connection, and those passed in to the method itself for + the 2.0 style of execution. + + .. versionadded: 1.4 + + :param result: :class:`_engine.CursorResult` generated by the + execution. + + """ + + def before_cursor_execute( + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> Optional[Tuple[str, _DBAPIAnyExecuteParams]]: + """Intercept low-level cursor execute() events before execution, + receiving the string SQL statement and DBAPI-specific parameter list to + be invoked against a cursor. + + This event is a good choice for logging as well as late modifications + to the SQL string. It's less ideal for parameter modifications except + for those which are specific to a target backend. + + This event can be optionally established with the ``retval=True`` + flag. The ``statement`` and ``parameters`` arguments should be + returned as a two-tuple in this case:: + + @event.listens_for(Engine, "before_cursor_execute", retval=True) + def before_cursor_execute(conn, cursor, statement, + parameters, context, executemany): + # do something with statement, parameters + return statement, parameters + + See the example at :class:`_events.ConnectionEvents`. + + :param conn: :class:`_engine.Connection` object + :param cursor: DBAPI cursor object + :param statement: string SQL statement, as to be passed to the DBAPI + :param parameters: Dictionary, tuple, or list of parameters being + passed to the ``execute()`` or ``executemany()`` method of the + DBAPI ``cursor``. In some cases may be ``None``. + :param context: :class:`.ExecutionContext` object in use. May + be ``None``. + :param executemany: boolean, if ``True``, this is an ``executemany()`` + call, if ``False``, this is an ``execute()`` call. + + .. seealso:: + + :meth:`.before_execute` + + :meth:`.after_cursor_execute` + + """ + + def after_cursor_execute( + self, + conn: Connection, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: Optional[ExecutionContext], + executemany: bool, + ) -> None: + """Intercept low-level cursor execute() events after execution. + + :param conn: :class:`_engine.Connection` object + :param cursor: DBAPI cursor object. Will have results pending + if the statement was a SELECT, but these should not be consumed + as they will be needed by the :class:`_engine.CursorResult`. + :param statement: string SQL statement, as passed to the DBAPI + :param parameters: Dictionary, tuple, or list of parameters being + passed to the ``execute()`` or ``executemany()`` method of the + DBAPI ``cursor``. In some cases may be ``None``. + :param context: :class:`.ExecutionContext` object in use. May + be ``None``. + :param executemany: boolean, if ``True``, this is an ``executemany()`` + call, if ``False``, this is an ``execute()`` call. + + """ + + @event._legacy_signature( + "2.0", ["conn", "branch"], converter=lambda conn: (conn, False) + ) + def engine_connect(self, conn: Connection) -> None: + """Intercept the creation of a new :class:`_engine.Connection`. + + This event is called typically as the direct result of calling + the :meth:`_engine.Engine.connect` method. + + It differs from the :meth:`_events.PoolEvents.connect` method, which + refers to the actual connection to a database at the DBAPI level; + a DBAPI connection may be pooled and reused for many operations. + In contrast, this event refers only to the production of a higher level + :class:`_engine.Connection` wrapper around such a DBAPI connection. + + It also differs from the :meth:`_events.PoolEvents.checkout` event + in that it is specific to the :class:`_engine.Connection` object, + not the + DBAPI connection that :meth:`_events.PoolEvents.checkout` deals with, + although + this DBAPI connection is available here via the + :attr:`_engine.Connection.connection` attribute. + But note there can in fact + be multiple :meth:`_events.PoolEvents.checkout` + events within the lifespan + of a single :class:`_engine.Connection` object, if that + :class:`_engine.Connection` + is invalidated and re-established. + + :param conn: :class:`_engine.Connection` object. + + .. seealso:: + + :meth:`_events.PoolEvents.checkout` + the lower-level pool checkout event + for an individual DBAPI connection + + """ + + def set_connection_execution_options( + self, conn: Connection, opts: Dict[str, Any] + ) -> None: + """Intercept when the :meth:`_engine.Connection.execution_options` + method is called. + + This method is called after the new :class:`_engine.Connection` + has been + produced, with the newly updated execution options collection, but + before the :class:`.Dialect` has acted upon any of those new options. + + Note that this method is not called when a new + :class:`_engine.Connection` + is produced which is inheriting execution options from its parent + :class:`_engine.Engine`; to intercept this condition, use the + :meth:`_events.ConnectionEvents.engine_connect` event. + + :param conn: The newly copied :class:`_engine.Connection` object + + :param opts: dictionary of options that were passed to the + :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. + + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. + + + .. seealso:: + + :meth:`_events.ConnectionEvents.set_engine_execution_options` + - event + which is called when :meth:`_engine.Engine.execution_options` + is called. + + + """ + + def set_engine_execution_options( + self, engine: Engine, opts: Dict[str, Any] + ) -> None: + """Intercept when the :meth:`_engine.Engine.execution_options` + method is called. + + The :meth:`_engine.Engine.execution_options` method produces a shallow + copy of the :class:`_engine.Engine` which stores the new options. + That new + :class:`_engine.Engine` is passed here. + A particular application of this + method is to add a :meth:`_events.ConnectionEvents.engine_connect` + event + handler to the given :class:`_engine.Engine` + which will perform some per- + :class:`_engine.Connection` task specific to these execution options. + + :param conn: The newly copied :class:`_engine.Engine` object + + :param opts: dictionary of options that were passed to the + :meth:`_engine.Connection.execution_options` method. + This dictionary may be modified in place to affect the ultimate + options which take effect. + + .. versionadded:: 2.0 the ``opts`` dictionary may be modified + in place. + + .. seealso:: + + :meth:`_events.ConnectionEvents.set_connection_execution_options` + - event + which is called when :meth:`_engine.Connection.execution_options` + is + called. + + """ + + def engine_disposed(self, engine: Engine) -> None: + """Intercept when the :meth:`_engine.Engine.dispose` method is called. + + The :meth:`_engine.Engine.dispose` method instructs the engine to + "dispose" of it's connection pool (e.g. :class:`_pool.Pool`), and + replaces it with a new one. Disposing of the old pool has the + effect that existing checked-in connections are closed. The new + pool does not establish any new connections until it is first used. + + This event can be used to indicate that resources related to the + :class:`_engine.Engine` should also be cleaned up, + keeping in mind that the + :class:`_engine.Engine` + can still be used for new requests in which case + it re-acquires connection resources. + + """ + + def begin(self, conn: Connection) -> None: + """Intercept begin() events. + + :param conn: :class:`_engine.Connection` object + + """ + + def rollback(self, conn: Connection) -> None: + """Intercept rollback() events, as initiated by a + :class:`.Transaction`. + + Note that the :class:`_pool.Pool` also "auto-rolls back" + a DBAPI connection upon checkin, if the ``reset_on_return`` + flag is set to its default value of ``'rollback'``. + To intercept this + rollback, use the :meth:`_events.PoolEvents.reset` hook. + + :param conn: :class:`_engine.Connection` object + + .. seealso:: + + :meth:`_events.PoolEvents.reset` + + """ + + def commit(self, conn: Connection) -> None: + """Intercept commit() events, as initiated by a + :class:`.Transaction`. + + Note that the :class:`_pool.Pool` may also "auto-commit" + a DBAPI connection upon checkin, if the ``reset_on_return`` + flag is set to the value ``'commit'``. To intercept this + commit, use the :meth:`_events.PoolEvents.reset` hook. + + :param conn: :class:`_engine.Connection` object + """ + + def savepoint(self, conn: Connection, name: str) -> None: + """Intercept savepoint() events. + + :param conn: :class:`_engine.Connection` object + :param name: specified name used for the savepoint. + + """ + + def rollback_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: + """Intercept rollback_savepoint() events. + + :param conn: :class:`_engine.Connection` object + :param name: specified name used for the savepoint. + :param context: not used + + """ + # TODO: deprecate "context" + + def release_savepoint( + self, conn: Connection, name: str, context: None + ) -> None: + """Intercept release_savepoint() events. + + :param conn: :class:`_engine.Connection` object + :param name: specified name used for the savepoint. + :param context: not used + + """ + # TODO: deprecate "context" + + def begin_twophase(self, conn: Connection, xid: Any) -> None: + """Intercept begin_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + + """ + + def prepare_twophase(self, conn: Connection, xid: Any) -> None: + """Intercept prepare_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + """ + + def rollback_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: + """Intercept rollback_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + :param is_prepared: boolean, indicates if + :meth:`.TwoPhaseTransaction.prepare` was called. + + """ + + def commit_twophase( + self, conn: Connection, xid: Any, is_prepared: bool + ) -> None: + """Intercept commit_twophase() events. + + :param conn: :class:`_engine.Connection` object + :param xid: two-phase XID identifier + :param is_prepared: boolean, indicates if + :meth:`.TwoPhaseTransaction.prepare` was called. + + """ + + +class DialectEvents(event.Events[Dialect]): + """event interface for execution-replacement functions. + + These events allow direct instrumentation and replacement + of key dialect functions which interact with the DBAPI. + + .. note:: + + :class:`.DialectEvents` hooks should be considered **semi-public** + and experimental. + These hooks are not for general use and are only for those situations + where intricate re-statement of DBAPI mechanics must be injected onto + an existing dialect. For general-use statement-interception events, + please use the :class:`_events.ConnectionEvents` interface. + + .. seealso:: + + :meth:`_events.ConnectionEvents.before_cursor_execute` + + :meth:`_events.ConnectionEvents.before_execute` + + :meth:`_events.ConnectionEvents.after_cursor_execute` + + :meth:`_events.ConnectionEvents.after_execute` + + """ + + _target_class_doc = "SomeEngine" + _dispatch_target = Dialect + + @classmethod + def _listen( + cls, + event_key: event._EventKey[Dialect], + *, + retval: bool = False, + **kw: Any, + ) -> None: + target = event_key.dispatch_target + + target._has_events = True + event_key.base_listen() + + @classmethod + def _accept_with( + cls, + target: Union[Engine, Type[Engine], Dialect, Type[Dialect]], + identifier: str, + ) -> Optional[Union[Dialect, Type[Dialect]]]: + if isinstance(target, type): + if issubclass(target, Engine): + return Dialect + elif issubclass(target, Dialect): + return target + elif isinstance(target, Engine): + return target.dialect + elif isinstance(target, Dialect): + return target + elif isinstance(target, Connection) and identifier == "handle_error": + raise exc.InvalidRequestError( + "The handle_error() event hook as of SQLAlchemy 2.0 is " + "established on the Dialect, and may only be applied to the " + "Engine as a whole or to a specific Dialect as a whole, " + "not on a per-Connection basis." + ) + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() + else: + return None + + def handle_error( + self, exception_context: ExceptionContext + ) -> Optional[BaseException]: + r"""Intercept all exceptions processed by the + :class:`_engine.Dialect`, typically but not limited to those + emitted within the scope of a :class:`_engine.Connection`. + + .. versionchanged:: 2.0 the :meth:`.DialectEvents.handle_error` event + is moved to the :class:`.DialectEvents` class, moved from the + :class:`.ConnectionEvents` class, so that it may also participate in + the "pre ping" operation configured with the + :paramref:`_sa.create_engine.pool_pre_ping` parameter. The event + remains registered by using the :class:`_engine.Engine` as the event + target, however note that using the :class:`_engine.Connection` as + an event target for :meth:`.DialectEvents.handle_error` is no longer + supported. + + This includes all exceptions emitted by the DBAPI as well as + within SQLAlchemy's statement invocation process, including + encoding errors and other statement validation errors. Other areas + in which the event is invoked include transaction begin and end, + result row fetching, cursor creation. + + Note that :meth:`.handle_error` may support new kinds of exceptions + and new calling scenarios at *any time*. Code which uses this + event must expect new calling patterns to be present in minor + releases. + + To support the wide variety of members that correspond to an exception, + as well as to allow extensibility of the event without backwards + incompatibility, the sole argument received is an instance of + :class:`.ExceptionContext`. This object contains data members + representing detail about the exception. + + Use cases supported by this hook include: + + * read-only, low-level exception handling for logging and + debugging purposes + * Establishing whether a DBAPI connection error message indicates + that the database connection needs to be reconnected, including + for the "pre_ping" handler used by **some** dialects + * Establishing or disabling whether a connection or the owning + connection pool is invalidated or expired in response to a + specific exception + * exception re-writing + + The hook is called while the cursor from the failed operation + (if any) is still open and accessible. Special cleanup operations + can be called on this cursor; SQLAlchemy will attempt to close + this cursor subsequent to this hook being invoked. + + As of SQLAlchemy 2.0, the "pre_ping" handler enabled using the + :paramref:`_sa.create_engine.pool_pre_ping` parameter will also + participate in the :meth:`.handle_error` process, **for those dialects + that rely upon disconnect codes to detect database liveness**. Note + that some dialects such as psycopg, psycopg2, and most MySQL dialects + make use of a native ``ping()`` method supplied by the DBAPI which does + not make use of disconnect codes. + + .. versionchanged:: 2.0.0 The :meth:`.DialectEvents.handle_error` + event hook participates in connection pool "pre-ping" operations. + Within this usage, the :attr:`.ExceptionContext.engine` attribute + will be ``None``, however the :class:`.Dialect` in use is always + available via the :attr:`.ExceptionContext.dialect` attribute. + + .. versionchanged:: 2.0.5 Added :attr:`.ExceptionContext.is_pre_ping` + attribute which will be set to ``True`` when the + :meth:`.DialectEvents.handle_error` event hook is triggered within + a connection pool pre-ping operation. + + .. versionchanged:: 2.0.5 An issue was repaired that allows for the + PostgreSQL ``psycopg`` and ``psycopg2`` drivers, as well as all + MySQL drivers, to properly participate in the + :meth:`.DialectEvents.handle_error` event hook during + connection pool "pre-ping" operations; previously, the + implementation was non-working for these drivers. + + + A handler function has two options for replacing + the SQLAlchemy-constructed exception into one that is user + defined. It can either raise this new exception directly, in + which case all further event listeners are bypassed and the + exception will be raised, after appropriate cleanup as taken + place:: + + @event.listens_for(Engine, "handle_error") + def handle_exception(context): + if isinstance(context.original_exception, + psycopg2.OperationalError) and \ + "failed" in str(context.original_exception): + raise MySpecialException("failed operation") + + .. warning:: Because the + :meth:`_events.DialectEvents.handle_error` + event specifically provides for exceptions to be re-thrown as + the ultimate exception raised by the failed statement, + **stack traces will be misleading** if the user-defined event + handler itself fails and throws an unexpected exception; + the stack trace may not illustrate the actual code line that + failed! It is advised to code carefully here and use + logging and/or inline debugging if unexpected exceptions are + occurring. + + Alternatively, a "chained" style of event handling can be + used, by configuring the handler with the ``retval=True`` + modifier and returning the new exception instance from the + function. In this case, event handling will continue onto the + next handler. The "chained" exception is available using + :attr:`.ExceptionContext.chained_exception`:: + + @event.listens_for(Engine, "handle_error", retval=True) + def handle_exception(context): + if context.chained_exception is not None and \ + "special" in context.chained_exception.message: + return MySpecialException("failed", + cause=context.chained_exception) + + Handlers that return ``None`` may be used within the chain; when + a handler returns ``None``, the previous exception instance, + if any, is maintained as the current exception that is passed onto the + next handler. + + When a custom exception is raised or returned, SQLAlchemy raises + this new exception as-is, it is not wrapped by any SQLAlchemy + object. If the exception is not a subclass of + :class:`sqlalchemy.exc.StatementError`, + certain features may not be available; currently this includes + the ORM's feature of adding a detail hint about "autoflush" to + exceptions raised within the autoflush process. + + :param context: an :class:`.ExceptionContext` object. See this + class for details on all available members. + + + .. seealso:: + + :ref:`pool_new_disconnect_codes` + + """ + + def do_connect( + self, + dialect: Dialect, + conn_rec: ConnectionPoolEntry, + cargs: Tuple[Any, ...], + cparams: Dict[str, Any], + ) -> Optional[DBAPIConnection]: + """Receive connection arguments before a connection is made. + + This event is useful in that it allows the handler to manipulate the + cargs and/or cparams collections that control how the DBAPI + ``connect()`` function will be called. ``cargs`` will always be a + Python list that can be mutated in-place, and ``cparams`` a Python + dictionary that may also be mutated:: + + e = create_engine("postgresql+psycopg2://user@host/dbname") + + @event.listens_for(e, 'do_connect') + def receive_do_connect(dialect, conn_rec, cargs, cparams): + cparams["password"] = "some_password" + + The event hook may also be used to override the call to ``connect()`` + entirely, by returning a non-``None`` DBAPI connection object:: + + e = create_engine("postgresql+psycopg2://user@host/dbname") + + @event.listens_for(e, 'do_connect') + def receive_do_connect(dialect, conn_rec, cargs, cparams): + return psycopg2.connect(*cargs, **cparams) + + .. seealso:: + + :ref:`custom_dbapi_args` + + """ + + def do_executemany( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: + """Receive a cursor to have executemany() called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_execute_no_params( + self, cursor: DBAPICursor, statement: str, context: ExecutionContext + ) -> Optional[Literal[True]]: + """Receive a cursor to have execute() with no parameters called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPISingleExecuteParams, + context: ExecutionContext, + ) -> Optional[Literal[True]]: + """Receive a cursor to have execute() called. + + Return the value True to halt further events from invoking, + and to indicate that the cursor execution has already taken + place within the event handler. + + """ + + def do_setinputsizes( + self, + inputsizes: Dict[BindParameter[Any], Any], + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIAnyExecuteParams, + context: ExecutionContext, + ) -> None: + """Receive the setinputsizes dictionary for possible modification. + + This event is emitted in the case where the dialect makes use of the + DBAPI ``cursor.setinputsizes()`` method which passes information about + parameter binding for a particular statement. The given + ``inputsizes`` dictionary will contain :class:`.BindParameter` objects + as keys, linked to DBAPI-specific type objects as values; for + parameters that are not bound, they are added to the dictionary with + ``None`` as the value, which means the parameter will not be included + in the ultimate setinputsizes call. The event may be used to inspect + and/or log the datatypes that are being bound, as well as to modify the + dictionary in place. Parameters can be added, modified, or removed + from this dictionary. Callers will typically want to inspect the + :attr:`.BindParameter.type` attribute of the given bind objects in + order to make decisions about the DBAPI object. + + After the event, the ``inputsizes`` dictionary is converted into + an appropriate datastructure to be passed to ``cursor.setinputsizes``; + either a list for a positional bound parameter execution style, + or a dictionary of string parameter keys to DBAPI type objects for + a named bound parameter execution style. + + The setinputsizes hook overall is only used for dialects which include + the flag ``use_setinputsizes=True``. Dialects which use this + include cx_Oracle, pg8000, asyncpg, and pyodbc dialects. + + .. note:: + + For use with pyodbc, the ``use_setinputsizes`` flag + must be passed to the dialect, e.g.:: + + create_engine("mssql+pyodbc://...", use_setinputsizes=True) + + .. seealso:: + + :ref:`mssql_pyodbc_setinputsizes` + + .. versionadded:: 1.2.9 + + .. seealso:: + + :ref:`cx_oracle_setinputsizes` + + """ + pass diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/interfaces.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/interfaces.py new file mode 100644 index 0000000..d1657b8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/interfaces.py @@ -0,0 +1,3395 @@ +# engine/interfaces.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Define core interfaces used by the engine system.""" + +from __future__ import annotations + +from enum import Enum +from types import ModuleType +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import ClassVar +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..event import EventTarget +from ..pool import Pool +from ..pool import PoolProxiedConnection +from ..sql.compiler import Compiled as Compiled +from ..sql.compiler import Compiled # noqa +from ..sql.compiler import TypeCompiler as TypeCompiler +from ..sql.compiler import TypeCompiler # noqa +from ..util import immutabledict +from ..util.concurrency import await_only +from ..util.typing import Literal +from ..util.typing import NotRequired +from ..util.typing import Protocol +from ..util.typing import TypedDict + +if TYPE_CHECKING: + from .base import Connection + from .base import Engine + from .cursor import CursorResult + from .url import URL + from ..event import _ListenerFnType + from ..event import dispatcher + from ..exc import StatementError + from ..sql import Executable + from ..sql.compiler import _InsertManyValuesBatch + from ..sql.compiler import DDLCompiler + from ..sql.compiler import IdentifierPreparer + from ..sql.compiler import InsertmanyvaluesSentinelOpts + from ..sql.compiler import Linting + from ..sql.compiler import SQLCompiler + from ..sql.elements import BindParameter + from ..sql.elements import ClauseElement + from ..sql.schema import Column + from ..sql.schema import DefaultGenerator + from ..sql.schema import SchemaItem + from ..sql.schema import Sequence as Sequence_SchemaItem + from ..sql.sqltypes import Integer + from ..sql.type_api import _TypeMemoDict + from ..sql.type_api import TypeEngine + +ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]] + +_T = TypeVar("_T", bound="Any") + + +class CacheStats(Enum): + CACHE_HIT = 0 + CACHE_MISS = 1 + CACHING_DISABLED = 2 + NO_CACHE_KEY = 3 + NO_DIALECT_SUPPORT = 4 + + +class ExecuteStyle(Enum): + """indicates the :term:`DBAPI` cursor method that will be used to invoke + a statement.""" + + EXECUTE = 0 + """indicates cursor.execute() will be used""" + + EXECUTEMANY = 1 + """indicates cursor.executemany() will be used.""" + + INSERTMANYVALUES = 2 + """indicates cursor.execute() will be used with an INSERT where the + VALUES expression will be expanded to accommodate for multiple + parameter sets + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + """ + + +class DBAPIConnection(Protocol): + """protocol representing a :pep:`249` database connection. + + .. versionadded:: 2.0 + + .. seealso:: + + `Connection Objects `_ + - in :pep:`249` + + """ # noqa: E501 + + def close(self) -> None: ... + + def commit(self) -> None: ... + + def cursor(self) -> DBAPICursor: ... + + def rollback(self) -> None: ... + + autocommit: bool + + +class DBAPIType(Protocol): + """protocol representing a :pep:`249` database type. + + .. versionadded:: 2.0 + + .. seealso:: + + `Type Objects `_ + - in :pep:`249` + + """ # noqa: E501 + + +class DBAPICursor(Protocol): + """protocol representing a :pep:`249` database cursor. + + .. versionadded:: 2.0 + + .. seealso:: + + `Cursor Objects `_ + - in :pep:`249` + + """ # noqa: E501 + + @property + def description( + self, + ) -> _DBAPICursorDescription: + """The description attribute of the Cursor. + + .. seealso:: + + `cursor.description `_ + - in :pep:`249` + + + """ # noqa: E501 + ... + + @property + def rowcount(self) -> int: ... + + arraysize: int + + lastrowid: int + + def close(self) -> None: ... + + def execute( + self, + operation: Any, + parameters: Optional[_DBAPISingleExecuteParams] = None, + ) -> Any: ... + + def executemany( + self, + operation: Any, + parameters: _DBAPIMultiExecuteParams, + ) -> Any: ... + + def fetchone(self) -> Optional[Any]: ... + + def fetchmany(self, size: int = ...) -> Sequence[Any]: ... + + def fetchall(self) -> Sequence[Any]: ... + + def setinputsizes(self, sizes: Sequence[Any]) -> None: ... + + def setoutputsize(self, size: Any, column: Any) -> None: ... + + def callproc( + self, procname: str, parameters: Sequence[Any] = ... + ) -> Any: ... + + def nextset(self) -> Optional[bool]: ... + + def __getattr__(self, key: str) -> Any: ... + + +_CoreSingleExecuteParams = Mapping[str, Any] +_MutableCoreSingleExecuteParams = MutableMapping[str, Any] +_CoreMultiExecuteParams = Sequence[_CoreSingleExecuteParams] +_CoreAnyExecuteParams = Union[ + _CoreMultiExecuteParams, _CoreSingleExecuteParams +] + +_DBAPISingleExecuteParams = Union[Sequence[Any], _CoreSingleExecuteParams] + +_DBAPIMultiExecuteParams = Union[ + Sequence[Sequence[Any]], _CoreMultiExecuteParams +] +_DBAPIAnyExecuteParams = Union[ + _DBAPIMultiExecuteParams, _DBAPISingleExecuteParams +] +_DBAPICursorDescription = Sequence[ + Tuple[ + str, + "DBAPIType", + Optional[int], + Optional[int], + Optional[int], + Optional[int], + Optional[bool], + ] +] + +_AnySingleExecuteParams = _DBAPISingleExecuteParams +_AnyMultiExecuteParams = _DBAPIMultiExecuteParams +_AnyExecuteParams = _DBAPIAnyExecuteParams + +CompiledCacheType = MutableMapping[Any, "Compiled"] +SchemaTranslateMapType = Mapping[Optional[str], Optional[str]] + +_ImmutableExecuteOptions = immutabledict[str, Any] + +_ParamStyle = Literal[ + "qmark", "numeric", "named", "format", "pyformat", "numeric_dollar" +] + +_GenericSetInputSizesType = List[Tuple[str, Any, "TypeEngine[Any]"]] + +IsolationLevel = Literal[ + "SERIALIZABLE", + "REPEATABLE READ", + "READ COMMITTED", + "READ UNCOMMITTED", + "AUTOCOMMIT", +] + + +class _CoreKnownExecutionOptions(TypedDict, total=False): + compiled_cache: Optional[CompiledCacheType] + logging_token: str + isolation_level: IsolationLevel + no_parameters: bool + stream_results: bool + max_row_buffer: int + yield_per: int + insertmanyvalues_page_size: int + schema_translate_map: Optional[SchemaTranslateMapType] + preserve_rowcount: bool + + +_ExecuteOptions = immutabledict[str, Any] +CoreExecuteOptionsParameter = Union[ + _CoreKnownExecutionOptions, Mapping[str, Any] +] + + +class ReflectedIdentity(TypedDict): + """represent the reflected IDENTITY structure of a column, corresponding + to the :class:`_schema.Identity` construct. + + The :class:`.ReflectedIdentity` structure is part of the + :class:`.ReflectedColumn` structure, which is returned by the + :meth:`.Inspector.get_columns` method. + + """ + + always: bool + """type of identity column""" + + on_null: bool + """indicates ON NULL""" + + start: int + """starting index of the sequence""" + + increment: int + """increment value of the sequence""" + + minvalue: int + """the minimum value of the sequence.""" + + maxvalue: int + """the maximum value of the sequence.""" + + nominvalue: bool + """no minimum value of the sequence.""" + + nomaxvalue: bool + """no maximum value of the sequence.""" + + cycle: bool + """allows the sequence to wrap around when the maxvalue + or minvalue has been reached.""" + + cache: Optional[int] + """number of future values in the + sequence which are calculated in advance.""" + + order: bool + """if true, renders the ORDER keyword.""" + + +class ReflectedComputed(TypedDict): + """Represent the reflected elements of a computed column, corresponding + to the :class:`_schema.Computed` construct. + + The :class:`.ReflectedComputed` structure is part of the + :class:`.ReflectedColumn` structure, which is returned by the + :meth:`.Inspector.get_columns` method. + + """ + + sqltext: str + """the expression used to generate this column returned + as a string SQL expression""" + + persisted: NotRequired[bool] + """indicates if the value is stored in the table or computed on demand""" + + +class ReflectedColumn(TypedDict): + """Dictionary representing the reflected elements corresponding to + a :class:`_schema.Column` object. + + The :class:`.ReflectedColumn` structure is returned by the + :class:`.Inspector.get_columns` method. + + """ + + name: str + """column name""" + + type: TypeEngine[Any] + """column type represented as a :class:`.TypeEngine` instance.""" + + nullable: bool + """boolean flag if the column is NULL or NOT NULL""" + + default: Optional[str] + """column default expression as a SQL string""" + + autoincrement: NotRequired[bool] + """database-dependent autoincrement flag. + + This flag indicates if the column has a database-side "autoincrement" + flag of some kind. Within SQLAlchemy, other kinds of columns may + also act as an "autoincrement" column without necessarily having + such a flag on them. + + See :paramref:`_schema.Column.autoincrement` for more background on + "autoincrement". + + """ + + comment: NotRequired[Optional[str]] + """comment for the column, if present. + Only some dialects return this key + """ + + computed: NotRequired[ReflectedComputed] + """indicates that this column is computed by the database. + Only some dialects return this key. + + .. versionadded:: 1.3.16 - added support for computed reflection. + """ + + identity: NotRequired[ReflectedIdentity] + """indicates this column is an IDENTITY column. + Only some dialects return this key. + + .. versionadded:: 1.4 - added support for identity column reflection. + """ + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this reflected + object""" + + +class ReflectedConstraint(TypedDict): + """Dictionary representing the reflected elements corresponding to + :class:`.Constraint` + + A base class for all constraints + """ + + name: Optional[str] + """constraint name""" + + comment: NotRequired[Optional[str]] + """comment for the constraint, if present""" + + +class ReflectedCheckConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.CheckConstraint`. + + The :class:`.ReflectedCheckConstraint` structure is returned by the + :meth:`.Inspector.get_check_constraints` method. + + """ + + sqltext: str + """the check constraint's SQL expression""" + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this check constraint + + .. versionadded:: 1.3.8 + """ + + +class ReflectedUniqueConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.UniqueConstraint`. + + The :class:`.ReflectedUniqueConstraint` structure is returned by the + :meth:`.Inspector.get_unique_constraints` method. + + """ + + column_names: List[str] + """column names which comprise the unique constraint""" + + duplicates_index: NotRequired[Optional[str]] + "Indicates if this unique constraint duplicates an index with this name" + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this unique + constraint""" + + +class ReflectedPrimaryKeyConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.PrimaryKeyConstraint`. + + The :class:`.ReflectedPrimaryKeyConstraint` structure is returned by the + :meth:`.Inspector.get_pk_constraint` method. + + """ + + constrained_columns: List[str] + """column names which comprise the primary key""" + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this primary key""" + + +class ReflectedForeignKeyConstraint(ReflectedConstraint): + """Dictionary representing the reflected elements corresponding to + :class:`.ForeignKeyConstraint`. + + The :class:`.ReflectedForeignKeyConstraint` structure is returned by + the :meth:`.Inspector.get_foreign_keys` method. + + """ + + constrained_columns: List[str] + """local column names which comprise the foreign key""" + + referred_schema: Optional[str] + """schema name of the table being referred""" + + referred_table: str + """name of the table being referred""" + + referred_columns: List[str] + """referred column names that correspond to ``constrained_columns``""" + + options: NotRequired[Dict[str, Any]] + """Additional options detected for this foreign key constraint""" + + +class ReflectedIndex(TypedDict): + """Dictionary representing the reflected elements corresponding to + :class:`.Index`. + + The :class:`.ReflectedIndex` structure is returned by the + :meth:`.Inspector.get_indexes` method. + + """ + + name: Optional[str] + """index name""" + + column_names: List[Optional[str]] + """column names which the index references. + An element of this list is ``None`` if it's an expression and is + returned in the ``expressions`` list. + """ + + expressions: NotRequired[List[str]] + """Expressions that compose the index. This list, when present, contains + both plain column names (that are also in ``column_names``) and + expressions (that are ``None`` in ``column_names``). + """ + + unique: bool + """whether or not the index has a unique flag""" + + duplicates_constraint: NotRequired[Optional[str]] + "Indicates if this index mirrors a constraint with this name" + + include_columns: NotRequired[List[str]] + """columns to include in the INCLUDE clause for supporting databases. + + .. deprecated:: 2.0 + + Legacy value, will be replaced with + ``index_dict["dialect_options"]["_include"]`` + + """ + + column_sorting: NotRequired[Dict[str, Tuple[str]]] + """optional dict mapping column names or expressions to tuple of sort + keywords, which may include ``asc``, ``desc``, ``nulls_first``, + ``nulls_last``. + + .. versionadded:: 1.3.5 + """ + + dialect_options: NotRequired[Dict[str, Any]] + """Additional dialect-specific options detected for this index""" + + +class ReflectedTableComment(TypedDict): + """Dictionary representing the reflected comment corresponding to + the :attr:`_schema.Table.comment` attribute. + + The :class:`.ReflectedTableComment` structure is returned by the + :meth:`.Inspector.get_table_comment` method. + + """ + + text: Optional[str] + """text of the comment""" + + +class BindTyping(Enum): + """Define different methods of passing typing information for + bound parameters in a statement to the database driver. + + .. versionadded:: 2.0 + + """ + + NONE = 1 + """No steps are taken to pass typing information to the database driver. + + This is the default behavior for databases such as SQLite, MySQL / MariaDB, + SQL Server. + + """ + + SETINPUTSIZES = 2 + """Use the pep-249 setinputsizes method. + + This is only implemented for DBAPIs that support this method and for which + the SQLAlchemy dialect has the appropriate infrastructure for that + dialect set up. Current dialects include cx_Oracle as well as + optional support for SQL Server using pyodbc. + + When using setinputsizes, dialects also have a means of only using the + method for certain datatypes using include/exclude lists. + + When SETINPUTSIZES is used, the :meth:`.Dialect.do_set_input_sizes` method + is called for each statement executed which has bound parameters. + + """ + + RENDER_CASTS = 3 + """Render casts or other directives in the SQL string. + + This method is used for all PostgreSQL dialects, including asyncpg, + pg8000, psycopg, psycopg2. Dialects which implement this can choose + which kinds of datatypes are explicitly cast in SQL statements and which + aren't. + + When RENDER_CASTS is used, the compiler will invoke the + :meth:`.SQLCompiler.render_bind_cast` method for the rendered + string representation of each :class:`.BindParameter` object whose + dialect-level type sets the :attr:`.TypeEngine.render_bind_cast` attribute. + + The :meth:`.SQLCompiler.render_bind_cast` is also used to render casts + for one form of "insertmanyvalues" query, when both + :attr:`.InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT` and + :attr:`.InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS` are set, + where the casts are applied to the intermediary columns e.g. + "INSERT INTO t (a, b, c) SELECT p0::TYP, p1::TYP, p2::TYP " + "FROM (VALUES (?, ?), (?, ?), ...)". + + .. versionadded:: 2.0.10 - :meth:`.SQLCompiler.render_bind_cast` is now + used within some elements of the "insertmanyvalues" implementation. + + + """ + + +VersionInfoType = Tuple[Union[int, str], ...] +TableKey = Tuple[Optional[str], str] + + +class Dialect(EventTarget): + """Define the behavior of a specific database and DB-API combination. + + Any aspect of metadata definition, SQL query generation, + execution, result-set handling, or anything else which varies + between databases is defined under the general category of the + Dialect. The Dialect acts as a factory for other + database-specific object implementations including + ExecutionContext, Compiled, DefaultGenerator, and TypeEngine. + + .. note:: Third party dialects should not subclass :class:`.Dialect` + directly. Instead, subclass :class:`.default.DefaultDialect` or + descendant class. + + """ + + CACHE_HIT = CacheStats.CACHE_HIT + CACHE_MISS = CacheStats.CACHE_MISS + CACHING_DISABLED = CacheStats.CACHING_DISABLED + NO_CACHE_KEY = CacheStats.NO_CACHE_KEY + NO_DIALECT_SUPPORT = CacheStats.NO_DIALECT_SUPPORT + + dispatch: dispatcher[Dialect] + + name: str + """identifying name for the dialect from a DBAPI-neutral point of view + (i.e. 'sqlite') + """ + + driver: str + """identifying name for the dialect's DBAPI""" + + dialect_description: str + + dbapi: Optional[ModuleType] + """A reference to the DBAPI module object itself. + + SQLAlchemy dialects import DBAPI modules using the classmethod + :meth:`.Dialect.import_dbapi`. The rationale is so that any dialect + module can be imported and used to generate SQL statements without the + need for the actual DBAPI driver to be installed. Only when an + :class:`.Engine` is constructed using :func:`.create_engine` does the + DBAPI get imported; at that point, the creation process will assign + the DBAPI module to this attribute. + + Dialects should therefore implement :meth:`.Dialect.import_dbapi` + which will import the necessary module and return it, and then refer + to ``self.dbapi`` in dialect code in order to refer to the DBAPI module + contents. + + .. versionchanged:: The :attr:`.Dialect.dbapi` attribute is exclusively + used as the per-:class:`.Dialect`-instance reference to the DBAPI + module. The previous not-fully-documented ``.Dialect.dbapi()`` + classmethod is deprecated and replaced by :meth:`.Dialect.import_dbapi`. + + """ + + @util.non_memoized_property + def loaded_dbapi(self) -> ModuleType: + """same as .dbapi, but is never None; will raise an error if no + DBAPI was set up. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + positional: bool + """True if the paramstyle for this Dialect is positional.""" + + paramstyle: str + """the paramstyle to be used (some DB-APIs support multiple + paramstyles). + """ + + compiler_linting: Linting + + statement_compiler: Type[SQLCompiler] + """a :class:`.Compiled` class used to compile SQL statements""" + + ddl_compiler: Type[DDLCompiler] + """a :class:`.Compiled` class used to compile DDL statements""" + + type_compiler_cls: ClassVar[Type[TypeCompiler]] + """a :class:`.Compiled` class used to compile SQL type objects + + .. versionadded:: 2.0 + + """ + + type_compiler_instance: TypeCompiler + """instance of a :class:`.Compiled` class used to compile SQL type + objects + + .. versionadded:: 2.0 + + """ + + type_compiler: Any + """legacy; this is a TypeCompiler class at the class level, a + TypeCompiler instance at the instance level. + + Refer to type_compiler_instance instead. + + """ + + preparer: Type[IdentifierPreparer] + """a :class:`.IdentifierPreparer` class used to + quote identifiers. + """ + + identifier_preparer: IdentifierPreparer + """This element will refer to an instance of :class:`.IdentifierPreparer` + once a :class:`.DefaultDialect` has been constructed. + + """ + + server_version_info: Optional[Tuple[Any, ...]] + """a tuple containing a version number for the DB backend in use. + + This value is only available for supporting dialects, and is + typically populated during the initial connection to the database. + """ + + default_schema_name: Optional[str] + """the name of the default schema. This value is only available for + supporting dialects, and is typically populated during the + initial connection to the database. + + """ + + # NOTE: this does not take into effect engine-level isolation level. + # not clear if this should be changed, seems like it should + default_isolation_level: Optional[IsolationLevel] + """the isolation that is implicitly present on new connections""" + + # create_engine() -> isolation_level currently goes here + _on_connect_isolation_level: Optional[IsolationLevel] + + execution_ctx_cls: Type[ExecutionContext] + """a :class:`.ExecutionContext` class used to handle statement execution""" + + execute_sequence_format: Union[ + Type[Tuple[Any, ...]], Type[Tuple[List[Any]]] + ] + """either the 'tuple' or 'list' type, depending on what cursor.execute() + accepts for the second argument (they vary).""" + + supports_alter: bool + """``True`` if the database supports ``ALTER TABLE`` - used only for + generating foreign key constraints in certain circumstances + """ + + max_identifier_length: int + """The maximum length of identifier names.""" + + supports_server_side_cursors: bool + """indicates if the dialect supports server side cursors""" + + server_side_cursors: bool + """deprecated; indicates if the dialect should attempt to use server + side cursors by default""" + + supports_sane_rowcount: bool + """Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements. + """ + + supports_sane_multi_rowcount: bool + """Indicate whether the dialect properly implements rowcount for + ``UPDATE`` and ``DELETE`` statements when executed via + executemany. + """ + + supports_empty_insert: bool + """dialect supports INSERT () VALUES (), i.e. a plain INSERT with no + columns in it. + + This is not usually supported; an "empty" insert is typically + suited using either "INSERT..DEFAULT VALUES" or + "INSERT ... (col) VALUES (DEFAULT)". + + """ + + supports_default_values: bool + """dialect supports INSERT... DEFAULT VALUES syntax""" + + supports_default_metavalue: bool + """dialect supports INSERT...(col) VALUES (DEFAULT) syntax. + + Most databases support this in some way, e.g. SQLite supports it using + ``VALUES (NULL)``. MS SQL Server supports the syntax also however + is the only included dialect where we have this disabled, as + MSSQL does not support the field for the IDENTITY column, which is + usually where we like to make use of the feature. + + """ + + default_metavalue_token: str = "DEFAULT" + """for INSERT... VALUES (DEFAULT) syntax, the token to put in the + parenthesis. + + E.g. for SQLite this is the keyword "NULL". + + """ + + supports_multivalues_insert: bool + """Target database supports INSERT...VALUES with multiple value + sets, i.e. INSERT INTO table (cols) VALUES (...), (...), (...), ... + + """ + + insert_executemany_returning: bool + """dialect / driver / database supports some means of providing + INSERT...RETURNING support when dialect.do_executemany() is used. + + """ + + insert_executemany_returning_sort_by_parameter_order: bool + """dialect / driver / database supports some means of providing + INSERT...RETURNING support when dialect.do_executemany() is used + along with the :paramref:`_dml.Insert.returning.sort_by_parameter_order` + parameter being set. + + """ + + update_executemany_returning: bool + """dialect supports UPDATE..RETURNING with executemany.""" + + delete_executemany_returning: bool + """dialect supports DELETE..RETURNING with executemany.""" + + use_insertmanyvalues: bool + """if True, indicates "insertmanyvalues" functionality should be used + to allow for ``insert_executemany_returning`` behavior, if possible. + + In practice, setting this to True means: + + if ``supports_multivalues_insert``, ``insert_returning`` and + ``use_insertmanyvalues`` are all True, the SQL compiler will produce + an INSERT that will be interpreted by the :class:`.DefaultDialect` + as an :attr:`.ExecuteStyle.INSERTMANYVALUES` execution that allows + for INSERT of many rows with RETURNING by rewriting a single-row + INSERT statement to have multiple VALUES clauses, also executing + the statement multiple times for a series of batches when large numbers + of rows are given. + + The parameter is False for the default dialect, and is set to + True for SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL, + SQL Server. It remains at False for Oracle, which provides native + "executemany with RETURNING" support and also does not support + ``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL + dialects that don't support RETURNING will not report + ``insert_executemany_returning`` as True. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + """ + + use_insertmanyvalues_wo_returning: bool + """if True, and use_insertmanyvalues is also True, INSERT statements + that don't include RETURNING will also use "insertmanyvalues". + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`engine_insertmanyvalues` + + """ + + insertmanyvalues_implicit_sentinel: InsertmanyvaluesSentinelOpts + """Options indicating the database supports a form of bulk INSERT where + the autoincrement integer primary key can be reliably used as an ordering + for INSERTed rows. + + .. versionadded:: 2.0.10 + + .. seealso:: + + :ref:`engine_insertmanyvalues_returning_order` + + """ + + insertmanyvalues_page_size: int + """Number of rows to render into an individual INSERT..VALUES() statement + for :attr:`.ExecuteStyle.INSERTMANYVALUES` executions. + + The default dialect defaults this to 1000. + + .. versionadded:: 2.0 + + .. seealso:: + + :paramref:`_engine.Connection.execution_options.insertmanyvalues_page_size` - + execution option available on :class:`_engine.Connection`, statements + + """ # noqa: E501 + + insertmanyvalues_max_parameters: int + """Alternate to insertmanyvalues_page_size, will additionally limit + page size based on number of parameters total in the statement. + + + """ + + preexecute_autoincrement_sequences: bool + """True if 'implicit' primary key functions must be executed separately + in order to get their value, if RETURNING is not used. + + This is currently oriented towards PostgreSQL when the + ``implicit_returning=False`` parameter is used on a :class:`.Table` + object. + + """ + + insert_returning: bool + """if the dialect supports RETURNING with INSERT + + .. versionadded:: 2.0 + + """ + + update_returning: bool + """if the dialect supports RETURNING with UPDATE + + .. versionadded:: 2.0 + + """ + + update_returning_multifrom: bool + """if the dialect supports RETURNING with UPDATE..FROM + + .. versionadded:: 2.0 + + """ + + delete_returning: bool + """if the dialect supports RETURNING with DELETE + + .. versionadded:: 2.0 + + """ + + delete_returning_multifrom: bool + """if the dialect supports RETURNING with DELETE..FROM + + .. versionadded:: 2.0 + + """ + + favor_returning_over_lastrowid: bool + """for backends that support both a lastrowid and a RETURNING insert + strategy, favor RETURNING for simple single-int pk inserts. + + cursor.lastrowid tends to be more performant on most backends. + + """ + + supports_identity_columns: bool + """target database supports IDENTITY""" + + cte_follows_insert: bool + """target database, when given a CTE with an INSERT statement, needs + the CTE to be below the INSERT""" + + colspecs: MutableMapping[Type[TypeEngine[Any]], Type[TypeEngine[Any]]] + """A dictionary of TypeEngine classes from sqlalchemy.types mapped + to subclasses that are specific to the dialect class. This + dictionary is class-level only and is not accessed from the + dialect instance itself. + """ + + supports_sequences: bool + """Indicates if the dialect supports CREATE SEQUENCE or similar.""" + + sequences_optional: bool + """If True, indicates if the :paramref:`_schema.Sequence.optional` + parameter on the :class:`_schema.Sequence` construct + should signal to not generate a CREATE SEQUENCE. Applies only to + dialects that support sequences. Currently used only to allow PostgreSQL + SERIAL to be used on a column that specifies Sequence() for usage on + other backends. + """ + + default_sequence_base: int + """the default value that will be rendered as the "START WITH" portion of + a CREATE SEQUENCE DDL statement. + + """ + + supports_native_enum: bool + """Indicates if the dialect supports a native ENUM construct. + This will prevent :class:`_types.Enum` from generating a CHECK + constraint when that type is used in "native" mode. + """ + + supports_native_boolean: bool + """Indicates if the dialect supports a native boolean construct. + This will prevent :class:`_types.Boolean` from generating a CHECK + constraint when that type is used. + """ + + supports_native_decimal: bool + """indicates if Decimal objects are handled and returned for precision + numeric types, or if floats are returned""" + + supports_native_uuid: bool + """indicates if Python UUID() objects are handled natively by the + driver for SQL UUID datatypes. + + .. versionadded:: 2.0 + + """ + + returns_native_bytes: bool + """indicates if Python bytes() objects are returned natively by the + driver for SQL "binary" datatypes. + + .. versionadded:: 2.0.11 + + """ + + construct_arguments: Optional[ + List[Tuple[Type[Union[SchemaItem, ClauseElement]], Mapping[str, Any]]] + ] = None + """Optional set of argument specifiers for various SQLAlchemy + constructs, typically schema items. + + To implement, establish as a series of tuples, as in:: + + construct_arguments = [ + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) + ] + + If the above construct is established on the PostgreSQL dialect, + the :class:`.Index` construct will now accept the keyword arguments + ``postgresql_using``, ``postgresql_where``, nad ``postgresql_ops``. + Any other argument specified to the constructor of :class:`.Index` + which is prefixed with ``postgresql_`` will raise :class:`.ArgumentError`. + + A dialect which does not include a ``construct_arguments`` member will + not participate in the argument validation system. For such a dialect, + any argument name is accepted by all participating constructs, within + the namespace of arguments prefixed with that dialect name. The rationale + here is so that third-party dialects that haven't yet implemented this + feature continue to function in the old way. + + .. seealso:: + + :class:`.DialectKWArgs` - implementing base class which consumes + :attr:`.DefaultDialect.construct_arguments` + + + """ + + reflection_options: Sequence[str] = () + """Sequence of string names indicating keyword arguments that can be + established on a :class:`.Table` object which will be passed as + "reflection options" when using :paramref:`.Table.autoload_with`. + + Current example is "oracle_resolve_synonyms" in the Oracle dialect. + + """ + + dbapi_exception_translation_map: Mapping[str, str] = util.EMPTY_DICT + """A dictionary of names that will contain as values the names of + pep-249 exceptions ("IntegrityError", "OperationalError", etc) + keyed to alternate class names, to support the case where a + DBAPI has exception classes that aren't named as they are + referred to (e.g. IntegrityError = MyException). In the vast + majority of cases this dictionary is empty. + """ + + supports_comments: bool + """Indicates the dialect supports comment DDL on tables and columns.""" + + inline_comments: bool + """Indicates the dialect supports comment DDL that's inline with the + definition of a Table or Column. If False, this implies that ALTER must + be used to set table and column comments.""" + + supports_constraint_comments: bool + """Indicates if the dialect supports comment DDL on constraints. + + .. versionadded: 2.0 + """ + + _has_events = False + + supports_statement_cache: bool = True + """indicates if this dialect supports caching. + + All dialects that are compatible with statement caching should set this + flag to True directly on each dialect class and subclass that supports + it. SQLAlchemy tests that this flag is locally present on each dialect + subclass before it will use statement caching. This is to provide + safety for legacy or new dialects that are not yet fully tested to be + compliant with SQL statement caching. + + .. versionadded:: 1.4.5 + + .. seealso:: + + :ref:`engine_thirdparty_caching` + + """ + + _supports_statement_cache: bool + """internal evaluation for supports_statement_cache""" + + bind_typing = BindTyping.NONE + """define a means of passing typing information to the database and/or + driver for bound parameters. + + See :class:`.BindTyping` for values. + + .. versionadded:: 2.0 + + """ + + is_async: bool + """Whether or not this dialect is intended for asyncio use.""" + + has_terminate: bool + """Whether or not this dialect has a separate "terminate" implementation + that does not block or require awaiting.""" + + engine_config_types: Mapping[str, Any] + """a mapping of string keys that can be in an engine config linked to + type conversion functions. + + """ + + label_length: Optional[int] + """optional user-defined max length for SQL labels""" + + include_set_input_sizes: Optional[Set[Any]] + """set of DBAPI type objects that should be included in + automatic cursor.setinputsizes() calls. + + This is only used if bind_typing is BindTyping.SET_INPUT_SIZES + + """ + + exclude_set_input_sizes: Optional[Set[Any]] + """set of DBAPI type objects that should be excluded in + automatic cursor.setinputsizes() calls. + + This is only used if bind_typing is BindTyping.SET_INPUT_SIZES + + """ + + supports_simple_order_by_label: bool + """target database supports ORDER BY , where + refers to a label in the columns clause of the SELECT""" + + div_is_floordiv: bool + """target database treats the / division operator as "floor division" """ + + tuple_in_values: bool + """target database supports tuple IN, i.e. (x, y) IN ((q, p), (r, z))""" + + _bind_typing_render_casts: bool + + _type_memos: MutableMapping[TypeEngine[Any], _TypeMemoDict] + + def _builtin_onconnect(self) -> Optional[_ListenerFnType]: + raise NotImplementedError() + + def create_connect_args(self, url: URL) -> ConnectArgsType: + """Build DB-API compatible connection arguments. + + Given a :class:`.URL` object, returns a tuple + consisting of a ``(*args, **kwargs)`` suitable to send directly + to the dbapi's connect function. The arguments are sent to the + :meth:`.Dialect.connect` method which then runs the DBAPI-level + ``connect()`` function. + + The method typically makes use of the + :meth:`.URL.translate_connect_args` + method in order to generate a dictionary of options. + + The default implementation is:: + + def create_connect_args(self, url): + opts = url.translate_connect_args() + opts.update(url.query) + return ([], opts) + + :param url: a :class:`.URL` object + + :return: a tuple of ``(*args, **kwargs)`` which will be passed to the + :meth:`.Dialect.connect` method. + + .. seealso:: + + :meth:`.URL.translate_connect_args` + + """ + + raise NotImplementedError() + + @classmethod + def import_dbapi(cls) -> ModuleType: + """Import the DBAPI module that is used by this dialect. + + The Python module object returned here will be assigned as an + instance variable to a constructed dialect under the name + ``.dbapi``. + + .. versionchanged:: 2.0 The :meth:`.Dialect.import_dbapi` class + method is renamed from the previous method ``.Dialect.dbapi()``, + which would be replaced at dialect instantiation time by the + DBAPI module itself, thus using the same name in two different ways. + If a ``.Dialect.dbapi()`` classmethod is present on a third-party + dialect, it will be used and a deprecation warning will be emitted. + + """ + raise NotImplementedError() + + @classmethod + def type_descriptor(cls, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: + """Transform a generic type to a dialect-specific type. + + Dialect classes will usually use the + :func:`_types.adapt_type` function in the types module to + accomplish this. + + The returned result is cached *per dialect class* so can + contain no dialect-instance state. + + """ + + raise NotImplementedError() + + def initialize(self, connection: Connection) -> None: + """Called during strategized creation of the dialect with a + connection. + + Allows dialects to configure options based on server version info or + other properties. + + The connection passed here is a SQLAlchemy Connection object, + with full capabilities. + + The initialize() method of the base dialect should be called via + super(). + + .. note:: as of SQLAlchemy 1.4, this method is called **before** + any :meth:`_engine.Dialect.on_connect` hooks are called. + + """ + + pass + + if TYPE_CHECKING: + + def _overrides_default(self, method_name: str) -> bool: ... + + def get_columns( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedColumn]: + """Return information about columns in ``table_name``. + + Given a :class:`_engine.Connection`, a string + ``table_name``, and an optional string ``schema``, return column + information as a list of dictionaries + corresponding to the :class:`.ReflectedColumn` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_columns`. + + """ + + raise NotImplementedError() + + def get_multi_columns( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedColumn]]]: + """Return information about columns in all tables in the + given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_columns`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_pk_constraint( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedPrimaryKeyConstraint: + """Return information about the primary key constraint on + table_name`. + + Given a :class:`_engine.Connection`, a string + ``table_name``, and an optional string ``schema``, return primary + key information as a dictionary corresponding to the + :class:`.ReflectedPrimaryKeyConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_pk_constraint`. + + """ + raise NotImplementedError() + + def get_multi_pk_constraint( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedPrimaryKeyConstraint]]: + """Return information about primary key constraints in + all tables in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_pk_constraint`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + def get_foreign_keys( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedForeignKeyConstraint]: + """Return information about foreign_keys in ``table_name``. + + Given a :class:`_engine.Connection`, a string + ``table_name``, and an optional string ``schema``, return foreign + key information as a list of dicts corresponding to the + :class:`.ReflectedForeignKeyConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_foreign_keys`. + """ + + raise NotImplementedError() + + def get_multi_foreign_keys( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedForeignKeyConstraint]]]: + """Return information about foreign_keys in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_foreign_keys`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_table_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of table names for ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_names`. + + """ + + raise NotImplementedError() + + def get_temp_table_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of temporary table names on the given connection, + if supported by the underlying backend. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_table_names`. + + """ + + raise NotImplementedError() + + def get_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all non-materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_names`. + + :param schema: schema name to query, if not the default schema. + + """ + + raise NotImplementedError() + + def get_materialized_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_materialized_view_names`. + + :param schema: schema name to query, if not the default schema. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_sequence_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all sequence names available in the database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_sequence_names`. + + :param schema: schema name to query, if not the default schema. + + .. versionadded:: 1.4 + """ + + raise NotImplementedError() + + def get_temp_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of temporary view names on the given connection, + if supported by the underlying backend. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_view_names`. + + """ + + raise NotImplementedError() + + def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]: + """Return a list of all schema names available in the database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_schema_names`. + """ + raise NotImplementedError() + + def get_view_definition( + self, + connection: Connection, + view_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> str: + """Return plain or materialized view definition. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_definition`. + + Given a :class:`_engine.Connection`, a string + ``view_name``, and an optional string ``schema``, return the view + definition. + """ + + raise NotImplementedError() + + def get_indexes( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedIndex]: + """Return information about indexes in ``table_name``. + + Given a :class:`_engine.Connection`, a string + ``table_name`` and an optional string ``schema``, return index + information as a list of dictionaries corresponding to the + :class:`.ReflectedIndex` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_indexes`. + """ + + raise NotImplementedError() + + def get_multi_indexes( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedIndex]]]: + """Return information about indexes in in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_indexes`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_unique_constraints( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedUniqueConstraint]: + r"""Return information about unique constraints in ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, return + unique constraint information as a list of dicts corresponding + to the :class:`.ReflectedUniqueConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_unique_constraints`. + """ + + raise NotImplementedError() + + def get_multi_unique_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedUniqueConstraint]]]: + """Return information about unique constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_unique_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_check_constraints( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> List[ReflectedCheckConstraint]: + r"""Return information about check constraints in ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, return + check constraint information as a list of dicts corresponding + to the :class:`.ReflectedCheckConstraint` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_check_constraints`. + + """ + + raise NotImplementedError() + + def get_multi_check_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedCheckConstraint]]]: + """Return information about check constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_check_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def get_table_options( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> Dict[str, Any]: + """Return a dictionary of options specified when ``table_name`` + was created. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_options`. + """ + raise NotImplementedError() + + def get_multi_table_options( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, Dict[str, Any]]]: + """Return a dictionary of options specified when the tables in the + given schema were created. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_options`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + def get_table_comment( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> ReflectedTableComment: + r"""Return the "comment" for the table identified by ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, return + table comment information as a dictionary corresponding to the + :class:`.ReflectedTableComment` dictionary. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_table_comment`. + + :raise: ``NotImplementedError`` for dialects that don't support + comments. + + .. versionadded:: 1.2 + + """ + + raise NotImplementedError() + + def get_multi_table_comment( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedTableComment]]: + """Return information about the table comment in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_comment`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def normalize_name(self, name: str) -> str: + """convert the given name to lowercase if it is detected as + case insensitive. + + This method is only used if the dialect defines + requires_name_normalize=True. + + """ + raise NotImplementedError() + + def denormalize_name(self, name: str) -> str: + """convert the given name to a case insensitive identifier + for the backend if it is an all-lowercase name. + + This method is only used if the dialect defines + requires_name_normalize=True. + + """ + raise NotImplementedError() + + def has_table( + self, + connection: Connection, + table_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + """For internal dialect use, check the existence of a particular table + or view in the database. + + Given a :class:`_engine.Connection` object, a string table_name and + optional schema name, return True if the given table exists in the + database, False otherwise. + + This method serves as the underlying implementation of the + public facing :meth:`.Inspector.has_table` method, and is also used + internally to implement the "checkfirst" behavior for methods like + :meth:`_schema.Table.create` and :meth:`_schema.MetaData.create_all`. + + .. note:: This method is used internally by SQLAlchemy, and is + published so that third-party dialects may provide an + implementation. It is **not** the public API for checking for table + presence. Please use the :meth:`.Inspector.has_table` method. + + .. versionchanged:: 2.0:: :meth:`_engine.Dialect.has_table` now + formally supports checking for additional table-like objects: + + * any type of views (plain or materialized) + * temporary tables of any kind + + Previously, these two checks were not formally specified and + different dialects would vary in their behavior. The dialect + testing suite now includes tests for all of these object types, + and dialects to the degree that the backing database supports views + or temporary tables should seek to support locating these objects + for full compliance. + + """ + + raise NotImplementedError() + + def has_index( + self, + connection: Connection, + table_name: str, + index_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + """Check the existence of a particular index name in the database. + + Given a :class:`_engine.Connection` object, a string + ``table_name`` and string index name, return ``True`` if an index of + the given name on the given table exists, ``False`` otherwise. + + The :class:`.DefaultDialect` implements this in terms of the + :meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods, + however dialects can implement a more performant version. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_index`. + + .. versionadded:: 1.4 + + """ + + raise NotImplementedError() + + def has_sequence( + self, + connection: Connection, + sequence_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + """Check the existence of a particular sequence in the database. + + Given a :class:`_engine.Connection` object and a string + `sequence_name`, return ``True`` if the given sequence exists in + the database, ``False`` otherwise. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_sequence`. + """ + + raise NotImplementedError() + + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + """Check the existence of a particular schema name in the database. + + Given a :class:`_engine.Connection` object, a string + ``schema_name``, return ``True`` if a schema of the + given exists, ``False`` otherwise. + + The :class:`.DefaultDialect` implements this by checking + the presence of ``schema_name`` among the schemas returned by + :meth:`.Dialect.get_schema_names`, + however dialects can implement a more performant version. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_schema`. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + + def _get_server_version_info(self, connection: Connection) -> Any: + """Retrieve the server version info from the given connection. + + This is used by the default implementation to populate the + "server_version_info" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def _get_default_schema_name(self, connection: Connection) -> str: + """Return the string name of the currently selected schema from + the given connection. + + This is used by the default implementation to populate the + "default_schema_name" attribute and is called exactly + once upon first connect. + + """ + + raise NotImplementedError() + + def do_begin(self, dbapi_connection: PoolProxiedConnection) -> None: + """Provide an implementation of ``connection.begin()``, given a + DB-API connection. + + The DBAPI has no dedicated "begin" method and it is expected + that transactions are implicit. This hook is provided for those + DBAPIs that might need additional help in this area. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_rollback(self, dbapi_connection: PoolProxiedConnection) -> None: + """Provide an implementation of ``connection.rollback()``, given + a DB-API connection. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_commit(self, dbapi_connection: PoolProxiedConnection) -> None: + """Provide an implementation of ``connection.commit()``, given a + DB-API connection. + + :param dbapi_connection: a DBAPI connection, typically + proxied within a :class:`.ConnectionFairy`. + + """ + + raise NotImplementedError() + + def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: + """Provide an implementation of ``connection.close()`` that tries as + much as possible to not block, given a DBAPI + connection. + + In the vast majority of cases this just calls .close(), however + for some asyncio dialects may call upon different API features. + + This hook is called by the :class:`_pool.Pool` + when a connection is being recycled or has been invalidated. + + .. versionadded:: 1.4.41 + + """ + + raise NotImplementedError() + + def do_close(self, dbapi_connection: DBAPIConnection) -> None: + """Provide an implementation of ``connection.close()``, given a DBAPI + connection. + + This hook is called by the :class:`_pool.Pool` + when a connection has been + detached from the pool, or is being returned beyond the normal + capacity of the pool. + + """ + + raise NotImplementedError() + + def _do_ping_w_event(self, dbapi_connection: DBAPIConnection) -> bool: + raise NotImplementedError() + + def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: + """ping the DBAPI connection and return True if the connection is + usable.""" + raise NotImplementedError() + + def do_set_input_sizes( + self, + cursor: DBAPICursor, + list_of_tuples: _GenericSetInputSizesType, + context: ExecutionContext, + ) -> Any: + """invoke the cursor.setinputsizes() method with appropriate arguments + + This hook is called if the :attr:`.Dialect.bind_typing` attribute is + set to the + :attr:`.BindTyping.SETINPUTSIZES` value. + Parameter data is passed in a list of tuples (paramname, dbtype, + sqltype), where ``paramname`` is the key of the parameter in the + statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the + SQLAlchemy type. The order of tuples is in the correct parameter order. + + .. versionadded:: 1.4 + + .. versionchanged:: 2.0 - setinputsizes mode is now enabled by + setting :attr:`.Dialect.bind_typing` to + :attr:`.BindTyping.SETINPUTSIZES`. Dialects which accept + a ``use_setinputsizes`` parameter should set this value + appropriately. + + + """ + raise NotImplementedError() + + def create_xid(self) -> Any: + """Create a two-phase transaction ID. + + This id will be passed to do_begin_twophase(), + do_rollback_twophase(), do_commit_twophase(). Its format is + unspecified. + """ + + raise NotImplementedError() + + def do_savepoint(self, connection: Connection, name: str) -> None: + """Create a savepoint with the given name. + + :param connection: a :class:`_engine.Connection`. + :param name: savepoint name. + + """ + + raise NotImplementedError() + + def do_rollback_to_savepoint( + self, connection: Connection, name: str + ) -> None: + """Rollback a connection to the named savepoint. + + :param connection: a :class:`_engine.Connection`. + :param name: savepoint name. + + """ + + raise NotImplementedError() + + def do_release_savepoint(self, connection: Connection, name: str) -> None: + """Release the named savepoint on a connection. + + :param connection: a :class:`_engine.Connection`. + :param name: savepoint name. + """ + + raise NotImplementedError() + + def do_begin_twophase(self, connection: Connection, xid: Any) -> None: + """Begin a two phase transaction on the given connection. + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + + """ + + raise NotImplementedError() + + def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: + """Prepare a two phase transaction on the given connection. + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + + """ + + raise NotImplementedError() + + def do_rollback_twophase( + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: + """Rollback a two phase transaction on the given connection. + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + :param is_prepared: whether or not + :meth:`.TwoPhaseTransaction.prepare` was called. + :param recover: if the recover flag was passed. + + """ + + raise NotImplementedError() + + def do_commit_twophase( + self, + connection: Connection, + xid: Any, + is_prepared: bool = True, + recover: bool = False, + ) -> None: + """Commit a two phase transaction on the given connection. + + + :param connection: a :class:`_engine.Connection`. + :param xid: xid + :param is_prepared: whether or not + :meth:`.TwoPhaseTransaction.prepare` was called. + :param recover: if the recover flag was passed. + + """ + + raise NotImplementedError() + + def do_recover_twophase(self, connection: Connection) -> List[Any]: + """Recover list of uncommitted prepared two phase transaction + identifiers on the given connection. + + :param connection: a :class:`_engine.Connection`. + + """ + + raise NotImplementedError() + + def _deliver_insertmanyvalues_batches( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + generic_setinputsizes: Optional[_GenericSetInputSizesType], + context: ExecutionContext, + ) -> Iterator[_InsertManyValuesBatch]: + """convert executemany parameters for an INSERT into an iterator + of statement/single execute values, used by the insertmanyvalues + feature. + + """ + raise NotImplementedError() + + def do_executemany( + self, + cursor: DBAPICursor, + statement: str, + parameters: _DBAPIMultiExecuteParams, + context: Optional[ExecutionContext] = None, + ) -> None: + """Provide an implementation of ``cursor.executemany(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute( + self, + cursor: DBAPICursor, + statement: str, + parameters: Optional[_DBAPISingleExecuteParams], + context: Optional[ExecutionContext] = None, + ) -> None: + """Provide an implementation of ``cursor.execute(statement, + parameters)``.""" + + raise NotImplementedError() + + def do_execute_no_params( + self, + cursor: DBAPICursor, + statement: str, + context: Optional[ExecutionContext] = None, + ) -> None: + """Provide an implementation of ``cursor.execute(statement)``. + + The parameter collection should not be sent. + + """ + + raise NotImplementedError() + + def is_disconnect( + self, + e: Exception, + connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], + cursor: Optional[DBAPICursor], + ) -> bool: + """Return True if the given DB-API error indicates an invalid + connection""" + + raise NotImplementedError() + + def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: + r"""Establish a connection using this dialect's DBAPI. + + The default implementation of this method is:: + + def connect(self, *cargs, **cparams): + return self.dbapi.connect(*cargs, **cparams) + + The ``*cargs, **cparams`` parameters are generated directly + from this dialect's :meth:`.Dialect.create_connect_args` method. + + This method may be used for dialects that need to perform programmatic + per-connection steps when a new connection is procured from the + DBAPI. + + + :param \*cargs: positional parameters returned from the + :meth:`.Dialect.create_connect_args` method + + :param \*\*cparams: keyword parameters returned from the + :meth:`.Dialect.create_connect_args` method. + + :return: a DBAPI connection, typically from the :pep:`249` module + level ``.connect()`` function. + + .. seealso:: + + :meth:`.Dialect.create_connect_args` + + :meth:`.Dialect.on_connect` + + """ + raise NotImplementedError() + + def on_connect_url(self, url: URL) -> Optional[Callable[[Any], Any]]: + """return a callable which sets up a newly created DBAPI connection. + + This method is a new hook that supersedes the + :meth:`_engine.Dialect.on_connect` method when implemented by a + dialect. When not implemented by a dialect, it invokes the + :meth:`_engine.Dialect.on_connect` method directly to maintain + compatibility with existing dialects. There is no deprecation + for :meth:`_engine.Dialect.on_connect` expected. + + The callable should accept a single argument "conn" which is the + DBAPI connection itself. The inner callable has no + return value. + + E.g.:: + + class MyDialect(default.DefaultDialect): + # ... + + def on_connect_url(self, url): + def do_on_connect(connection): + connection.execute("SET SPECIAL FLAGS etc") + + return do_on_connect + + This is used to set dialect-wide per-connection options such as + isolation modes, Unicode modes, etc. + + This method differs from :meth:`_engine.Dialect.on_connect` in that + it is passed the :class:`_engine.URL` object that's relevant to the + connect args. Normally the only way to get this is from the + :meth:`_engine.Dialect.on_connect` hook is to look on the + :class:`_engine.Engine` itself, however this URL object may have been + replaced by plugins. + + .. note:: + + The default implementation of + :meth:`_engine.Dialect.on_connect_url` is to invoke the + :meth:`_engine.Dialect.on_connect` method. Therefore if a dialect + implements this method, the :meth:`_engine.Dialect.on_connect` + method **will not be called** unless the overriding dialect calls + it directly from here. + + .. versionadded:: 1.4.3 added :meth:`_engine.Dialect.on_connect_url` + which normally calls into :meth:`_engine.Dialect.on_connect`. + + :param url: a :class:`_engine.URL` object representing the + :class:`_engine.URL` that was passed to the + :meth:`_engine.Dialect.create_connect_args` method. + + :return: a callable that accepts a single DBAPI connection as an + argument, or None. + + .. seealso:: + + :meth:`_engine.Dialect.on_connect` + + """ + return self.on_connect() + + def on_connect(self) -> Optional[Callable[[Any], Any]]: + """return a callable which sets up a newly created DBAPI connection. + + The callable should accept a single argument "conn" which is the + DBAPI connection itself. The inner callable has no + return value. + + E.g.:: + + class MyDialect(default.DefaultDialect): + # ... + + def on_connect(self): + def do_on_connect(connection): + connection.execute("SET SPECIAL FLAGS etc") + + return do_on_connect + + This is used to set dialect-wide per-connection options such as + isolation modes, Unicode modes, etc. + + The "do_on_connect" callable is invoked by using the + :meth:`_events.PoolEvents.connect` event + hook, then unwrapping the DBAPI connection and passing it into the + callable. + + .. versionchanged:: 1.4 the on_connect hook is no longer called twice + for the first connection of a dialect. The on_connect hook is still + called before the :meth:`_engine.Dialect.initialize` method however. + + .. versionchanged:: 1.4.3 the on_connect hook is invoked from a new + method on_connect_url that passes the URL that was used to create + the connect args. Dialects can implement on_connect_url instead + of on_connect if they need the URL object that was used for the + connection in order to get additional context. + + If None is returned, no event listener is generated. + + :return: a callable that accepts a single DBAPI connection as an + argument, or None. + + .. seealso:: + + :meth:`.Dialect.connect` - allows the DBAPI ``connect()`` sequence + itself to be controlled. + + :meth:`.Dialect.on_connect_url` - supersedes + :meth:`.Dialect.on_connect` to also receive the + :class:`_engine.URL` object in context. + + """ + return None + + def reset_isolation_level(self, dbapi_connection: DBAPIConnection) -> None: + """Given a DBAPI connection, revert its isolation to the default. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`_engine.Connection` and + :class:`_engine.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`_engine.Connection` isolation level + + :paramref:`_sa.create_engine.isolation_level` - + set per :class:`_engine.Engine` isolation level + + """ + + raise NotImplementedError() + + def set_isolation_level( + self, dbapi_connection: DBAPIConnection, level: IsolationLevel + ) -> None: + """Given a DBAPI connection, set its isolation level. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`_engine.Connection` and + :class:`_engine.Engine` + isolation level facilities; these APIs should be preferred for + most typical use cases. + + If the dialect also implements the + :meth:`.Dialect.get_isolation_level_values` method, then the given + level is guaranteed to be one of the string names within that sequence, + and the method will not need to anticipate a lookup failure. + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`_engine.Connection` isolation level + + :paramref:`_sa.create_engine.isolation_level` - + set per :class:`_engine.Engine` isolation level + + """ + + raise NotImplementedError() + + def get_isolation_level( + self, dbapi_connection: DBAPIConnection + ) -> IsolationLevel: + """Given a DBAPI connection, return its isolation level. + + When working with a :class:`_engine.Connection` object, + the corresponding + DBAPI connection may be procured using the + :attr:`_engine.Connection.connection` accessor. + + Note that this is a dialect-level method which is used as part + of the implementation of the :class:`_engine.Connection` and + :class:`_engine.Engine` isolation level facilities; + these APIs should be preferred for most typical use cases. + + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current level + + :attr:`_engine.Connection.default_isolation_level` + - view default level + + :paramref:`.Connection.execution_options.isolation_level` - + set per :class:`_engine.Connection` isolation level + + :paramref:`_sa.create_engine.isolation_level` - + set per :class:`_engine.Engine` isolation level + + + """ + + raise NotImplementedError() + + def get_default_isolation_level( + self, dbapi_conn: DBAPIConnection + ) -> IsolationLevel: + """Given a DBAPI connection, return its isolation level, or + a default isolation level if one cannot be retrieved. + + This method may only raise NotImplementedError and + **must not raise any other exception**, as it is used implicitly upon + first connect. + + The method **must return a value** for a dialect that supports + isolation level settings, as this level is what will be reverted + towards when a per-connection isolation level change is made. + + The method defaults to using the :meth:`.Dialect.get_isolation_level` + method unless overridden by a dialect. + + .. versionadded:: 1.3.22 + + """ + raise NotImplementedError() + + def get_isolation_level_values( + self, dbapi_conn: DBAPIConnection + ) -> List[IsolationLevel]: + """return a sequence of string isolation level names that are accepted + by this dialect. + + The available names should use the following conventions: + + * use UPPERCASE names. isolation level methods will accept lowercase + names but these are normalized into UPPERCASE before being passed + along to the dialect. + * separate words should be separated by spaces, not underscores, e.g. + ``REPEATABLE READ``. isolation level names will have underscores + converted to spaces before being passed along to the dialect. + * The names for the four standard isolation names to the extent that + they are supported by the backend should be ``READ UNCOMMITTED`` + ``READ COMMITTED``, ``REPEATABLE READ``, ``SERIALIZABLE`` + * if the dialect supports an autocommit option it should be provided + using the isolation level name ``AUTOCOMMIT``. + * Other isolation modes may also be present, provided that they + are named in UPPERCASE and use spaces not underscores. + + This function is used so that the default dialect can check that + a given isolation level parameter is valid, else raises an + :class:`_exc.ArgumentError`. + + A DBAPI connection is passed to the method, in the unlikely event that + the dialect needs to interrogate the connection itself to determine + this list, however it is expected that most backends will return + a hardcoded list of values. If the dialect supports "AUTOCOMMIT", + that value should also be present in the sequence returned. + + The method raises ``NotImplementedError`` by default. If a dialect + does not implement this method, then the default dialect will not + perform any checking on a given isolation level value before passing + it onto the :meth:`.Dialect.set_isolation_level` method. This is + to allow backwards-compatibility with third party dialects that may + not yet be implementing this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() + + def _assert_and_set_isolation_level( + self, dbapi_conn: DBAPIConnection, level: IsolationLevel + ) -> None: + raise NotImplementedError() + + @classmethod + def get_dialect_cls(cls, url: URL) -> Type[Dialect]: + """Given a URL, return the :class:`.Dialect` that will be used. + + This is a hook that allows an external plugin to provide functionality + around an existing dialect, by allowing the plugin to be loaded + from the url based on an entrypoint, and then the plugin returns + the actual dialect to be used. + + By default this just returns the cls. + + """ + return cls + + @classmethod + def get_async_dialect_cls(cls, url: URL) -> Type[Dialect]: + """Given a URL, return the :class:`.Dialect` that will be used by + an async engine. + + By default this is an alias of :meth:`.Dialect.get_dialect_cls` and + just returns the cls. It may be used if a dialect provides + both a sync and async version under the same name, like the + ``psycopg`` driver. + + .. versionadded:: 2 + + .. seealso:: + + :meth:`.Dialect.get_dialect_cls` + + """ + return cls.get_dialect_cls(url) + + @classmethod + def load_provisioning(cls) -> None: + """set up the provision.py module for this dialect. + + For dialects that include a provision.py module that sets up + provisioning followers, this method should initiate that process. + + A typical implementation would be:: + + @classmethod + def load_provisioning(cls): + __import__("mydialect.provision") + + The default method assumes a module named ``provision.py`` inside + the owning package of the current dialect, based on the ``__module__`` + attribute:: + + @classmethod + def load_provisioning(cls): + package = ".".join(cls.__module__.split(".")[0:-1]) + try: + __import__(package + ".provision") + except ImportError: + pass + + .. versionadded:: 1.3.14 + + """ + + @classmethod + def engine_created(cls, engine: Engine) -> None: + """A convenience hook called before returning the final + :class:`_engine.Engine`. + + If the dialect returned a different class from the + :meth:`.get_dialect_cls` + method, then the hook is called on both classes, first on + the dialect class returned by the :meth:`.get_dialect_cls` method and + then on the class on which the method was called. + + The hook should be used by dialects and/or wrappers to apply special + events to the engine or its components. In particular, it allows + a dialect-wrapping class to apply dialect-level events. + + """ + + def get_driver_connection(self, connection: DBAPIConnection) -> Any: + """Returns the connection object as returned by the external driver + package. + + For normal dialects that use a DBAPI compliant driver this call + will just return the ``connection`` passed as argument. + For dialects that instead adapt a non DBAPI compliant driver, like + when adapting an asyncio driver, this call will return the + connection-like object as returned by the driver. + + .. versionadded:: 1.4.24 + + """ + raise NotImplementedError() + + def set_engine_execution_options( + self, engine: Engine, opts: CoreExecuteOptionsParameter + ) -> None: + """Establish execution options for a given engine. + + This is implemented by :class:`.DefaultDialect` to establish + event hooks for new :class:`.Connection` instances created + by the given :class:`.Engine` which will then invoke the + :meth:`.Dialect.set_connection_execution_options` method for that + connection. + + """ + raise NotImplementedError() + + def set_connection_execution_options( + self, connection: Connection, opts: CoreExecuteOptionsParameter + ) -> None: + """Establish execution options for a given connection. + + This is implemented by :class:`.DefaultDialect` in order to implement + the :paramref:`_engine.Connection.execution_options.isolation_level` + execution option. Dialects can intercept various execution options + which may need to modify state on a particular DBAPI connection. + + .. versionadded:: 1.4 + + """ + raise NotImplementedError() + + def get_dialect_pool_class(self, url: URL) -> Type[Pool]: + """return a Pool class to use for a given URL""" + raise NotImplementedError() + + +class CreateEnginePlugin: + """A set of hooks intended to augment the construction of an + :class:`_engine.Engine` object based on entrypoint names in a URL. + + The purpose of :class:`_engine.CreateEnginePlugin` is to allow third-party + systems to apply engine, pool and dialect level event listeners without + the need for the target application to be modified; instead, the plugin + names can be added to the database URL. Target applications for + :class:`_engine.CreateEnginePlugin` include: + + * connection and SQL performance tools, e.g. which use events to track + number of checkouts and/or time spent with statements + + * connectivity plugins such as proxies + + A rudimentary :class:`_engine.CreateEnginePlugin` that attaches a logger + to an :class:`_engine.Engine` object might look like:: + + + import logging + + from sqlalchemy.engine import CreateEnginePlugin + from sqlalchemy import event + + class LogCursorEventsPlugin(CreateEnginePlugin): + def __init__(self, url, kwargs): + # consume the parameter "log_cursor_logging_name" from the + # URL query + logging_name = url.query.get("log_cursor_logging_name", "log_cursor") + + self.log = logging.getLogger(logging_name) + + def update_url(self, url): + "update the URL to one that no longer includes our parameters" + return url.difference_update_query(["log_cursor_logging_name"]) + + def engine_created(self, engine): + "attach an event listener after the new Engine is constructed" + event.listen(engine, "before_cursor_execute", self._log_event) + + + def _log_event( + self, + conn, + cursor, + statement, + parameters, + context, + executemany): + + self.log.info("Plugin logged cursor event: %s", statement) + + + + Plugins are registered using entry points in a similar way as that + of dialects:: + + entry_points={ + 'sqlalchemy.plugins': [ + 'log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin' + ] + + A plugin that uses the above names would be invoked from a database + URL as in:: + + from sqlalchemy import create_engine + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=log_cursor_plugin&log_cursor_logging_name=mylogger" + ) + + The ``plugin`` URL parameter supports multiple instances, so that a URL + may specify multiple plugins; they are loaded in the order stated + in the URL:: + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three") + + The plugin names may also be passed directly to :func:`_sa.create_engine` + using the :paramref:`_sa.create_engine.plugins` argument:: + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test", + plugins=["myplugin"]) + + .. versionadded:: 1.2.3 plugin names can also be specified + to :func:`_sa.create_engine` as a list + + A plugin may consume plugin-specific arguments from the + :class:`_engine.URL` object as well as the ``kwargs`` dictionary, which is + the dictionary of arguments passed to the :func:`_sa.create_engine` + call. "Consuming" these arguments includes that they must be removed + when the plugin initializes, so that the arguments are not passed along + to the :class:`_engine.Dialect` constructor, where they will raise an + :class:`_exc.ArgumentError` because they are not known by the dialect. + + As of version 1.4 of SQLAlchemy, arguments should continue to be consumed + from the ``kwargs`` dictionary directly, by removing the values with a + method such as ``dict.pop``. Arguments from the :class:`_engine.URL` object + should be consumed by implementing the + :meth:`_engine.CreateEnginePlugin.update_url` method, returning a new copy + of the :class:`_engine.URL` with plugin-specific parameters removed:: + + class MyPlugin(CreateEnginePlugin): + def __init__(self, url, kwargs): + self.my_argument_one = url.query['my_argument_one'] + self.my_argument_two = url.query['my_argument_two'] + self.my_argument_three = kwargs.pop('my_argument_three', None) + + def update_url(self, url): + return url.difference_update_query( + ["my_argument_one", "my_argument_two"] + ) + + Arguments like those illustrated above would be consumed from a + :func:`_sa.create_engine` call such as:: + + from sqlalchemy import create_engine + + engine = create_engine( + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=myplugin&my_argument_one=foo&my_argument_two=bar", + my_argument_three='bat' + ) + + .. versionchanged:: 1.4 + + The :class:`_engine.URL` object is now immutable; a + :class:`_engine.CreateEnginePlugin` that needs to alter the + :class:`_engine.URL` should implement the newly added + :meth:`_engine.CreateEnginePlugin.update_url` method, which + is invoked after the plugin is constructed. + + For migration, construct the plugin in the following way, checking + for the existence of the :meth:`_engine.CreateEnginePlugin.update_url` + method to detect which version is running:: + + class MyPlugin(CreateEnginePlugin): + def __init__(self, url, kwargs): + if hasattr(CreateEnginePlugin, "update_url"): + # detect the 1.4 API + self.my_argument_one = url.query['my_argument_one'] + self.my_argument_two = url.query['my_argument_two'] + else: + # detect the 1.3 and earlier API - mutate the + # URL directly + self.my_argument_one = url.query.pop('my_argument_one') + self.my_argument_two = url.query.pop('my_argument_two') + + self.my_argument_three = kwargs.pop('my_argument_three', None) + + def update_url(self, url): + # this method is only called in the 1.4 version + return url.difference_update_query( + ["my_argument_one", "my_argument_two"] + ) + + .. seealso:: + + :ref:`change_5526` - overview of the :class:`_engine.URL` change which + also includes notes regarding :class:`_engine.CreateEnginePlugin`. + + + When the engine creation process completes and produces the + :class:`_engine.Engine` object, it is again passed to the plugin via the + :meth:`_engine.CreateEnginePlugin.engine_created` hook. In this hook, additional + changes can be made to the engine, most typically involving setup of + events (e.g. those defined in :ref:`core_event_toplevel`). + + """ # noqa: E501 + + def __init__(self, url: URL, kwargs: Dict[str, Any]): + """Construct a new :class:`.CreateEnginePlugin`. + + The plugin object is instantiated individually for each call + to :func:`_sa.create_engine`. A single :class:`_engine. + Engine` will be + passed to the :meth:`.CreateEnginePlugin.engine_created` method + corresponding to this URL. + + :param url: the :class:`_engine.URL` object. The plugin may inspect + the :class:`_engine.URL` for arguments. Arguments used by the + plugin should be removed, by returning an updated :class:`_engine.URL` + from the :meth:`_engine.CreateEnginePlugin.update_url` method. + + .. versionchanged:: 1.4 + + The :class:`_engine.URL` object is now immutable, so a + :class:`_engine.CreateEnginePlugin` that needs to alter the + :class:`_engine.URL` object should implement the + :meth:`_engine.CreateEnginePlugin.update_url` method. + + :param kwargs: The keyword arguments passed to + :func:`_sa.create_engine`. + + """ + self.url = url + + def update_url(self, url: URL) -> URL: + """Update the :class:`_engine.URL`. + + A new :class:`_engine.URL` should be returned. This method is + typically used to consume configuration arguments from the + :class:`_engine.URL` which must be removed, as they will not be + recognized by the dialect. The + :meth:`_engine.URL.difference_update_query` method is available + to remove these arguments. See the docstring at + :class:`_engine.CreateEnginePlugin` for an example. + + + .. versionadded:: 1.4 + + """ + raise NotImplementedError() + + def handle_dialect_kwargs( + self, dialect_cls: Type[Dialect], dialect_args: Dict[str, Any] + ) -> None: + """parse and modify dialect kwargs""" + + def handle_pool_kwargs( + self, pool_cls: Type[Pool], pool_args: Dict[str, Any] + ) -> None: + """parse and modify pool kwargs""" + + def engine_created(self, engine: Engine) -> None: + """Receive the :class:`_engine.Engine` + object when it is fully constructed. + + The plugin may make additional changes to the engine, such as + registering engine or connection pool events. + + """ + + +class ExecutionContext: + """A messenger object for a Dialect that corresponds to a single + execution. + + """ + + engine: Engine + """engine which the Connection is associated with""" + + connection: Connection + """Connection object which can be freely used by default value + generators to execute SQL. This Connection should reference the + same underlying connection/transactional resources of + root_connection.""" + + root_connection: Connection + """Connection object which is the source of this ExecutionContext.""" + + dialect: Dialect + """dialect which created this ExecutionContext.""" + + cursor: DBAPICursor + """DB-API cursor procured from the connection""" + + compiled: Optional[Compiled] + """if passed to constructor, sqlalchemy.engine.base.Compiled object + being executed""" + + statement: str + """string version of the statement to be executed. Is either + passed to the constructor, or must be created from the + sql.Compiled object by the time pre_exec() has completed.""" + + invoked_statement: Optional[Executable] + """The Executable statement object that was given in the first place. + + This should be structurally equivalent to compiled.statement, but not + necessarily the same object as in a caching scenario the compiled form + will have been extracted from the cache. + + """ + + parameters: _AnyMultiExecuteParams + """bind parameters passed to the execute() or exec_driver_sql() methods. + + These are always stored as a list of parameter entries. A single-element + list corresponds to a ``cursor.execute()`` call and a multiple-element + list corresponds to ``cursor.executemany()``, except in the case + of :attr:`.ExecuteStyle.INSERTMANYVALUES` which will use + ``cursor.execute()`` one or more times. + + """ + + no_parameters: bool + """True if the execution style does not use parameters""" + + isinsert: bool + """True if the statement is an INSERT.""" + + isupdate: bool + """True if the statement is an UPDATE.""" + + execute_style: ExecuteStyle + """the style of DBAPI cursor method that will be used to execute + a statement. + + .. versionadded:: 2.0 + + """ + + executemany: bool + """True if the context has a list of more than one parameter set. + + Historically this attribute links to whether ``cursor.execute()`` or + ``cursor.executemany()`` will be used. It also can now mean that + "insertmanyvalues" may be used which indicates one or more + ``cursor.execute()`` calls. + + """ + + prefetch_cols: util.generic_fn_descriptor[Optional[Sequence[Column[Any]]]] + """a list of Column objects for which a client-side default + was fired off. Applies to inserts and updates.""" + + postfetch_cols: util.generic_fn_descriptor[Optional[Sequence[Column[Any]]]] + """a list of Column objects for which a server-side default or + inline SQL expression value was fired off. Applies to inserts + and updates.""" + + execution_options: _ExecuteOptions + """Execution options associated with the current statement execution""" + + @classmethod + def _init_ddl( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled_ddl: DDLCompiler, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_compiled( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + compiled: SQLCompiler, + parameters: _CoreMultiExecuteParams, + invoked_statement: Executable, + extracted_parameters: Optional[Sequence[BindParameter[Any]]], + cache_hit: CacheStats = CacheStats.CACHING_DISABLED, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_statement( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + statement: str, + parameters: _DBAPIMultiExecuteParams, + ) -> ExecutionContext: + raise NotImplementedError() + + @classmethod + def _init_default( + cls, + dialect: Dialect, + connection: Connection, + dbapi_connection: PoolProxiedConnection, + execution_options: _ExecuteOptions, + ) -> ExecutionContext: + raise NotImplementedError() + + def _exec_default( + self, + column: Optional[Column[Any]], + default: DefaultGenerator, + type_: Optional[TypeEngine[Any]], + ) -> Any: + raise NotImplementedError() + + def _prepare_set_input_sizes( + self, + ) -> Optional[List[Tuple[str, Any, TypeEngine[Any]]]]: + raise NotImplementedError() + + def _get_cache_stats(self) -> str: + raise NotImplementedError() + + def _setup_result_proxy(self) -> CursorResult[Any]: + raise NotImplementedError() + + def fire_sequence(self, seq: Sequence_SchemaItem, type_: Integer) -> int: + """given a :class:`.Sequence`, invoke it and return the next int + value""" + raise NotImplementedError() + + def create_cursor(self) -> DBAPICursor: + """Return a new cursor generated from this ExecutionContext's + connection. + + Some dialects may wish to change the behavior of + connection.cursor(), such as postgresql which may return a PG + "server side" cursor. + """ + + raise NotImplementedError() + + def pre_exec(self) -> None: + """Called before an execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `statement` and `parameters` datamembers must be + initialized after this statement is complete. + """ + + raise NotImplementedError() + + def get_out_parameter_values( + self, out_param_names: Sequence[str] + ) -> Sequence[Any]: + """Return a sequence of OUT parameter values from a cursor. + + For dialects that support OUT parameters, this method will be called + when there is a :class:`.SQLCompiler` object which has the + :attr:`.SQLCompiler.has_out_parameters` flag set. This flag in turn + will be set to True if the statement itself has :class:`.BindParameter` + objects that have the ``.isoutparam`` flag set which are consumed by + the :meth:`.SQLCompiler.visit_bindparam` method. If the dialect + compiler produces :class:`.BindParameter` objects with ``.isoutparam`` + set which are not handled by :meth:`.SQLCompiler.visit_bindparam`, it + should set this flag explicitly. + + The list of names that were rendered for each bound parameter + is passed to the method. The method should then return a sequence of + values corresponding to the list of parameter objects. Unlike in + previous SQLAlchemy versions, the values can be the **raw values** from + the DBAPI; the execution context will apply the appropriate type + handler based on what's present in self.compiled.binds and update the + values. The processed dictionary will then be made available via the + ``.out_parameters`` collection on the result object. Note that + SQLAlchemy 1.4 has multiple kinds of result object as part of the 2.0 + transition. + + .. versionadded:: 1.4 - added + :meth:`.ExecutionContext.get_out_parameter_values`, which is invoked + automatically by the :class:`.DefaultExecutionContext` when there + are :class:`.BindParameter` objects with the ``.isoutparam`` flag + set. This replaces the practice of setting out parameters within + the now-removed ``get_result_proxy()`` method. + + """ + raise NotImplementedError() + + def post_exec(self) -> None: + """Called after the execution of a compiled statement. + + If a compiled statement was passed to this ExecutionContext, + the `last_insert_ids`, `last_inserted_params`, etc. + datamembers should be available after this method completes. + """ + + raise NotImplementedError() + + def handle_dbapi_exception(self, e: BaseException) -> None: + """Receive a DBAPI exception which occurred upon execute, result + fetch, etc.""" + + raise NotImplementedError() + + def lastrow_has_defaults(self) -> bool: + """Return True if the last INSERT or UPDATE row contained + inlined or database-side defaults. + """ + + raise NotImplementedError() + + def get_rowcount(self) -> Optional[int]: + """Return the DBAPI ``cursor.rowcount`` value, or in some + cases an interpreted value. + + See :attr:`_engine.CursorResult.rowcount` for details on this. + + """ + + raise NotImplementedError() + + def fetchall_for_returning(self, cursor: DBAPICursor) -> Sequence[Any]: + """For a RETURNING result, deliver cursor.fetchall() from the + DBAPI cursor. + + This is a dialect-specific hook for dialects that have special + considerations when calling upon the rows delivered for a + "RETURNING" statement. Default implementation is + ``cursor.fetchall()``. + + This hook is currently used only by the :term:`insertmanyvalues` + feature. Dialects that don't set ``use_insertmanyvalues=True`` + don't need to consider this hook. + + .. versionadded:: 2.0.10 + + """ + raise NotImplementedError() + + +class ConnectionEventsTarget(EventTarget): + """An object which can accept events from :class:`.ConnectionEvents`. + + Includes :class:`_engine.Connection` and :class:`_engine.Engine`. + + .. versionadded:: 2.0 + + """ + + dispatch: dispatcher[ConnectionEventsTarget] + + +Connectable = ConnectionEventsTarget + + +class ExceptionContext: + """Encapsulate information about an error condition in progress. + + This object exists solely to be passed to the + :meth:`_events.DialectEvents.handle_error` event, + supporting an interface that + can be extended without backwards-incompatibility. + + + """ + + __slots__ = () + + dialect: Dialect + """The :class:`_engine.Dialect` in use. + + This member is present for all invocations of the event hook. + + .. versionadded:: 2.0 + + """ + + connection: Optional[Connection] + """The :class:`_engine.Connection` in use during the exception. + + This member is present, except in the case of a failure when + first connecting. + + .. seealso:: + + :attr:`.ExceptionContext.engine` + + + """ + + engine: Optional[Engine] + """The :class:`_engine.Engine` in use during the exception. + + This member is present in all cases except for when handling an error + within the connection pool "pre-ping" process. + + """ + + cursor: Optional[DBAPICursor] + """The DBAPI cursor object. + + May be None. + + """ + + statement: Optional[str] + """String SQL statement that was emitted directly to the DBAPI. + + May be None. + + """ + + parameters: Optional[_DBAPIAnyExecuteParams] + """Parameter collection that was emitted directly to the DBAPI. + + May be None. + + """ + + original_exception: BaseException + """The exception object which was caught. + + This member is always present. + + """ + + sqlalchemy_exception: Optional[StatementError] + """The :class:`sqlalchemy.exc.StatementError` which wraps the original, + and will be raised if exception handling is not circumvented by the event. + + May be None, as not all exception types are wrapped by SQLAlchemy. + For DBAPI-level exceptions that subclass the dbapi's Error class, this + field will always be present. + + """ + + chained_exception: Optional[BaseException] + """The exception that was returned by the previous handler in the + exception chain, if any. + + If present, this exception will be the one ultimately raised by + SQLAlchemy unless a subsequent handler replaces it. + + May be None. + + """ + + execution_context: Optional[ExecutionContext] + """The :class:`.ExecutionContext` corresponding to the execution + operation in progress. + + This is present for statement execution operations, but not for + operations such as transaction begin/end. It also is not present when + the exception was raised before the :class:`.ExecutionContext` + could be constructed. + + Note that the :attr:`.ExceptionContext.statement` and + :attr:`.ExceptionContext.parameters` members may represent a + different value than that of the :class:`.ExecutionContext`, + potentially in the case where a + :meth:`_events.ConnectionEvents.before_cursor_execute` event or similar + modified the statement/parameters to be sent. + + May be None. + + """ + + is_disconnect: bool + """Represent whether the exception as occurred represents a "disconnect" + condition. + + This flag will always be True or False within the scope of the + :meth:`_events.DialectEvents.handle_error` handler. + + SQLAlchemy will defer to this flag in order to determine whether or not + the connection should be invalidated subsequently. That is, by + assigning to this flag, a "disconnect" event which then results in + a connection and pool invalidation can be invoked or prevented by + changing this flag. + + + .. note:: The pool "pre_ping" handler enabled using the + :paramref:`_sa.create_engine.pool_pre_ping` parameter does **not** + consult this event before deciding if the "ping" returned false, + as opposed to receiving an unhandled error. For this use case, the + :ref:`legacy recipe based on engine_connect() may be used + `. A future API allow more + comprehensive customization of the "disconnect" detection mechanism + across all functions. + + """ + + invalidate_pool_on_disconnect: bool + """Represent whether all connections in the pool should be invalidated + when a "disconnect" condition is in effect. + + Setting this flag to False within the scope of the + :meth:`_events.DialectEvents.handle_error` + event will have the effect such + that the full collection of connections in the pool will not be + invalidated during a disconnect; only the current connection that is the + subject of the error will actually be invalidated. + + The purpose of this flag is for custom disconnect-handling schemes where + the invalidation of other connections in the pool is to be performed + based on other conditions, or even on a per-connection basis. + + """ + + is_pre_ping: bool + """Indicates if this error is occurring within the "pre-ping" step + performed when :paramref:`_sa.create_engine.pool_pre_ping` is set to + ``True``. In this mode, the :attr:`.ExceptionContext.engine` attribute + will be ``None``. The dialect in use is accessible via the + :attr:`.ExceptionContext.dialect` attribute. + + .. versionadded:: 2.0.5 + + """ + + +class AdaptedConnection: + """Interface of an adapted connection object to support the DBAPI protocol. + + Used by asyncio dialects to provide a sync-style pep-249 facade on top + of the asyncio connection/cursor API provided by the driver. + + .. versionadded:: 1.4.24 + + """ + + __slots__ = ("_connection",) + + _connection: Any + + @property + def driver_connection(self) -> Any: + """The connection object as returned by the driver after a connect.""" + return self._connection + + def run_async(self, fn: Callable[[Any], Awaitable[_T]]) -> _T: + """Run the awaitable returned by the given function, which is passed + the raw asyncio driver connection. + + This is used to invoke awaitable-only methods on the driver connection + within the context of a "synchronous" method, like a connection + pool event handler. + + E.g.:: + + engine = create_async_engine(...) + + @event.listens_for(engine.sync_engine, "connect") + def register_custom_types(dbapi_connection, ...): + dbapi_connection.run_async( + lambda connection: connection.set_type_codec( + 'MyCustomType', encoder, decoder, ... + ) + ) + + .. versionadded:: 1.4.30 + + .. seealso:: + + :ref:`asyncio_events_run_async` + + """ + return await_only(fn(self._connection)) + + def __repr__(self) -> str: + return "" % self._connection diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/mock.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/mock.py new file mode 100644 index 0000000..c9fa5eb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/mock.py @@ -0,0 +1,131 @@ +# engine/mock.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from operator import attrgetter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Optional +from typing import Type +from typing import Union + +from . import url as _url +from .. import util + + +if typing.TYPE_CHECKING: + from .base import Engine + from .interfaces import _CoreAnyExecuteParams + from .interfaces import CoreExecuteOptionsParameter + from .interfaces import Dialect + from .url import URL + from ..sql.base import Executable + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator + from ..sql.schema import HasSchemaAttr + from ..sql.schema import SchemaItem + + +class MockConnection: + def __init__(self, dialect: Dialect, execute: Callable[..., Any]): + self._dialect = dialect + self._execute_impl = execute + + engine: Engine = cast(Any, property(lambda s: s)) + dialect: Dialect = cast(Any, property(attrgetter("_dialect"))) + name: str = cast(Any, property(lambda s: s._dialect.name)) + + def connect(self, **kwargs: Any) -> MockConnection: + return self + + def schema_for_object(self, obj: HasSchemaAttr) -> Optional[str]: + return obj.schema + + def execution_options(self, **kw: Any) -> MockConnection: + return self + + def _run_ddl_visitor( + self, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, + **kwargs: Any, + ) -> None: + kwargs["checkfirst"] = False + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) + + def execute( + self, + obj: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + return self._execute_impl(obj, parameters) + + +def create_mock_engine( + url: Union[str, URL], executor: Any, **kw: Any +) -> MockConnection: + """Create a "mock" engine used for echoing DDL. + + This is a utility function used for debugging or storing the output of DDL + sequences as generated by :meth:`_schema.MetaData.create_all` + and related methods. + + The function accepts a URL which is used only to determine the kind of + dialect to be used, as well as an "executor" callable function which + will receive a SQL expression object and parameters, which can then be + echoed or otherwise printed. The executor's return value is not handled, + nor does the engine allow regular string statements to be invoked, and + is therefore only useful for DDL that is sent to the database without + receiving any results. + + E.g.:: + + from sqlalchemy import create_mock_engine + + def dump(sql, *multiparams, **params): + print(sql.compile(dialect=engine.dialect)) + + engine = create_mock_engine('postgresql+psycopg2://', dump) + metadata.create_all(engine, checkfirst=False) + + :param url: A string URL which typically needs to contain only the + database backend name. + + :param executor: a callable which receives the arguments ``sql``, + ``*multiparams`` and ``**params``. The ``sql`` parameter is typically + an instance of :class:`.ExecutableDDLElement`, which can then be compiled + into a string using :meth:`.ExecutableDDLElement.compile`. + + .. versionadded:: 1.4 - the :func:`.create_mock_engine` function replaces + the previous "mock" engine strategy used with + :func:`_sa.create_engine`. + + .. seealso:: + + :ref:`faq_ddl_as_string` + + """ + + # create url.URL object + u = _url.make_url(url) + + dialect_cls = u.get_dialect() + + dialect_args = {} + # consume dialect arguments from kwargs + for k in util.get_cls_kwargs(dialect_cls): + if k in kw: + dialect_args[k] = kw.pop(k) + + # create dialect + dialect = dialect_cls(**dialect_args) + + return MockConnection(dialect, executor) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/processors.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/processors.py new file mode 100644 index 0000000..610e03d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/processors.py @@ -0,0 +1,61 @@ +# engine/processors.py +# Copyright (C) 2010-2024 the SQLAlchemy authors and contributors +# +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" +from __future__ import annotations + +import typing + +from ._py_processors import str_to_datetime_processor_factory # noqa +from ..util._has_cy import HAS_CYEXTENSION + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_processors import int_to_boolean as int_to_boolean + from ._py_processors import str_to_date as str_to_date + from ._py_processors import str_to_datetime as str_to_datetime + from ._py_processors import str_to_time as str_to_time + from ._py_processors import ( + to_decimal_processor_factory as to_decimal_processor_factory, + ) + from ._py_processors import to_float as to_float + from ._py_processors import to_str as to_str +else: + from sqlalchemy.cyextension.processors import ( + DecimalResultProcessor, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401 + int_to_boolean as int_to_boolean, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + str_to_date as str_to_date, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401 + str_to_datetime as str_to_datetime, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + str_to_time as str_to_time, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + to_float as to_float, + ) + from sqlalchemy.cyextension.processors import ( # noqa: F401,E501 + to_str as to_str, + ) + + def to_decimal_processor_factory(target_class, scale): + # Note that the scale argument is not taken into account for integer + # values in the C implementation while it is in the Python one. + # For example, the Python implementation might return + # Decimal('5.00000') whereas the C implementation will + # return Decimal('5'). These are equivalent of course. + return DecimalResultProcessor(target_class, "%%.%df" % scale).process diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/reflection.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/reflection.py new file mode 100644 index 0000000..ef1e566 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/reflection.py @@ -0,0 +1,2089 @@ +# engine/reflection.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides an abstraction for obtaining database schema information. + +Usage Notes: + +Here are some general conventions when accessing the low level inspector +methods such as get_table_names, get_columns, etc. + +1. Inspector methods return lists of dicts in most cases for the following + reasons: + + * They're both standard types that can be serialized. + * Using a dict instead of a tuple allows easy expansion of attributes. + * Using a list for the outer structure maintains order and is easy to work + with (e.g. list comprehension [d['name'] for d in cols]). + +2. Records that contain a name, such as the column name in a column record + use the key 'name'. So for most return values, each record will have a + 'name' attribute.. +""" +from __future__ import annotations + +import contextlib +from dataclasses import dataclass +from enum import auto +from enum import Flag +from enum import unique +from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .base import Connection +from .base import Engine +from .. import exc +from .. import inspection +from .. import sql +from .. import util +from ..sql import operators +from ..sql import schema as sa_schema +from ..sql.cache_key import _ad_hoc_cache_key_from_args +from ..sql.elements import TextClause +from ..sql.type_api import TypeEngine +from ..sql.visitors import InternalTraversal +from ..util import topological +from ..util.typing import final + +if TYPE_CHECKING: + from .interfaces import Dialect + from .interfaces import ReflectedCheckConstraint + from .interfaces import ReflectedColumn + from .interfaces import ReflectedForeignKeyConstraint + from .interfaces import ReflectedIndex + from .interfaces import ReflectedPrimaryKeyConstraint + from .interfaces import ReflectedTableComment + from .interfaces import ReflectedUniqueConstraint + from .interfaces import TableKey + +_R = TypeVar("_R") + + +@util.decorator +def cache( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, +) -> _R: + info_cache = kw.get("info_cache", None) + if info_cache is None: + return fn(self, con, *args, **kw) + exclude = {"info_cache", "unreflectable"} + key = ( + fn.__name__, + tuple(a for a in args if isinstance(a, str)), + tuple((k, v) for k, v in kw.items() if k not in exclude), + ) + ret: _R = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + +def flexi_cache( + *traverse_args: Tuple[str, InternalTraversal] +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: + @util.decorator + def go( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, + ) -> _R: + info_cache = kw.get("info_cache", None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = _ad_hoc_cache_key_from_args((fn.__name__,), traverse_args, args) + ret: _R = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + return go + + +@unique +class ObjectKind(Flag): + """Enumerator that indicates which kind of object to return when calling + the ``get_multi`` methods. + + This is a Flag enum, so custom combinations can be passed. For example, + to reflect tables and plain views ``ObjectKind.TABLE | ObjectKind.VIEW`` + may be used. + + .. note:: + Not all dialect may support all kind of object. If a dialect does + not support a particular object an empty dict is returned. + In case a dialect supports an object, but the requested method + is not applicable for the specified kind the default value + will be returned for each reflected object. For example reflecting + check constraints of view return a dict with all the views with + empty lists as values. + """ + + TABLE = auto() + "Reflect table objects" + VIEW = auto() + "Reflect plain view objects" + MATERIALIZED_VIEW = auto() + "Reflect materialized view object" + + ANY_VIEW = VIEW | MATERIALIZED_VIEW + "Reflect any kind of view objects" + ANY = TABLE | VIEW | MATERIALIZED_VIEW + "Reflect all type of objects" + + +@unique +class ObjectScope(Flag): + """Enumerator that indicates which scope to use when calling + the ``get_multi`` methods. + """ + + DEFAULT = auto() + "Include default scope" + TEMPORARY = auto() + "Include only temp scope" + ANY = DEFAULT | TEMPORARY + "Include both default and temp scope" + + +@inspection._self_inspects +class Inspector(inspection.Inspectable["Inspector"]): + """Performs database schema inspection. + + The Inspector acts as a proxy to the reflection methods of the + :class:`~sqlalchemy.engine.interfaces.Dialect`, providing a + consistent interface as well as caching support for previously + fetched metadata. + + A :class:`_reflection.Inspector` object is usually created via the + :func:`_sa.inspect` function, which may be passed an + :class:`_engine.Engine` + or a :class:`_engine.Connection`:: + + from sqlalchemy import inspect, create_engine + engine = create_engine('...') + insp = inspect(engine) + + Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` associated + with the engine may opt to return an :class:`_reflection.Inspector` + subclass that + provides additional methods specific to the dialect's target database. + + """ + + bind: Union[Engine, Connection] + engine: Engine + _op_context_requires_connect: bool + dialect: Dialect + info_cache: Dict[Any, Any] + + @util.deprecated( + "1.4", + "The __init__() method on :class:`_reflection.Inspector` " + "is deprecated and " + "will be removed in a future release. Please use the " + ":func:`.sqlalchemy.inspect` " + "function on an :class:`_engine.Engine` or " + ":class:`_engine.Connection` " + "in order to " + "acquire an :class:`_reflection.Inspector`.", + ) + def __init__(self, bind: Union[Engine, Connection]): + """Initialize a new :class:`_reflection.Inspector`. + + :param bind: a :class:`~sqlalchemy.engine.Connection`, + which is typically an instance of + :class:`~sqlalchemy.engine.Engine` or + :class:`~sqlalchemy.engine.Connection`. + + For a dialect-specific instance of :class:`_reflection.Inspector`, see + :meth:`_reflection.Inspector.from_engine` + + """ + self._init_legacy(bind) + + @classmethod + def _construct( + cls, init: Callable[..., Any], bind: Union[Engine, Connection] + ) -> Inspector: + if hasattr(bind.dialect, "inspector"): + cls = bind.dialect.inspector + + self = cls.__new__(cls) + init(self, bind) + return self + + def _init_legacy(self, bind: Union[Engine, Connection]) -> None: + if hasattr(bind, "exec_driver_sql"): + self._init_connection(bind) # type: ignore[arg-type] + else: + self._init_engine(bind) + + def _init_engine(self, engine: Engine) -> None: + self.bind = self.engine = engine + engine.connect().close() + self._op_context_requires_connect = True + self.dialect = self.engine.dialect + self.info_cache = {} + + def _init_connection(self, connection: Connection) -> None: + self.bind = connection + self.engine = connection.engine + self._op_context_requires_connect = False + self.dialect = self.engine.dialect + self.info_cache = {} + + def clear_cache(self) -> None: + """reset the cache for this :class:`.Inspector`. + + Inspection methods that have data cached will emit SQL queries + when next called to get new data. + + .. versionadded:: 2.0 + + """ + self.info_cache.clear() + + @classmethod + @util.deprecated( + "1.4", + "The from_engine() method on :class:`_reflection.Inspector` " + "is deprecated and " + "will be removed in a future release. Please use the " + ":func:`.sqlalchemy.inspect` " + "function on an :class:`_engine.Engine` or " + ":class:`_engine.Connection` " + "in order to " + "acquire an :class:`_reflection.Inspector`.", + ) + def from_engine(cls, bind: Engine) -> Inspector: + """Construct a new dialect-specific Inspector object from the given + engine or connection. + + :param bind: a :class:`~sqlalchemy.engine.Connection` + or :class:`~sqlalchemy.engine.Engine`. + + This method differs from direct a direct constructor call of + :class:`_reflection.Inspector` in that the + :class:`~sqlalchemy.engine.interfaces.Dialect` is given a chance to + provide a dialect-specific :class:`_reflection.Inspector` instance, + which may + provide additional methods. + + See the example at :class:`_reflection.Inspector`. + + """ + return cls._construct(cls._init_legacy, bind) + + @inspection._inspects(Engine) + def _engine_insp(bind: Engine) -> Inspector: # type: ignore[misc] + return Inspector._construct(Inspector._init_engine, bind) + + @inspection._inspects(Connection) + def _connection_insp(bind: Connection) -> Inspector: # type: ignore[misc] + return Inspector._construct(Inspector._init_connection, bind) + + @contextlib.contextmanager + def _operation_context(self) -> Generator[Connection, None, None]: + """Return a context that optimizes for multiple operations on a single + transaction. + + This essentially allows connect()/close() to be called if we detected + that we're against an :class:`_engine.Engine` and not a + :class:`_engine.Connection`. + + """ + conn: Connection + if self._op_context_requires_connect: + conn = self.bind.connect() # type: ignore[union-attr] + else: + conn = self.bind # type: ignore[assignment] + try: + yield conn + finally: + if self._op_context_requires_connect: + conn.close() + + @contextlib.contextmanager + def _inspection_context(self) -> Generator[Inspector, None, None]: + """Return an :class:`_reflection.Inspector` + from this one that will run all + operations on a single connection. + + """ + + with self._operation_context() as conn: + sub_insp = self._construct(self.__class__._init_connection, conn) + sub_insp.info_cache = self.info_cache + yield sub_insp + + @property + def default_schema_name(self) -> Optional[str]: + """Return the default schema name presented by the dialect + for the current engine's database user. + + E.g. this is typically ``public`` for PostgreSQL and ``dbo`` + for SQL Server. + + """ + return self.dialect.default_schema_name + + def get_schema_names(self, **kw: Any) -> List[str]: + r"""Return all schema names. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + """ + + with self._operation_context() as conn: + return self.dialect.get_schema_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_table_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all table names within a particular schema. + + The names are expected to be real tables only, not views. + Views are instead returned using the + :meth:`_reflection.Inspector.get_view_names` and/or + :meth:`_reflection.Inspector.get_materialized_view_names` + methods. + + :param schema: Schema name. If ``schema`` is left at ``None``, the + database's default schema is + used, else the named schema is searched. If the database does not + support named schemas, behavior is undefined if ``schema`` is not + passed as ``None``. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. seealso:: + + :meth:`_reflection.Inspector.get_sorted_table_and_fkc_names` + + :attr:`_schema.MetaData.sorted_tables` + + """ + + with self._operation_context() as conn: + return self.dialect.get_table_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def has_table( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a table, view, or temporary + table of the given name. + + :param table_name: name of the table to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method + replaces the :meth:`_engine.Engine.has_table` method. + + .. versionchanged:: 2.0:: :meth:`.Inspector.has_table` now formally + supports checking for additional table-like objects: + + * any type of views (plain or materialized) + * temporary tables of any kind + + Previously, these two checks were not formally specified and + different dialects would vary in their behavior. The dialect + testing suite now includes tests for all of these object types + and should be supported by all SQLAlchemy-included dialects. + Support among third party dialects may be lagging, however. + + """ + with self._operation_context() as conn: + return self.dialect.has_table( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def has_sequence( + self, sequence_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a sequence with the given name. + + :param sequence_name: name of the sequence to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 1.4 + + """ + with self._operation_context() as conn: + return self.dialect.has_sequence( + conn, sequence_name, schema, info_cache=self.info_cache, **kw + ) + + def has_index( + self, + table_name: str, + index_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + r"""Check the existence of a particular index name in the database. + + :param table_name: the name of the table the index belongs to + :param index_name: the name of the index to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_index( + conn, + table_name, + index_name, + schema, + info_cache=self.info_cache, + **kw, + ) + + def has_schema(self, schema_name: str, **kw: Any) -> bool: + r"""Return True if the backend has a schema with the given name. + + :param schema_name: name of the schema to check + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_schema( + conn, schema_name, info_cache=self.info_cache, **kw + ) + + def get_sorted_table_and_fkc_names( + self, + schema: Optional[str] = None, + **kw: Any, + ) -> List[Tuple[Optional[str], List[Tuple[str, Optional[str]]]]]: + r"""Return dependency-sorted table and foreign key constraint names in + referred to within a particular schema. + + This will yield 2-tuples of + ``(tablename, [(tname, fkname), (tname, fkname), ...])`` + consisting of table names in CREATE order grouped with the foreign key + constraint names that are not detected as belonging to a cycle. + The final element + will be ``(None, [(tname, fkname), (tname, fkname), ..])`` + which will consist of remaining + foreign key constraint names that would require a separate CREATE + step after-the-fact, based on dependencies between tables. + + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. seealso:: + + :meth:`_reflection.Inspector.get_table_names` + + :func:`.sort_tables_and_constraints` - similar method which works + with an already-given :class:`_schema.MetaData`. + + """ + + return [ + ( + table_key[1] if table_key else None, + [(tname, fks) for (_, tname), fks in fk_collection], + ) + for ( + table_key, + fk_collection, + ) in self.sort_tables_on_foreign_key_dependency( + consider_schemas=(schema,) + ) + ] + + def sort_tables_on_foreign_key_dependency( + self, + consider_schemas: Collection[Optional[str]] = (None,), + **kw: Any, + ) -> List[ + Tuple[ + Optional[Tuple[Optional[str], str]], + List[Tuple[Tuple[Optional[str], str], Optional[str]]], + ] + ]: + r"""Return dependency-sorted table and foreign key constraint names + referred to within multiple schemas. + + This method may be compared to + :meth:`.Inspector.get_sorted_table_and_fkc_names`, which + works on one schema at a time; here, the method is a generalization + that will consider multiple schemas at once including that it will + resolve for cross-schema foreign keys. + + .. versionadded:: 2.0 + + """ + SchemaTab = Tuple[Optional[str], str] + + tuples: Set[Tuple[SchemaTab, SchemaTab]] = set() + remaining_fkcs: Set[Tuple[SchemaTab, Optional[str]]] = set() + fknames_for_table: Dict[SchemaTab, Set[Optional[str]]] = {} + tnames: List[SchemaTab] = [] + + for schname in consider_schemas: + schema_fkeys = self.get_multi_foreign_keys(schname, **kw) + tnames.extend(schema_fkeys) + for (_, tname), fkeys in schema_fkeys.items(): + fknames_for_table[(schname, tname)] = { + fk["name"] for fk in fkeys + } + for fkey in fkeys: + if ( + tname != fkey["referred_table"] + or schname != fkey["referred_schema"] + ): + tuples.add( + ( + ( + fkey["referred_schema"], + fkey["referred_table"], + ), + (schname, tname), + ) + ) + try: + candidate_sort = list(topological.sort(tuples, tnames)) + except exc.CircularDependencyError as err: + edge: Tuple[SchemaTab, SchemaTab] + for edge in err.edges: + tuples.remove(edge) + remaining_fkcs.update( + (edge[1], fkc) for fkc in fknames_for_table[edge[1]] + ) + + candidate_sort = list(topological.sort(tuples, tnames)) + ret: List[ + Tuple[Optional[SchemaTab], List[Tuple[SchemaTab, Optional[str]]]] + ] + ret = [ + ( + (schname, tname), + [ + ((schname, tname), fk) + for fk in fknames_for_table[(schname, tname)].difference( + name for _, name in remaining_fkcs + ) + ], + ) + for (schname, tname) in candidate_sort + ] + return ret + [(None, list(remaining_fkcs))] + + def get_temp_table_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary table names for the current bind. + + This method is unsupported by most dialects; currently + only Oracle, PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + """ + + with self._operation_context() as conn: + return self.dialect.get_temp_table_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_temp_view_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary view names for the current bind. + + This method is unsupported by most dialects; currently + only PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + """ + with self._operation_context() as conn: + return self.dialect.get_temp_view_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_table_options( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> Dict[str, Any]: + r"""Return a dictionary of options specified when the table of the + given name was created. + + This currently includes some options that apply to MySQL and Oracle + tables. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dict with the table options. The returned keys depend on the + dialect in use. Each one is prefixed with the dialect name. + + .. seealso:: :meth:`Inspector.get_multi_table_options` + + """ + with self._operation_context() as conn: + return self.dialect.get_table_options( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_table_options( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, Dict[str, Any]]: + r"""Return a dictionary of options specified when the tables in the + given schema were created. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + This currently includes some options that apply to MySQL and Oracle + tables. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if options of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries with the table options. + The returned keys in each dict depend on the + dialect in use. Each one is prefixed with the dialect name. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_table_options` + """ + with self._operation_context() as conn: + res = self.dialect.get_multi_table_options( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + return dict(res) + + def get_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all non-materialized view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + + .. versionchanged:: 2.0 For those dialects that previously included + the names of materialized views in this list (currently PostgreSQL), + this method no longer returns the names of materialized views. + the :meth:`.Inspector.get_materialized_view_names` method should + be used instead. + + .. seealso:: + + :meth:`.Inspector.get_materialized_view_names` + + """ + + with self._operation_context() as conn: + return self.dialect.get_view_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_materialized_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all materialized view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Inspector.get_view_names` + + """ + + with self._operation_context() as conn: + return self.dialect.get_materialized_view_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_sequence_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all sequence names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + """ + + with self._operation_context() as conn: + return self.dialect.get_sequence_names( + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_view_definition( + self, view_name: str, schema: Optional[str] = None, **kw: Any + ) -> str: + r"""Return definition for the plain or materialized view called + ``view_name``. + + :param view_name: Name of the view. + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + """ + + with self._operation_context() as conn: + return self.dialect.get_view_definition( + conn, view_name, schema, info_cache=self.info_cache, **kw + ) + + def get_columns( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedColumn]: + r"""Return information about columns in ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, + return column information as a list of :class:`.ReflectedColumn`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: list of dictionaries, each representing the definition of + a database column. + + .. seealso:: :meth:`Inspector.get_multi_columns`. + + """ + + with self._operation_context() as conn: + col_defs = self.dialect.get_columns( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + if col_defs: + self._instantiate_types([col_defs]) + return col_defs + + def _instantiate_types( + self, data: Iterable[List[ReflectedColumn]] + ) -> None: + # make this easy and only return instances for coltype + for col_defs in data: + for col_def in col_defs: + coltype = col_def["type"] + if not isinstance(coltype, TypeEngine): + col_def["type"] = coltype() + + def get_multi_columns( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedColumn]]: + r"""Return information about columns in all objects in the given + schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of :class:`.ReflectedColumn`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if columns of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a database column. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_columns` + """ + + with self._operation_context() as conn: + table_col_defs = dict( + self.dialect.get_multi_columns( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + self._instantiate_types(table_col_defs.values()) + return table_col_defs + + def get_pk_constraint( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedPrimaryKeyConstraint: + r"""Return information about primary key constraint in ``table_name``. + + Given a string ``table_name``, and an optional string `schema`, return + primary key information as a :class:`.ReflectedPrimaryKeyConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary representing the definition of + a primary key constraint. + + .. seealso:: :meth:`Inspector.get_multi_pk_constraint` + """ + with self._operation_context() as conn: + return self.dialect.get_pk_constraint( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_pk_constraint( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedPrimaryKeyConstraint]: + r"""Return information about primary key constraints in + all tables in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a :class:`.ReflectedPrimaryKeyConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if primary keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, each representing the + definition of a primary key constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_pk_constraint` + """ + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_pk_constraint( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_foreign_keys( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedForeignKeyConstraint]: + r"""Return information about foreign_keys in ``table_name``. + + Given a string ``table_name``, and an optional string `schema`, return + foreign key information as a list of + :class:`.ReflectedForeignKeyConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + a foreign key definition. + + .. seealso:: :meth:`Inspector.get_multi_foreign_keys` + """ + + with self._operation_context() as conn: + return self.dialect.get_foreign_keys( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_foreign_keys( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedForeignKeyConstraint]]: + r"""Return information about foreign_keys in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of + :class:`.ReflectedForeignKeyConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if foreign keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing + a foreign key definition. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_foreign_keys` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_foreign_keys( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_indexes( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedIndex]: + r"""Return information about indexes in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return + index information as a list of :class:`.ReflectedIndex`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an index. + + .. seealso:: :meth:`Inspector.get_multi_indexes` + """ + + with self._operation_context() as conn: + return self.dialect.get_indexes( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_indexes( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedIndex]]: + r"""Return information about indexes in in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of :class:`.ReflectedIndex`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if indexes of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an index. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_indexes` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_indexes( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_unique_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedUniqueConstraint]: + r"""Return information about unique constraints in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return + unique constraint information as a list of + :class:`.ReflectedUniqueConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an unique constraint. + + .. seealso:: :meth:`Inspector.get_multi_unique_constraints` + """ + + with self._operation_context() as conn: + return self.dialect.get_unique_constraints( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_unique_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedUniqueConstraint]]: + r"""Return information about unique constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of + :class:`.ReflectedUniqueConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an unique constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_unique_constraints` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_unique_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_table_comment( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedTableComment: + r"""Return information about the table comment for ``table_name``. + + Given a string ``table_name`` and an optional string ``schema``, + return table comment information as a :class:`.ReflectedTableComment`. + + Raises ``NotImplementedError`` for a dialect that does not support + comments. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary, with the table comment. + + .. versionadded:: 1.2 + + .. seealso:: :meth:`Inspector.get_multi_table_comment` + """ + + with self._operation_context() as conn: + return self.dialect.get_table_comment( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_table_comment( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedTableComment]: + r"""Return information about the table comment in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a :class:`.ReflectedTableComment`. + + Raises ``NotImplementedError`` for a dialect that does not support + comments. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if comments of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, representing the + table comments. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_table_comment` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_table_comment( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_check_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedCheckConstraint]: + r"""Return information about check constraints in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return + check constraint information as a list of + :class:`.ReflectedCheckConstraint`. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of a check constraints. + + .. seealso:: :meth:`Inspector.get_multi_check_constraints` + """ + + with self._operation_context() as conn: + return self.dialect.get_check_constraints( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_check_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedCheckConstraint]]: + r"""Return information about check constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + For each table the value is a list of + :class:`.ReflectedCheckConstraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a check constraints. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + + .. seealso:: :meth:`Inspector.get_check_constraints` + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_check_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def reflect_table( + self, + table: sa_schema.Table, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str] = (), + resolve_fks: bool = True, + _extend_on: Optional[Set[sa_schema.Table]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + ) -> None: + """Given a :class:`_schema.Table` object, load its internal + constructs based on introspection. + + This is the underlying method used by most dialects to produce + table reflection. Direct usage is like:: + + from sqlalchemy import create_engine, MetaData, Table + from sqlalchemy import inspect + + engine = create_engine('...') + meta = MetaData() + user_table = Table('user', meta) + insp = inspect(engine) + insp.reflect_table(user_table, None) + + .. versionchanged:: 1.4 Renamed from ``reflecttable`` to + ``reflect_table`` + + :param table: a :class:`~sqlalchemy.schema.Table` instance. + :param include_columns: a list of string column names to include + in the reflection process. If ``None``, all columns are reflected. + + """ + + if _extend_on is not None: + if table in _extend_on: + return + else: + _extend_on.add(table) + + dialect = self.bind.dialect + + with self._operation_context() as conn: + schema = conn.schema_for_object(table) + + table_name = table.name + + # get table-level arguments that are specifically + # intended for reflection, e.g. oracle_resolve_synonyms. + # these are unconditionally passed to related Table + # objects + reflection_options = { + k: table.dialect_kwargs.get(k) + for k in dialect.reflection_options + if k in table.dialect_kwargs + } + + table_key = (schema, table_name) + if _reflect_info is None or table_key not in _reflect_info.columns: + _reflect_info = self._get_reflection_info( + schema, + filter_names=[table_name], + kind=ObjectKind.ANY, + scope=ObjectScope.ANY, + _reflect_info=_reflect_info, + **table.dialect_kwargs, + ) + if table_key in _reflect_info.unreflectable: + raise _reflect_info.unreflectable[table_key] + + if table_key not in _reflect_info.columns: + raise exc.NoSuchTableError(table_name) + + # reflect table options, like mysql_engine + if _reflect_info.table_options: + tbl_opts = _reflect_info.table_options.get(table_key) + if tbl_opts: + # add additional kwargs to the Table if the dialect + # returned them + table._validate_dialect_kwargs(tbl_opts) + + found_table = False + cols_by_orig_name: Dict[str, sa_schema.Column[Any]] = {} + + for col_d in _reflect_info.columns[table_key]: + found_table = True + + self._reflect_column( + table, + col_d, + include_columns, + exclude_columns, + cols_by_orig_name, + ) + + # NOTE: support tables/views with no columns + if not found_table and not self.has_table(table_name, schema): + raise exc.NoSuchTableError(table_name) + + self._reflect_pk( + _reflect_info, table_key, table, cols_by_orig_name, exclude_columns + ) + + self._reflect_fk( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + resolve_fks, + _extend_on, + reflection_options, + ) + + self._reflect_indexes( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) + + self._reflect_unique_constraints( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) + + self._reflect_check_constraints( + _reflect_info, + table_key, + table, + cols_by_orig_name, + include_columns, + exclude_columns, + reflection_options, + ) + + self._reflect_table_comment( + _reflect_info, + table_key, + table, + reflection_options, + ) + + def _reflect_column( + self, + table: sa_schema.Table, + col_d: ReflectedColumn, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + ) -> None: + orig_name = col_d["name"] + + table.metadata.dispatch.column_reflect(self, table, col_d) + table.dispatch.column_reflect(self, table, col_d) + + # fetch name again as column_reflect is allowed to + # change it + name = col_d["name"] + if (include_columns and name not in include_columns) or ( + exclude_columns and name in exclude_columns + ): + return + + coltype = col_d["type"] + + col_kw = { + k: col_d[k] # type: ignore[literal-required] + for k in [ + "nullable", + "autoincrement", + "quote", + "info", + "key", + "comment", + ] + if k in col_d + } + + if "dialect_options" in col_d: + col_kw.update(col_d["dialect_options"]) + + colargs = [] + default: Any + if col_d.get("default") is not None: + default_text = col_d["default"] + assert default_text is not None + if isinstance(default_text, TextClause): + default = sa_schema.DefaultClause( + default_text, _reflected=True + ) + elif not isinstance(default_text, sa_schema.FetchedValue): + default = sa_schema.DefaultClause( + sql.text(default_text), _reflected=True + ) + else: + default = default_text + colargs.append(default) + + if "computed" in col_d: + computed = sa_schema.Computed(**col_d["computed"]) + colargs.append(computed) + + if "identity" in col_d: + identity = sa_schema.Identity(**col_d["identity"]) + colargs.append(identity) + + cols_by_orig_name[orig_name] = col = sa_schema.Column( + name, coltype, *colargs, **col_kw + ) + + if col.key in table.primary_key: + col.primary_key = True + table.append_column(col, replace_existing=True) + + def _reflect_pk( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + exclude_columns: Collection[str], + ) -> None: + pk_cons = _reflect_info.pk_constraint.get(table_key) + if pk_cons: + pk_cols = [ + cols_by_orig_name[pk] + for pk in pk_cons["constrained_columns"] + if pk in cols_by_orig_name and pk not in exclude_columns + ] + + # update pk constraint name and comment + table.primary_key.name = pk_cons.get("name") + table.primary_key.comment = pk_cons.get("comment", None) + + # tell the PKConstraint to re-initialize + # its column collection + table.primary_key._reload(pk_cols) + + def _reflect_fk( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + resolve_fks: bool, + _extend_on: Optional[Set[sa_schema.Table]], + reflection_options: Dict[str, Any], + ) -> None: + fkeys = _reflect_info.foreign_keys.get(table_key, []) + for fkey_d in fkeys: + conname = fkey_d["name"] + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_columns = [ + cols_by_orig_name[c].key if c in cols_by_orig_name else c + for c in fkey_d["constrained_columns"] + ] + + if ( + exclude_columns + and set(constrained_columns).intersection(exclude_columns) + or ( + include_columns + and set(constrained_columns).difference(include_columns) + ) + ): + continue + + referred_schema = fkey_d["referred_schema"] + referred_table = fkey_d["referred_table"] + referred_columns = fkey_d["referred_columns"] + refspec = [] + if referred_schema is not None: + if resolve_fks: + sa_schema.Table( + referred_table, + table.metadata, + schema=referred_schema, + autoload_with=self.bind, + _extend_on=_extend_on, + _reflect_info=_reflect_info, + **reflection_options, + ) + for column in referred_columns: + refspec.append( + ".".join([referred_schema, referred_table, column]) + ) + else: + if resolve_fks: + sa_schema.Table( + referred_table, + table.metadata, + autoload_with=self.bind, + schema=sa_schema.BLANK_SCHEMA, + _extend_on=_extend_on, + _reflect_info=_reflect_info, + **reflection_options, + ) + for column in referred_columns: + refspec.append(".".join([referred_table, column])) + if "options" in fkey_d: + options = fkey_d["options"] + else: + options = {} + + try: + table.append_constraint( + sa_schema.ForeignKeyConstraint( + constrained_columns, + refspec, + conname, + link_to_name=True, + comment=fkey_d.get("comment"), + **options, + ) + ) + except exc.ConstraintColumnNotFoundError: + util.warn( + f"On reflected table {table.name}, skipping reflection of " + "foreign key constraint " + f"{conname}; one or more subject columns within " + f"name(s) {', '.join(constrained_columns)} are not " + "present in the table" + ) + + _index_sort_exprs = { + "asc": operators.asc_op, + "desc": operators.desc_op, + "nulls_first": operators.nulls_first_op, + "nulls_last": operators.nulls_last_op, + } + + def _reflect_indexes( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + # Indexes + indexes = _reflect_info.indexes.get(table_key, []) + for index_d in indexes: + name = index_d["name"] + columns = index_d["column_names"] + expressions = index_d.get("expressions") + column_sorting = index_d.get("column_sorting", {}) + unique = index_d["unique"] + flavor = index_d.get("type", "index") + dialect_options = index_d.get("dialect_options", {}) + + duplicates = index_d.get("duplicates_constraint") + if include_columns and not set(columns).issubset(include_columns): + continue + if duplicates: + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + idx_element: Any + idx_elements = [] + for index, c in enumerate(columns): + if c is None: + if not expressions: + util.warn( + f"Skipping {flavor} {name!r} because key " + f"{index + 1} reflected as None but no " + "'expressions' were returned" + ) + break + idx_element = sql.text(expressions[index]) + else: + try: + if c in cols_by_orig_name: + idx_element = cols_by_orig_name[c] + else: + idx_element = table.c[c] + except KeyError: + util.warn( + f"{flavor} key {c!r} was not located in " + f"columns for table {table.name!r}" + ) + continue + for option in column_sorting.get(c, ()): + if option in self._index_sort_exprs: + op = self._index_sort_exprs[option] + idx_element = op(idx_element) + idx_elements.append(idx_element) + else: + sa_schema.Index( + name, + *idx_elements, + _table=table, + unique=unique, + **dialect_options, + ) + + def _reflect_unique_constraints( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.unique_constraints.get(table_key, []) + # Unique Constraints + for const_d in constraints: + conname = const_d["name"] + columns = const_d["column_names"] + comment = const_d.get("comment") + duplicates = const_d.get("duplicates_index") + dialect_options = const_d.get("dialect_options", {}) + if include_columns and not set(columns).issubset(include_columns): + continue + if duplicates: + continue + # look for columns by orig name in cols_by_orig_name, + # but support columns that are in-Python only as fallback + constrained_cols = [] + for c in columns: + try: + constrained_col = ( + cols_by_orig_name[c] + if c in cols_by_orig_name + else table.c[c] + ) + except KeyError: + util.warn( + "unique constraint key '%s' was not located in " + "columns for table '%s'" % (c, table.name) + ) + else: + constrained_cols.append(constrained_col) + table.append_constraint( + sa_schema.UniqueConstraint( + *constrained_cols, + name=conname, + comment=comment, + **dialect_options, + ) + ) + + def _reflect_check_constraints( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.check_constraints.get(table_key, []) + for const_d in constraints: + table.append_constraint(sa_schema.CheckConstraint(**const_d)) + + def _reflect_table_comment( + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + reflection_options: Dict[str, Any], + ) -> None: + comment_dict = _reflect_info.table_comment.get(table_key) + if comment_dict: + table.comment = comment_dict["text"] + + def _get_reflection_info( + self, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + available: Optional[Collection[str]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + **kw: Any, + ) -> _ReflectionInfo: + kw["schema"] = schema + + if filter_names and available and len(filter_names) > 100: + fraction = len(filter_names) / len(available) + else: + fraction = None + + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + kw["unreflectable"] = unreflectable = {} + + has_result: bool = True + + def run( + meth: Any, + *, + optional: bool = False, + check_filter_names_from_meth: bool = False, + ) -> Any: + nonlocal has_result + # simple heuristic to improve reflection performance if a + # dialect implements multi_reflection: + # if more than 50% of the tables in the db are in filter_names + # load all the tables, since it's most likely faster to avoid + # a filter on that many tables. + if ( + fraction is None + or fraction <= 0.5 + or not self.dialect._overrides_default(meth.__name__) + ): + _fn = filter_names + else: + _fn = None + try: + if has_result: + res = meth(filter_names=_fn, **kw) + if check_filter_names_from_meth and not res: + # method returned no result data. + # skip any future call methods + has_result = False + else: + res = {} + except NotImplementedError: + if not optional: + raise + res = {} + return res + + info = _ReflectionInfo( + columns=run( + self.get_multi_columns, check_filter_names_from_meth=True + ), + pk_constraint=run(self.get_multi_pk_constraint), + foreign_keys=run(self.get_multi_foreign_keys), + indexes=run(self.get_multi_indexes), + unique_constraints=run( + self.get_multi_unique_constraints, optional=True + ), + table_comment=run(self.get_multi_table_comment, optional=True), + check_constraints=run( + self.get_multi_check_constraints, optional=True + ), + table_options=run(self.get_multi_table_options, optional=True), + unreflectable=unreflectable, + ) + if _reflect_info: + _reflect_info.update(info) + return _reflect_info + else: + return info + + +@final +class ReflectionDefaults: + """provides blank default values for reflection methods.""" + + @classmethod + def columns(cls) -> List[ReflectedColumn]: + return [] + + @classmethod + def pk_constraint(cls) -> ReflectedPrimaryKeyConstraint: + return { + "name": None, + "constrained_columns": [], + } + + @classmethod + def foreign_keys(cls) -> List[ReflectedForeignKeyConstraint]: + return [] + + @classmethod + def indexes(cls) -> List[ReflectedIndex]: + return [] + + @classmethod + def unique_constraints(cls) -> List[ReflectedUniqueConstraint]: + return [] + + @classmethod + def check_constraints(cls) -> List[ReflectedCheckConstraint]: + return [] + + @classmethod + def table_options(cls) -> Dict[str, Any]: + return {} + + @classmethod + def table_comment(cls) -> ReflectedTableComment: + return {"text": None} + + +@dataclass +class _ReflectionInfo: + columns: Dict[TableKey, List[ReflectedColumn]] + pk_constraint: Dict[TableKey, Optional[ReflectedPrimaryKeyConstraint]] + foreign_keys: Dict[TableKey, List[ReflectedForeignKeyConstraint]] + indexes: Dict[TableKey, List[ReflectedIndex]] + # optionals + unique_constraints: Dict[TableKey, List[ReflectedUniqueConstraint]] + table_comment: Dict[TableKey, Optional[ReflectedTableComment]] + check_constraints: Dict[TableKey, List[ReflectedCheckConstraint]] + table_options: Dict[TableKey, Dict[str, Any]] + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + + def update(self, other: _ReflectionInfo) -> None: + for k, v in self.__dict__.items(): + ov = getattr(other, k) + if ov is not None: + if v is None: + setattr(self, k, ov) + else: + v.update(ov) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/result.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/result.py new file mode 100644 index 0000000..56b3a68 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/result.py @@ -0,0 +1,2382 @@ +# engine/result.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Define generic result set constructs.""" + +from __future__ import annotations + +from enum import Enum +import functools +import itertools +import operator +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .row import Row +from .row import RowMapping +from .. import exc +from .. import util +from ..sql.base import _generative +from ..sql.base import HasMemoized +from ..sql.base import InPlaceGenerative +from ..util import HasMemoized_ro_memoized_attribute +from ..util import NONE_SET +from ..util._has_cy import HAS_CYEXTENSION +from ..util.typing import Literal +from ..util.typing import Self + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_row import tuplegetter as tuplegetter +else: + from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter + +if typing.TYPE_CHECKING: + from ..sql.schema import Column + from ..sql.type_api import _ResultProcessorType + +_KeyType = Union[str, "Column[Any]"] +_KeyIndexType = Union[str, "Column[Any]", int] + +# is overridden in cursor using _CursorKeyMapRecType +_KeyMapRecType = Any + +_KeyMapType = Mapping[_KeyType, _KeyMapRecType] + + +_RowData = Union[Row[Any], RowMapping, Any] +"""A generic form of "row" that accommodates for the different kinds of +"rows" that different result objects return, including row, row mapping, and +scalar values""" + +_RawRowType = Tuple[Any, ...] +"""represents the kind of row we get from a DBAPI cursor""" + +_R = TypeVar("_R", bound=_RowData) +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_InterimRowType = Union[_R, _RawRowType] +"""a catchall "anything" kind of return type that can be applied +across all the result types + +""" + +_InterimSupportsScalarsRowType = Union[Row[Any], Any] + +_ProcessorsType = Sequence[Optional["_ResultProcessorType[Any]"]] +_TupleGetterType = Callable[[Sequence[Any]], Sequence[Any]] +_UniqueFilterType = Callable[[Any], Any] +_UniqueFilterStateType = Tuple[Set[Any], Optional[_UniqueFilterType]] + + +class ResultMetaData: + """Base for metadata about result rows.""" + + __slots__ = () + + _tuplefilter: Optional[_TupleGetterType] = None + _translated_indexes: Optional[Sequence[int]] = None + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None + _keymap: _KeyMapType + _keys: Sequence[str] + _processors: Optional[_ProcessorsType] + _key_to_index: Mapping[_KeyType, int] + + @property + def keys(self) -> RMKeyView: + return RMKeyView(self) + + def _has_key(self, key: object) -> bool: + raise NotImplementedError() + + def _for_freeze(self) -> ResultMetaData: + raise NotImplementedError() + + @overload + def _key_fallback( + self, key: Any, err: Optional[Exception], raiseerr: Literal[True] = ... + ) -> NoReturn: ... + + @overload + def _key_fallback( + self, + key: Any, + err: Optional[Exception], + raiseerr: Literal[False] = ..., + ) -> None: ... + + @overload + def _key_fallback( + self, key: Any, err: Optional[Exception], raiseerr: bool = ... + ) -> Optional[NoReturn]: ... + + def _key_fallback( + self, key: Any, err: Optional[Exception], raiseerr: bool = True + ) -> Optional[NoReturn]: + assert raiseerr + raise KeyError(key) from err + + def _raise_for_ambiguous_column_name( + self, rec: _KeyMapRecType + ) -> NoReturn: + raise NotImplementedError( + "ambiguous column name logic is implemented for " + "CursorResultMetaData" + ) + + def _index_for_key( + self, key: _KeyIndexType, raiseerr: bool + ) -> Optional[int]: + raise NotImplementedError() + + def _indexes_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Sequence[int]: + raise NotImplementedError() + + def _metadata_for_keys( + self, keys: Sequence[_KeyIndexType] + ) -> Iterator[_KeyMapRecType]: + raise NotImplementedError() + + def _reduce(self, keys: Sequence[_KeyIndexType]) -> ResultMetaData: + raise NotImplementedError() + + def _getter( + self, key: Any, raiseerr: bool = True + ) -> Optional[Callable[[Row[Any]], Any]]: + index = self._index_for_key(key, raiseerr) + + if index is not None: + return operator.itemgetter(index) + else: + return None + + def _row_as_tuple_getter( + self, keys: Sequence[_KeyIndexType] + ) -> _TupleGetterType: + indexes = self._indexes_for_keys(keys) + return tuplegetter(*indexes) + + def _make_key_to_index( + self, keymap: Mapping[_KeyType, Sequence[Any]], index: int + ) -> Mapping[_KeyType, int]: + return { + key: rec[index] + for key, rec in keymap.items() + if rec[index] is not None + } + + def _key_not_found(self, key: Any, attr_error: bool) -> NoReturn: + if key in self._keymap: + # the index must be none in this case + self._raise_for_ambiguous_column_name(self._keymap[key]) + else: + # unknown key + if attr_error: + try: + self._key_fallback(key, None) + except KeyError as ke: + raise AttributeError(ke.args[0]) from ke + else: + self._key_fallback(key, None) + + @property + def _effective_processors(self) -> Optional[_ProcessorsType]: + if not self._processors or NONE_SET.issuperset(self._processors): + return None + else: + return self._processors + + +class RMKeyView(typing.KeysView[Any]): + __slots__ = ("_parent", "_keys") + + _parent: ResultMetaData + _keys: Sequence[str] + + def __init__(self, parent: ResultMetaData): + self._parent = parent + self._keys = [k for k in parent._keys if k is not None] + + def __len__(self) -> int: + return len(self._keys) + + def __repr__(self) -> str: + return "{0.__class__.__name__}({0._keys!r})".format(self) + + def __iter__(self) -> Iterator[str]: + return iter(self._keys) + + def __contains__(self, item: Any) -> bool: + if isinstance(item, int): + return False + + # note this also includes special key fallback behaviors + # which also don't seem to be tested in test_resultset right now + return self._parent._has_key(item) + + def __eq__(self, other: Any) -> bool: + return list(other) == list(self) + + def __ne__(self, other: Any) -> bool: + return list(other) != list(self) + + +class SimpleResultMetaData(ResultMetaData): + """result metadata for in-memory collections.""" + + __slots__ = ( + "_keys", + "_keymap", + "_processors", + "_tuplefilter", + "_translated_indexes", + "_unique_filters", + "_key_to_index", + ) + + _keys: Sequence[str] + + def __init__( + self, + keys: Sequence[str], + extra: Optional[Sequence[Any]] = None, + _processors: Optional[_ProcessorsType] = None, + _tuplefilter: Optional[_TupleGetterType] = None, + _translated_indexes: Optional[Sequence[int]] = None, + _unique_filters: Optional[Sequence[Callable[[Any], Any]]] = None, + ): + self._keys = list(keys) + self._tuplefilter = _tuplefilter + self._translated_indexes = _translated_indexes + self._unique_filters = _unique_filters + if extra: + recs_names = [ + ( + (name,) + (extras if extras else ()), + (index, name, extras), + ) + for index, (name, extras) in enumerate(zip(self._keys, extra)) + ] + else: + recs_names = [ + ((name,), (index, name, ())) + for index, name in enumerate(self._keys) + ] + + self._keymap = {key: rec for keys, rec in recs_names for key in keys} + + self._processors = _processors + + self._key_to_index = self._make_key_to_index(self._keymap, 0) + + def _has_key(self, key: object) -> bool: + return key in self._keymap + + def _for_freeze(self) -> ResultMetaData: + unique_filters = self._unique_filters + if unique_filters and self._tuplefilter: + unique_filters = self._tuplefilter(unique_filters) + + # TODO: are we freezing the result with or without uniqueness + # applied? + return SimpleResultMetaData( + self._keys, + extra=[self._keymap[key][2] for key in self._keys], + _unique_filters=unique_filters, + ) + + def __getstate__(self) -> Dict[str, Any]: + return { + "_keys": self._keys, + "_translated_indexes": self._translated_indexes, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + if state["_translated_indexes"]: + _translated_indexes = state["_translated_indexes"] + _tuplefilter = tuplegetter(*_translated_indexes) + else: + _translated_indexes = _tuplefilter = None + self.__init__( # type: ignore + state["_keys"], + _translated_indexes=_translated_indexes, + _tuplefilter=_tuplefilter, + ) + + def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: + if int in key.__class__.__mro__: + key = self._keys[key] + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._key_fallback(key, ke, raiseerr) + + return rec[0] # type: ignore[no-any-return] + + def _indexes_for_keys(self, keys: Sequence[Any]) -> Sequence[int]: + return [self._keymap[key][0] for key in keys] + + def _metadata_for_keys( + self, keys: Sequence[Any] + ) -> Iterator[_KeyMapRecType]: + for key in keys: + if int in key.__class__.__mro__: + key = self._keys[key] + + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._key_fallback(key, ke, True) + + yield rec + + def _reduce(self, keys: Sequence[Any]) -> ResultMetaData: + try: + metadata_for_keys = [ + self._keymap[ + self._keys[key] if int in key.__class__.__mro__ else key + ] + for key in keys + ] + except KeyError as ke: + self._key_fallback(ke.args[0], ke, True) + + indexes: Sequence[int] + new_keys: Sequence[str] + extra: Sequence[Any] + indexes, new_keys, extra = zip(*metadata_for_keys) + + if self._translated_indexes: + indexes = [self._translated_indexes[idx] for idx in indexes] + + tup = tuplegetter(*indexes) + + new_metadata = SimpleResultMetaData( + new_keys, + extra=extra, + _tuplefilter=tup, + _translated_indexes=indexes, + _processors=self._processors, + _unique_filters=self._unique_filters, + ) + + return new_metadata + + +def result_tuple( + fields: Sequence[str], extra: Optional[Any] = None +) -> Callable[[Iterable[Any]], Row[Any]]: + parent = SimpleResultMetaData(fields, extra) + return functools.partial( + Row, parent, parent._effective_processors, parent._key_to_index + ) + + +# a symbol that indicates to internal Result methods that +# "no row is returned". We can't use None for those cases where a scalar +# filter is applied to rows. +class _NoRow(Enum): + _NO_ROW = 0 + + +_NO_ROW = _NoRow._NO_ROW + + +class ResultInternal(InPlaceGenerative, Generic[_R]): + __slots__ = () + + _real_result: Optional[Result[Any]] = None + _generate_rows: bool = True + _row_logging_fn: Optional[Callable[[Any], Any]] + + _unique_filter_state: Optional[_UniqueFilterStateType] = None + _post_creational_filter: Optional[Callable[[Any], Any]] = None + _is_cursor = False + + _metadata: ResultMetaData + + _source_supports_scalars: bool + + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: + raise NotImplementedError() + + def _soft_close(self, hard: bool = False) -> None: + raise NotImplementedError() + + @HasMemoized_ro_memoized_attribute + def _row_getter(self) -> Optional[Callable[..., _R]]: + real_result: Result[Any] = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + + if real_result._source_supports_scalars: + if not self._generate_rows: + return None + else: + _proc = Row + + def process_row( + metadata: ResultMetaData, + processors: Optional[_ProcessorsType], + key_to_index: Mapping[_KeyType, int], + scalar_obj: Any, + ) -> Row[Any]: + return _proc( + metadata, processors, key_to_index, (scalar_obj,) + ) + + else: + process_row = Row # type: ignore + + metadata = self._metadata + + key_to_index = metadata._key_to_index + processors = metadata._effective_processors + tf = metadata._tuplefilter + + if tf and not real_result._source_supports_scalars: + if processors: + processors = tf(processors) + + _make_row_orig: Callable[..., _R] = functools.partial( # type: ignore # noqa E501 + process_row, metadata, processors, key_to_index + ) + + fixed_tf = tf + + def make_row(row: _InterimRowType[Row[Any]]) -> _R: + return _make_row_orig(fixed_tf(row)) + + else: + make_row = functools.partial( # type: ignore + process_row, metadata, processors, key_to_index + ) + + if real_result._row_logging_fn: + _log_row = real_result._row_logging_fn + _make_row = make_row + + def make_row(row: _InterimRowType[Row[Any]]) -> _R: + return _log_row(_make_row(row)) # type: ignore + + return make_row + + @HasMemoized_ro_memoized_attribute + def _iterator_getter(self) -> Callable[..., Iterator[_R]]: + make_row = self._row_getter + + post_creational_filter = self._post_creational_filter + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + def iterrows(self: Result[Any]) -> Iterator[_R]: + for raw_row in self._fetchiter_impl(): + obj: _InterimRowType[Any] = ( + make_row(raw_row) if make_row else raw_row + ) + hashed = strategy(obj) if strategy else obj + if hashed in uniques: + continue + uniques.add(hashed) + if post_creational_filter: + obj = post_creational_filter(obj) + yield obj # type: ignore + + else: + + def iterrows(self: Result[Any]) -> Iterator[_R]: + for raw_row in self._fetchiter_impl(): + row: _InterimRowType[Any] = ( + make_row(raw_row) if make_row else raw_row + ) + if post_creational_filter: + row = post_creational_filter(row) + yield row # type: ignore + + return iterrows + + def _raw_all_rows(self) -> List[_R]: + make_row = self._row_getter + assert make_row is not None + rows = self._fetchall_impl() + return [make_row(row) for row in rows] + + def _allrows(self) -> List[_R]: + post_creational_filter = self._post_creational_filter + + make_row = self._row_getter + + rows = self._fetchall_impl() + made_rows: List[_InterimRowType[_R]] + if make_row: + made_rows = [make_row(row) for row in rows] + else: + made_rows = rows # type: ignore + + interim_rows: List[_R] + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + interim_rows = [ + made_row # type: ignore + for made_row, sig_row in [ + ( + made_row, + strategy(made_row) if strategy else made_row, + ) + for made_row in made_rows + ] + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa: E501 + ] + else: + interim_rows = made_rows # type: ignore + + if post_creational_filter: + interim_rows = [ + post_creational_filter(row) for row in interim_rows + ] + return interim_rows + + @HasMemoized_ro_memoized_attribute + def _onerow_getter( + self, + ) -> Callable[..., Union[Literal[_NoRow._NO_ROW], _R]]: + make_row = self._row_getter + + post_creational_filter = self._post_creational_filter + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: + _onerow = self._fetchone_impl + while True: + row = _onerow() + if row is None: + return _NO_ROW + else: + obj: _InterimRowType[Any] = ( + make_row(row) if make_row else row + ) + hashed = strategy(obj) if strategy else obj + if hashed in uniques: + continue + else: + uniques.add(hashed) + if post_creational_filter: + obj = post_creational_filter(obj) + return obj # type: ignore + + else: + + def onerow(self: Result[Any]) -> Union[_NoRow, _R]: + row = self._fetchone_impl() + if row is None: + return _NO_ROW + else: + interim_row: _InterimRowType[Any] = ( + make_row(row) if make_row else row + ) + if post_creational_filter: + interim_row = post_creational_filter(interim_row) + return interim_row # type: ignore + + return onerow + + @HasMemoized_ro_memoized_attribute + def _manyrow_getter(self) -> Callable[..., List[_R]]: + make_row = self._row_getter + + post_creational_filter = self._post_creational_filter + + if self._unique_filter_state: + uniques, strategy = self._unique_strategy + + def filterrows( + make_row: Optional[Callable[..., _R]], + rows: List[Any], + strategy: Optional[Callable[[List[Any]], Any]], + uniques: Set[Any], + ) -> List[_R]: + if make_row: + rows = [make_row(row) for row in rows] + + if strategy: + made_rows = ( + (made_row, strategy(made_row)) for made_row in rows + ) + else: + made_rows = ((made_row, made_row) for made_row in rows) + return [ + made_row + for made_row, sig_row in made_rows + if sig_row not in uniques and not uniques.add(sig_row) # type: ignore # noqa: E501 + ] + + def manyrows( + self: ResultInternal[_R], num: Optional[int] + ) -> List[_R]: + collect: List[_R] = [] + + _manyrows = self._fetchmany_impl + + if num is None: + # if None is passed, we don't know the default + # manyrows number, DBAPI has this as cursor.arraysize + # different DBAPIs / fetch strategies may be different. + # do a fetch to find what the number is. if there are + # only fewer rows left, then it doesn't matter. + real_result = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + if real_result._yield_per: + num_required = num = real_result._yield_per + else: + rows = _manyrows(num) + num = len(rows) + assert make_row is not None + collect.extend( + filterrows(make_row, rows, strategy, uniques) + ) + num_required = num - len(collect) + else: + num_required = num + + assert num is not None + + while num_required: + rows = _manyrows(num_required) + if not rows: + break + + collect.extend( + filterrows(make_row, rows, strategy, uniques) + ) + num_required = num - len(collect) + + if post_creational_filter: + collect = [post_creational_filter(row) for row in collect] + return collect + + else: + + def manyrows( + self: ResultInternal[_R], num: Optional[int] + ) -> List[_R]: + if num is None: + real_result = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + num = real_result._yield_per + + rows: List[_InterimRowType[Any]] = self._fetchmany_impl(num) + if make_row: + rows = [make_row(row) for row in rows] + if post_creational_filter: + rows = [post_creational_filter(row) for row in rows] + return rows # type: ignore + + return manyrows + + @overload + def _only_one_row( + self, + raise_for_second_row: bool, + raise_for_none: Literal[True], + scalar: bool, + ) -> _R: ... + + @overload + def _only_one_row( + self, + raise_for_second_row: bool, + raise_for_none: bool, + scalar: bool, + ) -> Optional[_R]: ... + + def _only_one_row( + self, + raise_for_second_row: bool, + raise_for_none: bool, + scalar: bool, + ) -> Optional[_R]: + onerow = self._fetchone_impl + + row: Optional[_InterimRowType[Any]] = onerow(hard_close=True) + if row is None: + if raise_for_none: + raise exc.NoResultFound( + "No row was found when one was required" + ) + else: + return None + + if scalar and self._source_supports_scalars: + self._generate_rows = False + make_row = None + else: + make_row = self._row_getter + + try: + row = make_row(row) if make_row else row + except: + self._soft_close(hard=True) + raise + + if raise_for_second_row: + if self._unique_filter_state: + # for no second row but uniqueness, need to essentially + # consume the entire result :( + uniques, strategy = self._unique_strategy + + existing_row_hash = strategy(row) if strategy else row + + while True: + next_row: Any = onerow(hard_close=True) + if next_row is None: + next_row = _NO_ROW + break + + try: + next_row = make_row(next_row) if make_row else next_row + + if strategy: + assert next_row is not _NO_ROW + if existing_row_hash == strategy(next_row): + continue + elif row == next_row: + continue + # here, we have a row and it's different + break + except: + self._soft_close(hard=True) + raise + else: + next_row = onerow(hard_close=True) + if next_row is None: + next_row = _NO_ROW + + if next_row is not _NO_ROW: + self._soft_close(hard=True) + raise exc.MultipleResultsFound( + "Multiple rows were found when exactly one was required" + if raise_for_none + else "Multiple rows were found when one or none " + "was required" + ) + else: + next_row = _NO_ROW + # if we checked for second row then that would have + # closed us :) + self._soft_close(hard=True) + + if not scalar: + post_creational_filter = self._post_creational_filter + if post_creational_filter: + row = post_creational_filter(row) + + if scalar and make_row: + return row[0] # type: ignore + else: + return row # type: ignore + + def _iter_impl(self) -> Iterator[_R]: + return self._iterator_getter(self) + + def _next_impl(self) -> _R: + row = self._onerow_getter(self) + if row is _NO_ROW: + raise StopIteration() + else: + return row + + @_generative + def _column_slices(self, indexes: Sequence[_KeyIndexType]) -> Self: + real_result = ( + self._real_result + if self._real_result + else cast("Result[Any]", self) + ) + + if not real_result._source_supports_scalars or len(indexes) != 1: + self._metadata = self._metadata._reduce(indexes) + + assert self._generate_rows + + return self + + @HasMemoized.memoized_attribute + def _unique_strategy(self) -> _UniqueFilterStateType: + assert self._unique_filter_state is not None + uniques, strategy = self._unique_filter_state + + real_result = ( + self._real_result + if self._real_result is not None + else cast("Result[Any]", self) + ) + + if not strategy and self._metadata._unique_filters: + if ( + real_result._source_supports_scalars + and not self._generate_rows + ): + strategy = self._metadata._unique_filters[0] + else: + filters = self._metadata._unique_filters + if self._metadata._tuplefilter: + filters = self._metadata._tuplefilter(filters) + + strategy = operator.methodcaller("_filter_on_values", filters) + return uniques, strategy + + +class _WithKeys: + __slots__ = () + + _metadata: ResultMetaData + + # used mainly to share documentation on the keys method. + def keys(self) -> RMKeyView: + """Return an iterable view which yields the string keys that would + be represented by each :class:`_engine.Row`. + + The keys can represent the labels of the columns returned by a core + statement or the names of the orm classes returned by an orm + execution. + + The view also can be tested for key containment using the Python + ``in`` operator, which will test both for the string keys represented + in the view, as well as for alternate keys such as column objects. + + .. versionchanged:: 1.4 a key view object is returned rather than a + plain list. + + + """ + return self._metadata.keys + + +class Result(_WithKeys, ResultInternal[Row[_TP]]): + """Represent a set of database results. + + .. versionadded:: 1.4 The :class:`_engine.Result` object provides a + completely updated usage model and calling facade for SQLAlchemy + Core and SQLAlchemy ORM. In Core, it forms the basis of the + :class:`_engine.CursorResult` object which replaces the previous + :class:`_engine.ResultProxy` interface. When using the ORM, a + higher level object called :class:`_engine.ChunkedIteratorResult` + is normally used. + + .. note:: In SQLAlchemy 1.4 and above, this object is + used for ORM results returned by :meth:`_orm.Session.execute`, which can + yield instances of ORM mapped objects either individually or within + tuple-like rows. Note that the :class:`_engine.Result` object does not + deduplicate instances or rows automatically as is the case with the + legacy :class:`_orm.Query` object. For in-Python de-duplication of + instances or rows, use the :meth:`_engine.Result.unique` modifier + method. + + .. seealso:: + + :ref:`tutorial_fetching_rows` - in the :doc:`/tutorial/index` + + """ + + __slots__ = ("_metadata", "__dict__") + + _row_logging_fn: Optional[Callable[[Row[Any]], Row[Any]]] = None + + _source_supports_scalars: bool = False + + _yield_per: Optional[int] = None + + _attributes: util.immutabledict[Any, Any] = util.immutabledict() + + def __init__(self, cursor_metadata: ResultMetaData): + self._metadata = cursor_metadata + + def __enter__(self) -> Self: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + def close(self) -> None: + """close this :class:`_engine.Result`. + + The behavior of this method is implementation specific, and is + not implemented by default. The method should generally end + the resources in use by the result object and also cause any + subsequent iteration or row fetching to raise + :class:`.ResourceClosedError`. + + .. versionadded:: 1.4.27 - ``.close()`` was previously not generally + available for all :class:`_engine.Result` classes, instead only + being available on the :class:`_engine.CursorResult` returned for + Core statement executions. As most other result objects, namely the + ones used by the ORM, are proxying a :class:`_engine.CursorResult` + in any case, this allows the underlying cursor result to be closed + from the outside facade for the case when the ORM query is using + the ``yield_per`` execution option where it does not immediately + exhaust and autoclose the database cursor. + + """ + self._soft_close(hard=True) + + @property + def _soft_closed(self) -> bool: + raise NotImplementedError() + + @property + def closed(self) -> bool: + """return ``True`` if this :class:`_engine.Result` reports .closed + + .. versionadded:: 1.4.43 + + """ + raise NotImplementedError() + + @_generative + def yield_per(self, num: int) -> Self: + """Configure the row-fetching strategy to fetch ``num`` rows at a time. + + This impacts the underlying behavior of the result when iterating over + the result object, or otherwise making use of methods such as + :meth:`_engine.Result.fetchone` that return one row at a time. Data + from the underlying cursor or other data source will be buffered up to + this many rows in memory, and the buffered collection will then be + yielded out one row at a time or as many rows are requested. Each time + the buffer clears, it will be refreshed to this many rows or as many + rows remain if fewer remain. + + The :meth:`_engine.Result.yield_per` method is generally used in + conjunction with the + :paramref:`_engine.Connection.execution_options.stream_results` + execution option, which will allow the database dialect in use to make + use of a server side cursor, if the DBAPI supports a specific "server + side cursor" mode separate from its default mode of operation. + + .. tip:: + + Consider using the + :paramref:`_engine.Connection.execution_options.yield_per` + execution option, which will simultaneously set + :paramref:`_engine.Connection.execution_options.stream_results` + to ensure the use of server side cursors, as well as automatically + invoke the :meth:`_engine.Result.yield_per` method to establish + a fixed row buffer size at once. + + The :paramref:`_engine.Connection.execution_options.yield_per` + execution option is available for ORM operations, with + :class:`_orm.Session`-oriented use described at + :ref:`orm_queryguide_yield_per`. The Core-only version which works + with :class:`_engine.Connection` is new as of SQLAlchemy 1.4.40. + + .. versionadded:: 1.4 + + :param num: number of rows to fetch each time the buffer is refilled. + If set to a value below 1, fetches all rows for the next buffer. + + .. seealso:: + + :ref:`engine_stream_results` - describes Core behavior for + :meth:`_engine.Result.yield_per` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + + """ + self._yield_per = num + return self + + @_generative + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_engine.Result`. + + When this filter is applied with no arguments, the rows or objects + returned will filtered such that each row is returned uniquely. The + algorithm used to determine this uniqueness is by default the Python + hashing identity of the whole tuple. In some cases a specialized + per-entity hashing scheme may be used, such as when using the ORM, a + scheme is applied which works against the primary key identity of + returned objects. + + The unique filter is applied **after all other filters**, which means + if the columns returned have been refined using a method such as the + :meth:`_engine.Result.columns` or :meth:`_engine.Result.scalars` + method, the uniquing is applied to **only the column or columns + returned**. This occurs regardless of the order in which these + methods have been called upon the :class:`_engine.Result` object. + + The unique filter also changes the calculus used for methods like + :meth:`_engine.Result.fetchmany` and :meth:`_engine.Result.partitions`. + When using :meth:`_engine.Result.unique`, these methods will continue + to yield the number of rows or objects requested, after uniquing + has been applied. However, this necessarily impacts the buffering + behavior of the underlying cursor or datasource, such that multiple + underlying calls to ``cursor.fetchmany()`` may be necessary in order + to accumulate enough objects in order to provide a unique collection + of the requested size. + + :param strategy: a callable that will be applied to rows or objects + being iterated, which should return an object that represents the + unique value of the row. A Python ``set()`` is used to store + these identities. If not passed, a default uniqueness strategy + is used which may have been assembled by the source of this + :class:`_engine.Result` object. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row. + + This method may be used to limit the columns returned as well + as to reorder them. The given list of expressions are normally + a series of integers or string key names. They may also be + appropriate :class:`.ColumnElement` objects which correspond to + a given statement construct. + + .. versionchanged:: 2.0 Due to a bug in 1.4, the + :meth:`_engine.Result.columns` method had an incorrect behavior + where calling upon the method with just one index would cause the + :class:`_engine.Result` object to yield scalar values rather than + :class:`_engine.Row` objects. In version 2.0, this behavior + has been corrected such that calling upon + :meth:`_engine.Result.columns` with a single index will + produce a :class:`_engine.Result` object that continues + to yield :class:`_engine.Row` objects, which include + only a single column. + + E.g.:: + + statement = select(table.c.x, table.c.y, table.c.z) + result = connection.execute(statement) + + for z, y in result.columns('z', 'y'): + # ... + + + Example of using the column objects from the statement itself:: + + for z, y in result.columns( + statement.selected_columns.c.z, + statement.selected_columns.c.y + ): + # ... + + .. versionadded:: 1.4 + + :param \*col_expressions: indicates columns to be returned. Elements + may be integer row indexes, string column names, or appropriate + :class:`.ColumnElement` objects corresponding to a select construct. + + :return: this :class:`_engine.Result` object with the modifications + given. + + """ + return self._column_slices(col_expressions) + + @overload + def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: ... + + @overload + def scalars( + self: Result[Tuple[_T]], index: Literal[0] + ) -> ScalarResult[_T]: ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: ... + + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: + """Return a :class:`_engine.ScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + E.g.:: + + >>> result = conn.execute(text("select int_id from table")) + >>> result.scalars().all() + [1, 2, 3] + + When results are fetched from the :class:`_engine.ScalarResult` + filtering object, the single column-row that would be returned by the + :class:`_engine.Result` is instead returned as the column's value. + + .. versionadded:: 1.4 + + :param index: integer or row key indicating the column to be fetched + from each row, defaults to ``0`` indicating the first column. + + :return: a new :class:`_engine.ScalarResult` filtering object referring + to this :class:`_engine.Result` object. + + """ + return ScalarResult(self, index) + + def _getter( + self, key: _KeyIndexType, raiseerr: bool = True + ) -> Optional[Callable[[Row[Any]], Any]]: + """return a callable that will retrieve the given key from a + :class:`_engine.Row`. + + """ + if self._source_supports_scalars: + raise NotImplementedError( + "can't use this function in 'only scalars' mode" + ) + return self._metadata._getter(key, raiseerr) + + def _tuple_getter(self, keys: Sequence[_KeyIndexType]) -> _TupleGetterType: + """return a callable that will retrieve the given keys from a + :class:`_engine.Row`. + + """ + if self._source_supports_scalars: + raise NotImplementedError( + "can't use this function in 'only scalars' mode" + ) + return self._metadata._row_as_tuple_getter(keys) + + def mappings(self) -> MappingResult: + """Apply a mappings filter to returned rows, returning an instance of + :class:`_engine.MappingResult`. + + When this filter is applied, fetching rows will return + :class:`_engine.RowMapping` objects instead of :class:`_engine.Row` + objects. + + .. versionadded:: 1.4 + + :return: a new :class:`_engine.MappingResult` filtering object + referring to this :class:`_engine.Result` object. + + """ + + return MappingResult(self) + + @property + def t(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`_engine.Result.t` attribute is a synonym for + calling the :meth:`_engine.Result.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> TupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`_engine.Result` object + at runtime, + however annotates as returning a :class:`_engine.TupleResult` object + that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`_engine.Row` + objects to by typed, for those cases where the statement invoked + itself included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_engine.TupleResult` type at typing time. + + .. seealso:: + + :attr:`_engine.Result.t` - shorter synonym + + :attr:`_engine.Row._t` - :class:`_engine.Row` version + + """ + + return self # type: ignore + + def _raw_row_iterator(self) -> Iterator[_RowData]: + """Return a safe iterator that yields raw row data. + + This is used by the :meth:`_engine.Result.merge` method + to merge multiple compatible results together. + + """ + raise NotImplementedError() + + def __iter__(self) -> Iterator[Row[_TP]]: + return self._iter_impl() + + def __next__(self) -> Row[_TP]: + return self._next_impl() + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[Row[_TP]]]: + """Iterate through sub-lists of rows of the size given. + + Each list will be of the size given, excluding the last list to + be yielded, which may have a small number of rows. No empty + lists will be yielded. + + The result object is automatically closed when the iterator + is fully consumed. + + Note that the backend driver will usually buffer the entire result + ahead of time unless the + :paramref:`.Connection.execution_options.stream_results` execution + option is used indicating that the driver should not pre-buffer + results, if possible. Not all drivers support this option and + the option is silently ignored for those who do not. + + When using the ORM, the :meth:`_engine.Result.partitions` method + is typically more effective from a memory perspective when it is + combined with use of the + :ref:`yield_per execution option `, + which instructs both the DBAPI driver to use server side cursors, + if available, as well as instructs the ORM loading internals to only + build a certain amount of ORM objects from a result at a time before + yielding them out. + + .. versionadded:: 1.4 + + :param size: indicate the maximum number of rows to be present + in each list yielded. If None, makes use of the value set by + the :meth:`_engine.Result.yield_per`, method, if it were called, + or the :paramref:`_engine.Connection.execution_options.yield_per` + execution option, which is equivalent in this regard. If + yield_per weren't set, it makes use of the + :meth:`_engine.Result.fetchmany` default, which may be backend + specific and not well defined. + + :return: iterator of lists + + .. seealso:: + + :ref:`engine_stream_results` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + + """ + + getter = self._manyrow_getter + + while True: + partition = getter(self, size) + if partition: + yield partition + else: + break + + def fetchall(self) -> Sequence[Row[_TP]]: + """A synonym for the :meth:`_engine.Result.all` method.""" + + return self._allrows() + + def fetchone(self) -> Optional[Row[_TP]]: + """Fetch one row. + + When all rows are exhausted, returns None. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch the first row of a result only, use the + :meth:`_engine.Result.first` method. To iterate through all + rows, iterate the :class:`_engine.Result` object directly. + + :return: a :class:`_engine.Row` object if no filters are applied, + or ``None`` if no rows remain. + + """ + row = self._onerow_getter(self) + if row is _NO_ROW: + return None + else: + return row + + def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: + """Fetch many rows. + + When all rows are exhausted, returns an empty sequence. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch rows in groups, use the :meth:`_engine.Result.partitions` + method. + + :return: a sequence of :class:`_engine.Row` objects. + + .. seealso:: + + :meth:`_engine.Result.partitions` + + """ + + return self._manyrow_getter(self, size) + + def all(self) -> Sequence[Row[_TP]]: + """Return all rows in a sequence. + + Closes the result set after invocation. Subsequent invocations + will return an empty sequence. + + .. versionadded:: 1.4 + + :return: a sequence of :class:`_engine.Row` objects. + + .. seealso:: + + :ref:`engine_stream_results` - How to stream a large result set + without loading it completely in python. + + """ + + return self._allrows() + + def first(self) -> Optional[Row[_TP]]: + """Fetch the first row or ``None`` if no row is present. + + Closes the result set and discards remaining rows. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_engine.Result.scalar` method, + or combine :meth:`_engine.Result.scalars` and + :meth:`_engine.Result.first`. + + Additionally, in contrast to the behavior of the legacy ORM + :meth:`_orm.Query.first` method, **no limit is applied** to the + SQL query which was invoked to produce this + :class:`_engine.Result`; + for a DBAPI driver that buffers results in memory before yielding + rows, all rows will be sent to the Python process and all but + the first row will be discarded. + + .. seealso:: + + :ref:`migration_20_unify_select` + + :return: a :class:`_engine.Row` object, or None + if no rows remain. + + .. seealso:: + + :meth:`_engine.Result.scalar` + + :meth:`_engine.Result.one` + + """ + + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=False + ) + + def one_or_none(self) -> Optional[Row[_TP]]: + """Return at most one result or raise an exception. + + Returns ``None`` if the result has no rows. + Raises :class:`.MultipleResultsFound` + if multiple rows are returned. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row` or ``None`` if no row + is available. + + :raises: :class:`.MultipleResultsFound` + + .. seealso:: + + :meth:`_engine.Result.first` + + :meth:`_engine.Result.one` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=False + ) + + @overload + def scalar_one(self: Result[Tuple[_T]]) -> _T: ... + + @overload + def scalar_one(self) -> Any: ... + + def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_engine.Result.scalars` and + then :meth:`_engine.Result.one`. + + .. seealso:: + + :meth:`_engine.Result.one` + + :meth:`_engine.Result.scalars` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=True + ) + + @overload + def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: ... + + def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one scalar result or ``None``. + + This is equivalent to calling :meth:`_engine.Result.scalars` and + then :meth:`_engine.Result.one_or_none`. + + .. seealso:: + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalars` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=True + ) + + def one(self) -> Row[_TP]: + """Return exactly one row or raise an exception. + + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows + would be returned. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_engine.Result.scalar_one` method, or combine + :meth:`_engine.Result.scalars` and + :meth:`_engine.Result.one`. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row`. + + :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound` + + .. seealso:: + + :meth:`_engine.Result.first` + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalar_one` + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=False + ) + + @overload + def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: ... + + @overload + def scalar(self) -> Any: ... + + def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value, or ``None`` if no rows remain. + + """ + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=True + ) + + def freeze(self) -> FrozenResult[_TP]: + """Return a callable object that will produce copies of this + :class:`_engine.Result` when invoked. + + The callable object returned is an instance of + :class:`_engine.FrozenResult`. + + This is used for result set caching. The method must be called + on the result when it has been unconsumed, and calling the method + will consume the result fully. When the :class:`_engine.FrozenResult` + is retrieved from a cache, it can be called any number of times where + it will produce a new :class:`_engine.Result` object each time + against its stored set of rows. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + """ + + return FrozenResult(self) + + def merge(self, *others: Result[Any]) -> MergedResult[_TP]: + """Merge this :class:`_engine.Result` with other compatible result + objects. + + The object returned is an instance of :class:`_engine.MergedResult`, + which will be composed of iterators from the given result + objects. + + The new result will use the metadata from this result object. + The subsequent result objects must be against an identical + set of result / cursor metadata, otherwise the behavior is + undefined. + + """ + return MergedResult(self._metadata, (self,) + others) + + +class FilterResult(ResultInternal[_R]): + """A wrapper for a :class:`_engine.Result` that returns objects other than + :class:`_engine.Row` objects, such as dictionaries or scalar objects. + + :class:`_engine.FilterResult` is the common base for additional result + APIs including :class:`_engine.MappingResult`, + :class:`_engine.ScalarResult` and :class:`_engine.AsyncResult`. + + """ + + __slots__ = ( + "_real_result", + "_post_creational_filter", + "_metadata", + "_unique_filter_state", + "__dict__", + ) + + _post_creational_filter: Optional[Callable[[Any], Any]] + + _real_result: Result[Any] + + def __enter__(self) -> Self: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self._real_result.__exit__(type_, value, traceback) + + @_generative + def yield_per(self, num: int) -> Self: + """Configure the row-fetching strategy to fetch ``num`` rows at a time. + + The :meth:`_engine.FilterResult.yield_per` method is a pass through + to the :meth:`_engine.Result.yield_per` method. See that method's + documentation for usage notes. + + .. versionadded:: 1.4.40 - added :meth:`_engine.FilterResult.yield_per` + so that the method is available on all result set implementations + + .. seealso:: + + :ref:`engine_stream_results` - describes Core behavior for + :meth:`_engine.Result.yield_per` + + :ref:`orm_queryguide_yield_per` - in the :ref:`queryguide_toplevel` + + """ + self._real_result = self._real_result.yield_per(num) + return self + + def _soft_close(self, hard: bool = False) -> None: + self._real_result._soft_close(hard=hard) + + @property + def _soft_closed(self) -> bool: + return self._real_result._soft_closed + + @property + def closed(self) -> bool: + """Return ``True`` if the underlying :class:`_engine.Result` reports + closed + + .. versionadded:: 1.4.43 + + """ + return self._real_result.closed + + def close(self) -> None: + """Close this :class:`_engine.FilterResult`. + + .. versionadded:: 1.4.43 + + """ + self._real_result.close() + + @property + def _attributes(self) -> Dict[Any, Any]: + return self._real_result._attributes + + def _fetchiter_impl(self) -> Iterator[_InterimRowType[Row[Any]]]: + return self._real_result._fetchiter_impl() + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType[Row[Any]]]: + return self._real_result._fetchone_impl(hard_close=hard_close) + + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: + return self._real_result._fetchall_impl() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + return self._real_result._fetchmany_impl(size=size) + + +class ScalarResult(FilterResult[_R]): + """A wrapper for a :class:`_engine.Result` that returns scalar values + rather than :class:`_row.Row` values. + + The :class:`_engine.ScalarResult` object is acquired by calling the + :meth:`_engine.Result.scalars` method. + + A special limitation of :class:`_engine.ScalarResult` is that it has + no ``fetchone()`` method; since the semantics of ``fetchone()`` are that + the ``None`` value indicates no more results, this is not compatible + with :class:`_engine.ScalarResult` since there is no way to distinguish + between ``None`` as a row value versus ``None`` as an indicator. Use + ``next(result)`` to receive values individually. + + """ + + __slots__ = () + + _generate_rows = False + + _post_creational_filter: Optional[Callable[[Any], Any]] + + def __init__(self, real_result: Result[Any], index: _KeyIndexType): + self._real_result = real_result + + if real_result._source_supports_scalars: + self._metadata = real_result._metadata + self._post_creational_filter = None + else: + self._metadata = real_result._metadata._reduce([index]) + self._post_creational_filter = operator.itemgetter(0) + + self._unique_filter_state = real_result._unique_filter_state + + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_engine.ScalarResult`. + + See :meth:`_engine.Result.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def partitions(self, size: Optional[int] = None) -> Iterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_engine.Result.partitions` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = getter(self, size) + if partition: + yield partition + else: + break + + def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + + return self._allrows() + + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_engine.Result.fetchmany` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._manyrow_getter(self, size) + + def all(self) -> Sequence[_R]: + """Return all scalar values in a sequence. + + Equivalent to :meth:`_engine.Result.all` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._allrows() + + def __iter__(self) -> Iterator[_R]: + return self._iter_impl() + + def __next__(self) -> _R: + return self._next_impl() + + def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_engine.Result.first` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + + """ + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=False + ) + + def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one_or_none` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=False + ) + + def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=False + ) + + +class TupleResult(FilterResult[_R], util.TypingOnly): + """A :class:`_engine.Result` that's typed as returning plain + Python tuples instead of rows. + + Since :class:`_engine.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`_engine.Result` is + still used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_engine.Result.partitions` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_engine.Result.fetchone` except that + tuple values, rather than :class:`_engine.Row` + objects, are returned. + + """ + ... + + def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_engine.Result.fetchmany` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a sequence. + + Equivalent to :meth:`_engine.Result.all` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def __iter__(self) -> Iterator[_R]: ... + + def __next__(self) -> _R: ... + + def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_engine.Result.first` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + + """ + ... + + def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one_or_none` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + @overload + def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: ... + + @overload + def scalar_one(self) -> Any: ... + + def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one`. + + .. seealso:: + + :meth:`_engine.Result.one` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + def scalar_one_or_none( + self: TupleResult[Tuple[_T]], + ) -> Optional[_T]: ... + + @overload + def scalar_one_or_none(self) -> Optional[Any]: ... + + def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one_or_none`. + + .. seealso:: + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: ... + + @overload + def scalar(self) -> Any: ... + + def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result + set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or ``None`` if no rows remain. + + """ + ... + + +class MappingResult(_WithKeys, FilterResult[RowMapping]): + """A wrapper for a :class:`_engine.Result` that returns dictionary values + rather than :class:`_engine.Row` values. + + The :class:`_engine.MappingResult` object is acquired by calling the + :meth:`_engine.Result.mappings` method. + + """ + + __slots__ = () + + _generate_rows = True + + _post_creational_filter = operator.attrgetter("_mapping") + + def __init__(self, result: Result[Any]): + self._real_result = result + self._unique_filter_state = result._unique_filter_state + self._metadata = result._metadata + if result._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_engine.MappingResult`. + + See :meth:`_engine.Result.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row.""" + return self._column_slices(col_expressions) + + def partitions( + self, size: Optional[int] = None + ) -> Iterator[Sequence[RowMapping]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_engine.Result.partitions` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = getter(self, size) + if partition: + yield partition + else: + break + + def fetchall(self) -> Sequence[RowMapping]: + """A synonym for the :meth:`_engine.MappingResult.all` method.""" + + return self._allrows() + + def fetchone(self) -> Optional[RowMapping]: + """Fetch one object. + + Equivalent to :meth:`_engine.Result.fetchone` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + row = self._onerow_getter(self) + if row is _NO_ROW: + return None + else: + return row + + def fetchmany(self, size: Optional[int] = None) -> Sequence[RowMapping]: + """Fetch many objects. + + Equivalent to :meth:`_engine.Result.fetchmany` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return self._manyrow_getter(self, size) + + def all(self) -> Sequence[RowMapping]: + """Return all scalar values in a sequence. + + Equivalent to :meth:`_engine.Result.all` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return self._allrows() + + def __iter__(self) -> Iterator[RowMapping]: + return self._iter_impl() + + def __next__(self) -> RowMapping: + return self._next_impl() + + def first(self) -> Optional[RowMapping]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_engine.Result.first` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + + """ + return self._only_one_row( + raise_for_second_row=False, raise_for_none=False, scalar=False + ) + + def one_or_none(self) -> Optional[RowMapping]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one_or_none` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=False, scalar=False + ) + + def one(self) -> RowMapping: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_engine.Result.one` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return self._only_one_row( + raise_for_second_row=True, raise_for_none=True, scalar=False + ) + + +class FrozenResult(Generic[_TP]): + """Represents a :class:`_engine.Result` object in a "frozen" state suitable + for caching. + + The :class:`_engine.FrozenResult` object is returned from the + :meth:`_engine.Result.freeze` method of any :class:`_engine.Result` + object. + + A new iterable :class:`_engine.Result` object is generated from a fixed + set of data each time the :class:`_engine.FrozenResult` is invoked as + a callable:: + + + result = connection.execute(query) + + frozen = result.freeze() + + unfrozen_result_one = frozen() + + for row in unfrozen_result_one: + print(row) + + unfrozen_result_two = frozen() + rows = unfrozen_result_two.all() + + # ... etc + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + :func:`_orm.loading.merge_frozen_result` - ORM function to merge + a frozen result back into a :class:`_orm.Session`. + + """ + + data: Sequence[Any] + + def __init__(self, result: Result[_TP]): + self.metadata = result._metadata._for_freeze() + self._source_supports_scalars = result._source_supports_scalars + self._attributes = result._attributes + + if self._source_supports_scalars: + self.data = list(result._raw_row_iterator()) + else: + self.data = result.fetchall() + + def rewrite_rows(self) -> Sequence[Sequence[Any]]: + if self._source_supports_scalars: + return [[elem] for elem in self.data] + else: + return [list(row) for row in self.data] + + def with_new_rows( + self, tuple_data: Sequence[Row[_TP]] + ) -> FrozenResult[_TP]: + fr = FrozenResult.__new__(FrozenResult) + fr.metadata = self.metadata + fr._attributes = self._attributes + fr._source_supports_scalars = self._source_supports_scalars + + if self._source_supports_scalars: + fr.data = [d[0] for d in tuple_data] + else: + fr.data = tuple_data + return fr + + def __call__(self) -> Result[_TP]: + result: IteratorResult[_TP] = IteratorResult( + self.metadata, iter(self.data) + ) + result._attributes = self._attributes + result._source_supports_scalars = self._source_supports_scalars + return result + + +class IteratorResult(Result[_TP]): + """A :class:`_engine.Result` that gets data from a Python iterator of + :class:`_engine.Row` objects or similar row-like data. + + .. versionadded:: 1.4 + + """ + + _hard_closed = False + _soft_closed = False + + def __init__( + self, + cursor_metadata: ResultMetaData, + iterator: Iterator[_InterimSupportsScalarsRowType], + raw: Optional[Result[Any]] = None, + _source_supports_scalars: bool = False, + ): + self._metadata = cursor_metadata + self.iterator = iterator + self.raw = raw + self._source_supports_scalars = _source_supports_scalars + + @property + def closed(self) -> bool: + """Return ``True`` if this :class:`_engine.IteratorResult` has + been closed + + .. versionadded:: 1.4.43 + + """ + return self._hard_closed + + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + if hard: + self._hard_closed = True + if self.raw is not None: + self.raw._soft_close(hard=hard, **kw) + self.iterator = iter([]) + self._reset_memoizations() + self._soft_closed = True + + def _raise_hard_closed(self) -> NoReturn: + raise exc.ResourceClosedError("This result object is closed.") + + def _raw_row_iterator(self) -> Iterator[_RowData]: + return self.iterator + + def _fetchiter_impl(self) -> Iterator[_InterimSupportsScalarsRowType]: + if self._hard_closed: + self._raise_hard_closed() + return self.iterator + + def _fetchone_impl( + self, hard_close: bool = False + ) -> Optional[_InterimRowType[Row[Any]]]: + if self._hard_closed: + self._raise_hard_closed() + + row = next(self.iterator, _NO_ROW) + if row is _NO_ROW: + self._soft_close(hard=hard_close) + return None + else: + return row + + def _fetchall_impl(self) -> List[_InterimRowType[Row[Any]]]: + if self._hard_closed: + self._raise_hard_closed() + try: + return list(self.iterator) + finally: + self._soft_close() + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + if self._hard_closed: + self._raise_hard_closed() + + return list(itertools.islice(self.iterator, 0, size)) + + +def null_result() -> IteratorResult[Any]: + return IteratorResult(SimpleResultMetaData([]), iter([])) + + +class ChunkedIteratorResult(IteratorResult[_TP]): + """An :class:`_engine.IteratorResult` that works from an + iterator-producing callable. + + The given ``chunks`` argument is a function that is given a number of rows + to return in each chunk, or ``None`` for all rows. The function should + then return an un-consumed iterator of lists, each list of the requested + size. + + The function can be called at any time again, in which case it should + continue from the same result set but adjust the chunk size as given. + + .. versionadded:: 1.4 + + """ + + def __init__( + self, + cursor_metadata: ResultMetaData, + chunks: Callable[ + [Optional[int]], Iterator[Sequence[_InterimRowType[_R]]] + ], + source_supports_scalars: bool = False, + raw: Optional[Result[Any]] = None, + dynamic_yield_per: bool = False, + ): + self._metadata = cursor_metadata + self.chunks = chunks + self._source_supports_scalars = source_supports_scalars + self.raw = raw + self.iterator = itertools.chain.from_iterable(self.chunks(None)) + self.dynamic_yield_per = dynamic_yield_per + + @_generative + def yield_per(self, num: int) -> Self: + # TODO: this throws away the iterator which may be holding + # onto a chunk. the yield_per cannot be changed once any + # rows have been fetched. either find a way to enforce this, + # or we can't use itertools.chain and will instead have to + # keep track. + + self._yield_per = num + self.iterator = itertools.chain.from_iterable(self.chunks(num)) + return self + + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + super()._soft_close(hard=hard, **kw) + self.chunks = lambda size: [] # type: ignore + + def _fetchmany_impl( + self, size: Optional[int] = None + ) -> List[_InterimRowType[Row[Any]]]: + if self.dynamic_yield_per: + self.iterator = itertools.chain.from_iterable(self.chunks(size)) + return super()._fetchmany_impl(size=size) + + +class MergedResult(IteratorResult[_TP]): + """A :class:`_engine.Result` that is merged from any number of + :class:`_engine.Result` objects. + + Returned by the :meth:`_engine.Result.merge` method. + + .. versionadded:: 1.4 + + """ + + closed = False + rowcount: Optional[int] + + def __init__( + self, cursor_metadata: ResultMetaData, results: Sequence[Result[_TP]] + ): + self._results = results + super().__init__( + cursor_metadata, + itertools.chain.from_iterable( + r._raw_row_iterator() for r in results + ), + ) + + self._unique_filter_state = results[0]._unique_filter_state + self._yield_per = results[0]._yield_per + + # going to try something w/ this in next rev + self._source_supports_scalars = results[0]._source_supports_scalars + + self._attributes = self._attributes.merge_with( + *[r._attributes for r in results] + ) + + def _soft_close(self, hard: bool = False, **kw: Any) -> None: + for r in self._results: + r._soft_close(hard=hard, **kw) + if hard: + self.closed = True diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/row.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/row.py new file mode 100644 index 0000000..bcaffee --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/row.py @@ -0,0 +1,401 @@ +# engine/row.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Define row constructs including :class:`.Row`.""" + +from __future__ import annotations + +from abc import ABC +import collections.abc as collections_abc +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from ..sql import util as sql_util +from ..util import deprecated +from ..util._has_cy import HAS_CYEXTENSION + +if TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_row import BaseRow as BaseRow +else: + from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow + +if TYPE_CHECKING: + from .result import _KeyType + from .result import _ProcessorsType + from .result import RMKeyView + +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + + +class Row(BaseRow, Sequence[Any], Generic[_TP]): + """Represent a single result row. + + The :class:`.Row` object represents a row of a database result. It is + typically associated in the 1.x series of SQLAlchemy with the + :class:`_engine.CursorResult` object, however is also used by the ORM for + tuple-like results as of SQLAlchemy 1.4. + + The :class:`.Row` object seeks to act as much like a Python named + tuple as possible. For mapping (i.e. dictionary) behavior on a row, + such as testing for containment of keys, refer to the :attr:`.Row._mapping` + attribute. + + .. seealso:: + + :ref:`tutorial_selecting_data` - includes examples of selecting + rows from SELECT statements. + + .. versionchanged:: 1.4 + + Renamed ``RowProxy`` to :class:`.Row`. :class:`.Row` is no longer a + "proxy" object in that it contains the final form of data within it, + and now acts mostly like a named tuple. Mapping-like functionality is + moved to the :attr:`.Row._mapping` attribute. See + :ref:`change_4710_core` for background on this change. + + """ + + __slots__ = () + + def __setattr__(self, name: str, value: Any) -> NoReturn: + raise AttributeError("can't set attribute") + + def __delattr__(self, name: str) -> NoReturn: + raise AttributeError("can't delete attribute") + + def _tuple(self) -> _TP: + """Return a 'tuple' form of this :class:`.Row`. + + At runtime, this method returns "self"; the :class:`.Row` object is + already a named tuple. However, at the typing level, if this + :class:`.Row` is typed, the "tuple" return type will be a :pep:`484` + ``Tuple`` datatype that contains typing information about individual + elements, supporting typed unpacking and attribute access. + + .. versionadded:: 2.0.19 - The :meth:`.Row._tuple` method supersedes + the previous :meth:`.Row.tuple` method, which is now underscored + to avoid name conflicts with column names in the same way as other + named-tuple methods on :class:`.Row`. + + .. seealso:: + + :attr:`.Row._t` - shorthand attribute notation + + :meth:`.Result.tuples` + + + """ + return self # type: ignore + + @deprecated( + "2.0.19", + "The :meth:`.Row.tuple` method is deprecated in favor of " + ":meth:`.Row._tuple`; all :class:`.Row` " + "methods and library-level attributes are intended to be underscored " + "to avoid name conflicts. Please use :meth:`Row._tuple`.", + ) + def tuple(self) -> _TP: + """Return a 'tuple' form of this :class:`.Row`. + + .. versionadded:: 2.0 + + """ + return self._tuple() + + @property + def _t(self) -> _TP: + """A synonym for :meth:`.Row._tuple`. + + .. versionadded:: 2.0.19 - The :attr:`.Row._t` attribute supersedes + the previous :attr:`.Row.t` attribute, which is now underscored + to avoid name conflicts with column names in the same way as other + named-tuple methods on :class:`.Row`. + + .. seealso:: + + :attr:`.Result.t` + """ + return self # type: ignore + + @property + @deprecated( + "2.0.19", + "The :attr:`.Row.t` attribute is deprecated in favor of " + ":attr:`.Row._t`; all :class:`.Row` " + "methods and library-level attributes are intended to be underscored " + "to avoid name conflicts. Please use :attr:`Row._t`.", + ) + def t(self) -> _TP: + """A synonym for :meth:`.Row._tuple`. + + .. versionadded:: 2.0 + + """ + return self._t + + @property + def _mapping(self) -> RowMapping: + """Return a :class:`.RowMapping` for this :class:`.Row`. + + This object provides a consistent Python mapping (i.e. dictionary) + interface for the data contained within the row. The :class:`.Row` + by itself behaves like a named tuple. + + .. seealso:: + + :attr:`.Row._fields` + + .. versionadded:: 1.4 + + """ + return RowMapping(self._parent, None, self._key_to_index, self._data) + + def _filter_on_values( + self, processor: Optional[_ProcessorsType] + ) -> Row[Any]: + return Row(self._parent, processor, self._key_to_index, self._data) + + if not TYPE_CHECKING: + + def _special_name_accessor(name: str) -> Any: + """Handle ambiguous names such as "count" and "index" """ + + @property + def go(self: Row) -> Any: + if self._parent._has_key(name): + return self.__getattr__(name) + else: + + def meth(*arg: Any, **kw: Any) -> Any: + return getattr(collections_abc.Sequence, name)( + self, *arg, **kw + ) + + return meth + + return go + + count = _special_name_accessor("count") + index = _special_name_accessor("index") + + def __contains__(self, key: Any) -> bool: + return key in self._data + + def _op(self, other: Any, op: Callable[[Any, Any], bool]) -> bool: + return ( + op(self._to_tuple_instance(), other._to_tuple_instance()) + if isinstance(other, Row) + else op(self._to_tuple_instance(), other) + ) + + __hash__ = BaseRow.__hash__ + + if TYPE_CHECKING: + + @overload + def __getitem__(self, index: int) -> Any: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[Any]: ... + + def __getitem__(self, index: Union[int, slice]) -> Any: ... + + def __lt__(self, other: Any) -> bool: + return self._op(other, operator.lt) + + def __le__(self, other: Any) -> bool: + return self._op(other, operator.le) + + def __ge__(self, other: Any) -> bool: + return self._op(other, operator.ge) + + def __gt__(self, other: Any) -> bool: + return self._op(other, operator.gt) + + def __eq__(self, other: Any) -> bool: + return self._op(other, operator.eq) + + def __ne__(self, other: Any) -> bool: + return self._op(other, operator.ne) + + def __repr__(self) -> str: + return repr(sql_util._repr_row(self)) + + @property + def _fields(self) -> Tuple[str, ...]: + """Return a tuple of string keys as represented by this + :class:`.Row`. + + The keys can represent the labels of the columns returned by a core + statement or the names of the orm classes returned by an orm + execution. + + This attribute is analogous to the Python named tuple ``._fields`` + attribute. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`.Row._mapping` + + """ + return tuple([k for k in self._parent.keys if k is not None]) + + def _asdict(self) -> Dict[str, Any]: + """Return a new dict which maps field names to their corresponding + values. + + This method is analogous to the Python named tuple ``._asdict()`` + method, and works by applying the ``dict()`` constructor to the + :attr:`.Row._mapping` attribute. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`.Row._mapping` + + """ + return dict(self._mapping) + + +BaseRowProxy = BaseRow +RowProxy = Row + + +class ROMappingView(ABC): + __slots__ = () + + _items: Sequence[Any] + _mapping: Mapping["_KeyType", Any] + + def __init__( + self, mapping: Mapping["_KeyType", Any], items: Sequence[Any] + ): + self._mapping = mapping # type: ignore[misc] + self._items = items # type: ignore[misc] + + def __len__(self) -> int: + return len(self._items) + + def __repr__(self) -> str: + return "{0.__class__.__name__}({0._mapping!r})".format(self) + + def __iter__(self) -> Iterator[Any]: + return iter(self._items) + + def __contains__(self, item: Any) -> bool: + return item in self._items + + def __eq__(self, other: Any) -> bool: + return list(other) == list(self) + + def __ne__(self, other: Any) -> bool: + return list(other) != list(self) + + +class ROMappingKeysValuesView( + ROMappingView, typing.KeysView["_KeyType"], typing.ValuesView[Any] +): + __slots__ = ("_items",) # mapping slot is provided by KeysView + + +class ROMappingItemsView(ROMappingView, typing.ItemsView["_KeyType", Any]): + __slots__ = ("_items",) # mapping slot is provided by ItemsView + + +class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): + """A ``Mapping`` that maps column names and objects to :class:`.Row` + values. + + The :class:`.RowMapping` is available from a :class:`.Row` via the + :attr:`.Row._mapping` attribute, as well as from the iterable interface + provided by the :class:`.MappingResult` object returned by the + :meth:`_engine.Result.mappings` method. + + :class:`.RowMapping` supplies Python mapping (i.e. dictionary) access to + the contents of the row. This includes support for testing of + containment of specific keys (string column names or objects), as well + as iteration of keys, values, and items:: + + for row in result: + if 'a' in row._mapping: + print("Column 'a': %s" % row._mapping['a']) + + print("Column b: %s" % row._mapping[table.c.b]) + + + .. versionadded:: 1.4 The :class:`.RowMapping` object replaces the + mapping-like access previously provided by a database result row, + which now seeks to behave mostly like a named tuple. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + def __getitem__(self, key: _KeyType) -> Any: ... + + else: + __getitem__ = BaseRow._get_by_key_impl_mapping + + def _values_impl(self) -> List[Any]: + return list(self._data) + + def __iter__(self) -> Iterator[str]: + return (k for k in self._parent.keys if k is not None) + + def __len__(self) -> int: + return len(self._data) + + def __contains__(self, key: object) -> bool: + return self._parent._has_key(key) + + def __repr__(self) -> str: + return repr(dict(self)) + + def items(self) -> ROMappingItemsView: + """Return a view of key/value tuples for the elements in the + underlying :class:`.Row`. + + """ + return ROMappingItemsView( + self, [(key, self[key]) for key in self.keys()] + ) + + def keys(self) -> RMKeyView: + """Return a view of 'keys' for string column names represented + by the underlying :class:`.Row`. + + """ + + return self._parent.keys + + def values(self) -> ROMappingKeysValuesView: + """Return a view of values for the values represented in the + underlying :class:`.Row`. + + """ + return ROMappingKeysValuesView(self, self._values_impl()) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/strategies.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/strategies.py new file mode 100644 index 0000000..30c331e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/strategies.py @@ -0,0 +1,19 @@ +# engine/strategies.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Deprecated mock engine strategy used by Alembic. + + +""" + +from __future__ import annotations + +from .mock import MockConnection # noqa + + +class MockEngineStrategy: + MockConnection = MockConnection diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/url.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/url.py new file mode 100644 index 0000000..1eeb73a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/url.py @@ -0,0 +1,910 @@ +# engine/url.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides the :class:`~sqlalchemy.engine.url.URL` class which encapsulates +information about a database connection specification. + +The URL object is created automatically when +:func:`~sqlalchemy.engine.create_engine` is called with a string +argument; alternatively, the URL is a public-facing construct which can +be used directly and is also accepted directly by ``create_engine()``. +""" + +from __future__ import annotations + +import collections.abc as collections_abc +import re +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import Union +from urllib.parse import parse_qsl +from urllib.parse import quote +from urllib.parse import quote_plus +from urllib.parse import unquote + +from .interfaces import Dialect +from .. import exc +from .. import util +from ..dialects import plugins +from ..dialects import registry + + +class URL(NamedTuple): + """ + Represent the components of a URL used to connect to a database. + + URLs are typically constructed from a fully formatted URL string, where the + :func:`.make_url` function is used internally by the + :func:`_sa.create_engine` function in order to parse the URL string into + its individual components, which are then used to construct a new + :class:`.URL` object. When parsing from a formatted URL string, the parsing + format generally follows + `RFC-1738 `_, with some exceptions. + + A :class:`_engine.URL` object may also be produced directly, either by + using the :func:`.make_url` function with a fully formed URL string, or + by using the :meth:`_engine.URL.create` constructor in order + to construct a :class:`_engine.URL` programmatically given individual + fields. The resulting :class:`.URL` object may be passed directly to + :func:`_sa.create_engine` in place of a string argument, which will bypass + the usage of :func:`.make_url` within the engine's creation process. + + .. versionchanged:: 1.4 + + The :class:`_engine.URL` object is now an immutable object. To + create a URL, use the :func:`_engine.make_url` or + :meth:`_engine.URL.create` function / method. To modify + a :class:`_engine.URL`, use methods like + :meth:`_engine.URL.set` and + :meth:`_engine.URL.update_query_dict` to return a new + :class:`_engine.URL` object with modifications. See notes for this + change at :ref:`change_5526`. + + .. seealso:: + + :ref:`database_urls` + + :class:`_engine.URL` contains the following attributes: + + * :attr:`_engine.URL.drivername`: database backend and driver name, such as + ``postgresql+psycopg2`` + * :attr:`_engine.URL.username`: username string + * :attr:`_engine.URL.password`: password string + * :attr:`_engine.URL.host`: string hostname + * :attr:`_engine.URL.port`: integer port number + * :attr:`_engine.URL.database`: string database name + * :attr:`_engine.URL.query`: an immutable mapping representing the query + string. contains strings for keys and either strings or tuples of + strings for values. + + + """ + + drivername: str + """database backend and driver name, such as + ``postgresql+psycopg2`` + + """ + + username: Optional[str] + "username string" + + password: Optional[str] + """password, which is normally a string but may also be any + object that has a ``__str__()`` method.""" + + host: Optional[str] + """hostname or IP number. May also be a data source name for some + drivers.""" + + port: Optional[int] + """integer port number""" + + database: Optional[str] + """database name""" + + query: util.immutabledict[str, Union[Tuple[str, ...], str]] + """an immutable mapping representing the query string. contains strings + for keys and either strings or tuples of strings for values, e.g.:: + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> url.query + immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) + + To create a mutable copy of this mapping, use the ``dict`` constructor:: + + mutable_query_opts = dict(url.query) + + .. seealso:: + + :attr:`_engine.URL.normalized_query` - normalizes all values into sequences + for consistent processing + + Methods for altering the contents of :attr:`_engine.URL.query`: + + :meth:`_engine.URL.update_query_dict` + + :meth:`_engine.URL.update_query_string` + + :meth:`_engine.URL.update_query_pairs` + + :meth:`_engine.URL.difference_update_query` + + """ # noqa: E501 + + @classmethod + def create( + cls, + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = util.EMPTY_DICT, + ) -> URL: + """Create a new :class:`_engine.URL` object. + + .. seealso:: + + :ref:`database_urls` + + :param drivername: the name of the database backend. This name will + correspond to a module in sqlalchemy/databases or a third party + plug-in. + :param username: The user name. + :param password: database password. Is typically a string, but may + also be an object that can be stringified with ``str()``. + + .. note:: The password string should **not** be URL encoded when + passed as an argument to :meth:`_engine.URL.create`; the string + should contain the password characters exactly as they would be + typed. + + .. note:: A password-producing object will be stringified only + **once** per :class:`_engine.Engine` object. For dynamic password + generation per connect, see :ref:`engines_dynamic_tokens`. + + :param host: The name of the host. + :param port: The port number. + :param database: The database name. + :param query: A dictionary of string keys to string values to be passed + to the dialect and/or the DBAPI upon connect. To specify non-string + parameters to a Python DBAPI directly, use the + :paramref:`_sa.create_engine.connect_args` parameter to + :func:`_sa.create_engine`. See also + :attr:`_engine.URL.normalized_query` for a dictionary that is + consistently string->list of string. + :return: new :class:`_engine.URL` object. + + .. versionadded:: 1.4 + + The :class:`_engine.URL` object is now an **immutable named + tuple**. In addition, the ``query`` dictionary is also immutable. + To create a URL, use the :func:`_engine.url.make_url` or + :meth:`_engine.URL.create` function/ method. To modify a + :class:`_engine.URL`, use the :meth:`_engine.URL.set` and + :meth:`_engine.URL.update_query` methods. + + """ + + return cls( + cls._assert_str(drivername, "drivername"), + cls._assert_none_str(username, "username"), + password, + cls._assert_none_str(host, "host"), + cls._assert_port(port), + cls._assert_none_str(database, "database"), + cls._str_dict(query), + ) + + @classmethod + def _assert_port(cls, port: Optional[int]) -> Optional[int]: + if port is None: + return None + try: + return int(port) + except TypeError: + raise TypeError("Port argument must be an integer or None") + + @classmethod + def _assert_str(cls, v: str, paramname: str) -> str: + if not isinstance(v, str): + raise TypeError("%s must be a string" % paramname) + return v + + @classmethod + def _assert_none_str( + cls, v: Optional[str], paramname: str + ) -> Optional[str]: + if v is None: + return v + + return cls._assert_str(v, paramname) + + @classmethod + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> util.immutabledict[str, Union[Tuple[str, ...], str]]: + if dict_ is None: + return util.EMPTY_DICT + + @overload + def _assert_value( + val: str, + ) -> str: ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: + if isinstance(val, str): + return val + elif isinstance(val, collections_abc.Sequence): + return tuple(_assert_value(elem) for elem in val) + else: + raise TypeError( + "Query dictionary values must be strings or " + "sequences of strings" + ) + + def _assert_str(v: str) -> str: + if not isinstance(v, str): + raise TypeError("Query dictionary keys must be strings") + return v + + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] + if isinstance(dict_, collections_abc.Sequence): + dict_items = dict_ + else: + dict_items = dict_.items() + + return util.immutabledict( + { + _assert_str(key): _assert_value( + value, + ) + for key, value in dict_items + } + ) + + def set( + self, + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> URL: + """return a new :class:`_engine.URL` object with modifications. + + Values are used if they are non-None. To set a value to ``None`` + explicitly, use the :meth:`_engine.URL._replace` method adapted + from ``namedtuple``. + + :param drivername: new drivername + :param username: new username + :param password: new password + :param host: new hostname + :param port: new port + :param query: new query parameters, passed a dict of string keys + referring to string or sequence of string values. Fully + replaces the previous list of arguments. + + :return: new :class:`_engine.URL` object. + + .. versionadded:: 1.4 + + .. seealso:: + + :meth:`_engine.URL.update_query_dict` + + """ + + kw: Dict[str, Any] = {} + if drivername is not None: + kw["drivername"] = drivername + if username is not None: + kw["username"] = username + if password is not None: + kw["password"] = password + if host is not None: + kw["host"] = host + if port is not None: + kw["port"] = port + if database is not None: + kw["database"] = database + if query is not None: + kw["query"] = query + + return self._assert_replace(**kw) + + def _assert_replace(self, **kw: Any) -> URL: + """argument checks before calling _replace()""" + + if "drivername" in kw: + self._assert_str(kw["drivername"], "drivername") + for name in "username", "host", "database": + if name in kw: + self._assert_none_str(kw[name], name) + if "port" in kw: + self._assert_port(kw["port"]) + if "query" in kw: + kw["query"] = self._str_dict(kw["query"]) + + return self._replace(**kw) + + def update_query_string( + self, query_string: str, append: bool = False + ) -> URL: + """Return a new :class:`_engine.URL` object with the :attr:`_engine.URL.query` + parameter dictionary updated by the given query string. + + E.g.:: + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") + >>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> str(url) + 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' + + :param query_string: a URL escaped query string, not including the + question mark. + + :param append: if True, parameters in the existing query string will + not be removed; new parameters will be in addition to those present. + If left at its default of False, keys present in the given query + parameters will replace those of the existing query string. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.update_query_dict` + + """ # noqa: E501 + return self.update_query_pairs(parse_qsl(query_string), append=append) + + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> URL: + """Return a new :class:`_engine.URL` object with the + :attr:`_engine.URL.query` + parameter dictionary updated by the given sequence of key/value pairs + + E.g.:: + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") + >>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")]) + >>> str(url) + 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' + + :param key_value_pairs: A sequence of tuples containing two strings + each. + + :param append: if True, parameters in the existing query string will + not be removed; new parameters will be in addition to those present. + If left at its default of False, keys present in the given query + parameters will replace those of the existing query string. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.difference_update_query` + + :meth:`_engine.URL.set` + + """ # noqa: E501 + + existing_query = self.query + new_keys: Dict[str, Union[str, List[str]]] = {} + + for key, value in key_value_pairs: + if key in new_keys: + new_keys[key] = util.to_list(new_keys[key]) + cast("List[str]", new_keys[key]).append(cast(str, value)) + else: + new_keys[key] = ( + list(value) if isinstance(value, (list, tuple)) else value + ) + + new_query: Mapping[str, Union[str, Sequence[str]]] + if append: + new_query = {} + + for k in new_keys: + if k in existing_query: + new_query[k] = tuple( + util.to_list(existing_query[k]) + + util.to_list(new_keys[k]) + ) + else: + new_query[k] = new_keys[k] + + new_query.update( + { + k: existing_query[k] + for k in set(existing_query).difference(new_keys) + } + ) + else: + new_query = self.query.union( + { + k: tuple(v) if isinstance(v, list) else v + for k, v in new_keys.items() + } + ) + return self.set(query=new_query) + + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> URL: + """Return a new :class:`_engine.URL` object with the + :attr:`_engine.URL.query` parameter dictionary updated by the given + dictionary. + + The dictionary typically contains string keys and string values. + In order to represent a query parameter that is expressed multiple + times, pass a sequence of string values. + + E.g.:: + + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") + >>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"}) + >>> str(url) + 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' + + + :param query_parameters: A dictionary with string keys and values + that are either strings, or sequences of strings. + + :param append: if True, parameters in the existing query string will + not be removed; new parameters will be in addition to those present. + If left at its default of False, keys present in the given query + parameters will replace those of the existing query string. + + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.update_query_string` + + :meth:`_engine.URL.update_query_pairs` + + :meth:`_engine.URL.difference_update_query` + + :meth:`_engine.URL.set` + + """ # noqa: E501 + return self.update_query_pairs(query_parameters.items(), append=append) + + def difference_update_query(self, names: Iterable[str]) -> URL: + """ + Remove the given names from the :attr:`_engine.URL.query` dictionary, + returning the new :class:`_engine.URL`. + + E.g.:: + + url = url.difference_update_query(['foo', 'bar']) + + Equivalent to using :meth:`_engine.URL.set` as follows:: + + url = url.set( + query={ + key: url.query[key] + for key in set(url.query).difference(['foo', 'bar']) + } + ) + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_engine.URL.query` + + :meth:`_engine.URL.update_query_dict` + + :meth:`_engine.URL.set` + + """ + + if not set(names).intersection(self.query): + return self + + return URL( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + util.immutabledict( + { + key: self.query[key] + for key in set(self.query).difference(names) + } + ), + ) + + @property + def normalized_query(self) -> Mapping[str, Sequence[str]]: + """Return the :attr:`_engine.URL.query` dictionary with values normalized + into sequences. + + As the :attr:`_engine.URL.query` dictionary may contain either + string values or sequences of string values to differentiate between + parameters that are specified multiple times in the query string, + code that needs to handle multiple parameters generically will wish + to use this attribute so that all parameters present are presented + as sequences. Inspiration is from Python's ``urllib.parse.parse_qs`` + function. E.g.:: + + + >>> from sqlalchemy.engine import make_url + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") + >>> url.query + immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) + >>> url.normalized_query + immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': ('/path/to/crt',)}) + + """ # noqa: E501 + + return util.immutabledict( + { + k: (v,) if not isinstance(v, tuple) else v + for k, v in self.query.items() + } + ) + + @util.deprecated( + "1.4", + "The :meth:`_engine.URL.__to_string__ method is deprecated and will " + "be removed in a future release. Please use the " + ":meth:`_engine.URL.render_as_string` method.", + ) + def __to_string__(self, hide_password: bool = True) -> str: + """Render this :class:`_engine.URL` object as a string. + + :param hide_password: Defaults to True. The password is not shown + in the string unless this is set to False. + + """ + return self.render_as_string(hide_password=hide_password) + + def render_as_string(self, hide_password: bool = True) -> str: + """Render this :class:`_engine.URL` object as a string. + + This method is used when the ``__str__()`` or ``__repr__()`` + methods are used. The method directly includes additional options. + + :param hide_password: Defaults to True. The password is not shown + in the string unless this is set to False. + + """ + s = self.drivername + "://" + if self.username is not None: + s += quote(self.username, safe=" +") + if self.password is not None: + s += ":" + ( + "***" + if hide_password + else quote(str(self.password), safe=" +") + ) + s += "@" + if self.host is not None: + if ":" in self.host: + s += f"[{self.host}]" + else: + s += self.host + if self.port is not None: + s += ":" + str(self.port) + if self.database is not None: + s += "/" + self.database + if self.query: + keys = list(self.query) + keys.sort() + s += "?" + "&".join( + f"{quote_plus(k)}={quote_plus(element)}" + for k in keys + for element in util.to_list(self.query[k]) + ) + return s + + def __repr__(self) -> str: + return self.render_as_string() + + def __copy__(self) -> URL: + return self.__class__.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + # note this is an immutabledict of str-> str / tuple of str, + # also fully immutable. does not require deepcopy + self.query, + ) + + def __deepcopy__(self, memo: Any) -> URL: + return self.__copy__() + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + and self.port == other.port + ) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def get_backend_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the database backend in + use, and is the portion of the :attr:`_engine.URL.drivername` + that is to the left of the plus sign. + + """ + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[0] + + def get_driver_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the DBAPI driver in + use, and is the portion of the :attr:`_engine.URL.drivername` + that is to the right of the plus sign. + + If the :attr:`_engine.URL.drivername` does not include a plus sign, + then the default :class:`_engine.Dialect` for this :class:`_engine.URL` + is imported in order to get the driver name. + + """ + + if "+" not in self.drivername: + return self.get_dialect().driver + else: + return self.drivername.split("+")[1] + + def _instantiate_plugins( + self, kwargs: Mapping[str, Any] + ) -> Tuple[URL, List[Any], Dict[str, Any]]: + plugin_names = util.to_list(self.query.get("plugin", ())) + plugin_names += kwargs.get("plugins", []) + + kwargs = dict(kwargs) + + loaded_plugins = [ + plugins.load(plugin_name)(self, kwargs) + for plugin_name in plugin_names + ] + + u = self.difference_update_query(["plugin", "plugins"]) + + for plugin in loaded_plugins: + new_u = plugin.update_url(u) + if new_u is not None: + u = new_u + + kwargs.pop("plugins", None) + + return u, loaded_plugins, kwargs + + def _get_entrypoint(self) -> Type[Dialect]: + """Return the "entry point" dialect class. + + This is normally the dialect itself except in the case when the + returned class implements the get_dialect_cls() method. + + """ + if "+" not in self.drivername: + name = self.drivername + else: + name = self.drivername.replace("+", ".") + cls = registry.load(name) + # check for legacy dialects that + # would return a module with 'dialect' as the + # actual class + if ( + hasattr(cls, "dialect") + and isinstance(cls.dialect, type) + and issubclass(cls.dialect, Dialect) + ): + return cls.dialect + else: + return cast("Type[Dialect]", cls) + + def get_dialect(self, _is_async: bool = False) -> Type[Dialect]: + """Return the SQLAlchemy :class:`_engine.Dialect` class corresponding + to this URL's driver name. + + """ + entrypoint = self._get_entrypoint() + if _is_async: + dialect_cls = entrypoint.get_async_dialect_cls(self) + else: + dialect_cls = entrypoint.get_dialect_cls(self) + return dialect_cls + + def translate_connect_args( + self, names: Optional[List[str]] = None, **kw: Any + ) -> Dict[str, Any]: + r"""Translate url attributes into a dictionary of connection arguments. + + Returns attributes of this url (`host`, `database`, `username`, + `password`, `port`) as a plain dictionary. The attribute names are + used as the keys by default. Unset or false attributes are omitted + from the final dictionary. + + :param \**kw: Optional, alternate key names for url attributes. + + :param names: Deprecated. Same purpose as the keyword-based alternate + names, but correlates the name to the original positionally. + """ + + if names is not None: + util.warn_deprecated( + "The `URL.translate_connect_args.name`s parameter is " + "deprecated. Please pass the " + "alternate names as kw arguments.", + "1.4", + ) + + translated = {} + attribute_names = ["host", "database", "username", "password", "port"] + for sname in attribute_names: + if names: + name = names.pop(0) + elif sname in kw: + name = kw[sname] + else: + name = sname + if name is not None and getattr(self, sname, False): + if sname == "password": + translated[name] = str(getattr(self, sname)) + else: + translated[name] = getattr(self, sname) + + return translated + + +def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. + + The format of the URL generally follows `RFC-1738 + `_, with some exceptions, including + that underscores, and not dashes or periods, are accepted within the + "scheme" portion. + + If a :class:`.URL` object is passed, it is returned as is. + + .. seealso:: + + :ref:`database_urls` + + """ + + if isinstance(name_or_url, str): + return _parse_url(name_or_url) + elif not isinstance(name_or_url, URL) and not hasattr( + name_or_url, "_sqla_is_testing_if_this_is_a_mock_object" + ): + raise exc.ArgumentError( + f"Expected string or URL object, got {name_or_url!r}" + ) + else: + return name_or_url + + +def _parse_url(name: str) -> URL: + pattern = re.compile( + r""" + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P[^@]*))? + @)? + (?: + (?: + \[(?P[^/\?]+)\] | + (?P[^/:\?]+) + )? + (?::(?P[^/\?]*))? + )? + (?:/(?P[^\?]*))? + (?:\?(?P.*))? + """, + re.X, + ) + + m = pattern.match(name) + if m is not None: + components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] + if components["query"] is not None: + query = {} + + for key, value in parse_qsl(components["query"]): + if key in query: + query[key] = util.to_list(query[key]) + cast("List[str]", query[key]).append(value) + else: + query[key] = value + else: + query = None + components["query"] = query + + if components["username"] is not None: + components["username"] = unquote(components["username"]) + + if components["password"] is not None: + components["password"] = unquote(components["password"]) + + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") + + if components["port"]: + components["port"] = int(components["port"]) + + return URL.create(name, **components) # type: ignore + + else: + raise exc.ArgumentError( + "Could not parse SQLAlchemy URL from string '%s'" % name + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/engine/util.py b/venv/lib/python3.11/site-packages/sqlalchemy/engine/util.py new file mode 100644 index 0000000..186ca4c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/engine/util.py @@ -0,0 +1,167 @@ +# engine/util.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Optional +from typing import TypeVar + +from .. import exc +from .. import util +from ..util._has_cy import HAS_CYEXTENSION +from ..util.typing import Protocol +from ..util.typing import Self + +if typing.TYPE_CHECKING or not HAS_CYEXTENSION: + from ._py_util import _distill_params_20 as _distill_params_20 + from ._py_util import _distill_raw_params as _distill_raw_params +else: + from sqlalchemy.cyextension.util import ( # noqa: F401 + _distill_params_20 as _distill_params_20, + ) + from sqlalchemy.cyextension.util import ( # noqa: F401 + _distill_raw_params as _distill_raw_params, + ) + +_C = TypeVar("_C", bound=Callable[[], Any]) + + +def connection_memoize(key: str) -> Callable[[_C], _C]: + """Decorator, memoize a function in a connection.info stash. + + Only applicable to functions which take no arguments other than a + connection. The memo will be stored in ``connection.info[key]``. + """ + + @util.decorator + def decorated(fn, self, connection): # type: ignore + connection = connection.connect() + try: + return connection.info[key] + except KeyError: + connection.info[key] = val = fn(self, connection) + return val + + return decorated + + +class _TConsSubject(Protocol): + _trans_context_manager: Optional[TransactionalContext] + + +class TransactionalContext: + """Apply Python context manager behavior to transaction objects. + + Performs validation to ensure the subject of the transaction is not + used if the transaction were ended prematurely. + + """ + + __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__") + + _trans_subject: Optional[_TConsSubject] + + def _transaction_is_active(self) -> bool: + raise NotImplementedError() + + def _transaction_is_closed(self) -> bool: + raise NotImplementedError() + + def _rollback_can_be_called(self) -> bool: + """indicates the object is in a state that is known to be acceptable + for rollback() to be called. + + This does not necessarily mean rollback() will succeed or not raise + an error, just that there is currently no state detected that indicates + rollback() would fail or emit warnings. + + It also does not mean that there's a transaction in progress, as + it is usually safe to call rollback() even if no transaction is + present. + + .. versionadded:: 1.4.28 + + """ + raise NotImplementedError() + + def _get_subject(self) -> _TConsSubject: + raise NotImplementedError() + + def commit(self) -> None: + raise NotImplementedError() + + def rollback(self) -> None: + raise NotImplementedError() + + def close(self) -> None: + raise NotImplementedError() + + @classmethod + def _trans_ctx_check(cls, subject: _TConsSubject) -> None: + trans_context = subject._trans_context_manager + if trans_context: + if not trans_context._transaction_is_active(): + raise exc.InvalidRequestError( + "Can't operate on closed transaction inside context " + "manager. Please complete the context manager " + "before emitting further commands." + ) + + def __enter__(self) -> Self: + subject = self._get_subject() + + # none for outer transaction, may be non-None for nested + # savepoint, legacy nesting cases + trans_context = subject._trans_context_manager + self._outer_trans_ctx = trans_context + + self._trans_subject = subject + subject._trans_context_manager = self + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + subject = getattr(self, "_trans_subject", None) + + # simplistically we could assume that + # "subject._trans_context_manager is self". However, any calling + # code that is manipulating __exit__ directly would break this + # assumption. alembic context manager + # is an example of partial use that just calls __exit__ and + # not __enter__ at the moment. it's safe to assume this is being done + # in the wild also + out_of_band_exit = ( + subject is None or subject._trans_context_manager is not self + ) + + if type_ is None and self._transaction_is_active(): + try: + self.commit() + except: + with util.safe_reraise(): + if self._rollback_can_be_called(): + self.rollback() + finally: + if not out_of_band_exit: + assert subject is not None + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None + else: + try: + if not self._transaction_is_active(): + if not self._transaction_is_closed(): + self.close() + else: + if self._rollback_can_be_called(): + self.rollback() + finally: + if not out_of_band_exit: + assert subject is not None + subject._trans_context_manager = self._outer_trans_ctx + self._trans_subject = self._outer_trans_ctx = None diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/event/__init__.py new file mode 100644 index 0000000..9b54f07 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/event/__init__.py @@ -0,0 +1,25 @@ +# event/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from .api import CANCEL as CANCEL +from .api import contains as contains +from .api import listen as listen +from .api import listens_for as listens_for +from .api import NO_RETVAL as NO_RETVAL +from .api import remove as remove +from .attr import _InstanceLevelDispatch as _InstanceLevelDispatch +from .attr import RefCollection as RefCollection +from .base import _Dispatch as _Dispatch +from .base import _DispatchCommon as _DispatchCommon +from .base import dispatcher as dispatcher +from .base import Events as Events +from .legacy import _legacy_signature as _legacy_signature +from .registry import _EventKey as _EventKey +from .registry import _ListenerFnType as _ListenerFnType +from .registry import EventTarget as EventTarget diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..cdacdf7 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/api.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/api.cpython-311.pyc new file mode 100644 index 0000000..9495b80 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/api.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/attr.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/attr.cpython-311.pyc new file mode 100644 index 0000000..264917e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/attr.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..d2dce27 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/legacy.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/legacy.cpython-311.pyc new file mode 100644 index 0000000..f6911d7 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/legacy.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/registry.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/registry.cpython-311.pyc new file mode 100644 index 0000000..7bf4877 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/event/__pycache__/registry.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/api.py b/venv/lib/python3.11/site-packages/sqlalchemy/event/api.py new file mode 100644 index 0000000..4a39d10 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/event/api.py @@ -0,0 +1,225 @@ +# event/api.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Public API functions for the event system. + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable + +from .base import _registrars +from .registry import _ET +from .registry import _EventKey +from .registry import _ListenerFnType +from .. import exc +from .. import util + + +CANCEL = util.symbol("CANCEL") +NO_RETVAL = util.symbol("NO_RETVAL") + + +def _event_key( + target: _ET, identifier: str, fn: _ListenerFnType +) -> _EventKey[_ET]: + for evt_cls in _registrars[identifier]: + tgt = evt_cls._accept_with(target, identifier) + if tgt is not None: + return _EventKey(target, identifier, fn, tgt) + else: + raise exc.InvalidRequestError( + "No such event '%s' for target '%s'" % (identifier, target) + ) + + +def listen( + target: Any, identifier: str, fn: Callable[..., Any], *args: Any, **kw: Any +) -> None: + """Register a listener function for the given target. + + The :func:`.listen` function is part of the primary interface for the + SQLAlchemy event system, documented at :ref:`event_toplevel`. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.schema import UniqueConstraint + + def unique_constraint_name(const, table): + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) + event.listen( + UniqueConstraint, + "after_parent_attach", + unique_constraint_name) + + :param bool insert: The default behavior for event handlers is to append + the decorated user defined function to an internal list of registered + event listeners upon discovery. If a user registers a function with + ``insert=True``, SQLAlchemy will insert (prepend) the function to the + internal list upon discovery. This feature is not typically used or + recommended by the SQLAlchemy maintainers, but is provided to ensure + certain user defined functions can run before others, such as when + :ref:`Changing the sql_mode in MySQL `. + + :param bool named: When using named argument passing, the names listed in + the function argument specification will be used as keys in the + dictionary. + See :ref:`event_named_argument_styles`. + + :param bool once: Private/Internal API usage. Deprecated. This parameter + would provide that an event function would run only once per given + target. It does not however imply automatic de-registration of the + listener function; associating an arbitrarily high number of listeners + without explicitly removing them will cause memory to grow unbounded even + if ``once=True`` is specified. + + :param bool propagate: The ``propagate`` kwarg is available when working + with ORM instrumentation and mapping events. + See :class:`_ormevent.MapperEvents` and + :meth:`_ormevent.MapperEvents.before_mapper_configured` for examples. + + :param bool retval: This flag applies only to specific event listeners, + each of which includes documentation explaining when it should be used. + By default, no listener ever requires a return value. + However, some listeners do support special behaviors for return values, + and include in their documentation that the ``retval=True`` flag is + necessary for a return value to be processed. + + Event listener suites that make use of :paramref:`_event.listen.retval` + include :class:`_events.ConnectionEvents` and + :class:`_ormevent.AttributeEvents`. + + .. note:: + + The :func:`.listen` function cannot be called at the same time + that the target event is being run. This has implications + for thread safety, and also means an event cannot be added + from inside the listener function for itself. The list of + events to be run are present inside of a mutable collection + that can't be changed during iteration. + + Event registration and removal is not intended to be a "high + velocity" operation; it is a configurational operation. For + systems that need to quickly associate and deassociate with + events at high scale, use a mutable structure that is handled + from inside of a single listener. + + .. seealso:: + + :func:`.listens_for` + + :func:`.remove` + + """ + + _event_key(target, identifier, fn).listen(*args, **kw) + + +def listens_for( + target: Any, identifier: str, *args: Any, **kw: Any +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorate a function as a listener for the given target + identifier. + + The :func:`.listens_for` decorator is part of the primary interface for the + SQLAlchemy event system, documented at :ref:`event_toplevel`. + + This function generally shares the same kwargs as :func:`.listens`. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.schema import UniqueConstraint + + @event.listens_for(UniqueConstraint, "after_parent_attach") + def unique_constraint_name(const, table): + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) + + A given function can also be invoked for only the first invocation + of the event using the ``once`` argument:: + + @event.listens_for(Mapper, "before_configure", once=True) + def on_config(): + do_config() + + + .. warning:: The ``once`` argument does not imply automatic de-registration + of the listener function after it has been invoked a first time; a + listener entry will remain associated with the target object. + Associating an arbitrarily high number of listeners without explicitly + removing them will cause memory to grow unbounded even if ``once=True`` + is specified. + + .. seealso:: + + :func:`.listen` - general description of event listening + + """ + + def decorate(fn: Callable[..., Any]) -> Callable[..., Any]: + listen(target, identifier, fn, *args, **kw) + return fn + + return decorate + + +def remove(target: Any, identifier: str, fn: Callable[..., Any]) -> None: + """Remove an event listener. + + The arguments here should match exactly those which were sent to + :func:`.listen`; all the event registration which proceeded as a result + of this call will be reverted by calling :func:`.remove` with the same + arguments. + + e.g.:: + + # if a function was registered like this... + @event.listens_for(SomeMappedClass, "before_insert", propagate=True) + def my_listener_function(*arg): + pass + + # ... it's removed like this + event.remove(SomeMappedClass, "before_insert", my_listener_function) + + Above, the listener function associated with ``SomeMappedClass`` was also + propagated to subclasses of ``SomeMappedClass``; the :func:`.remove` + function will revert all of these operations. + + .. note:: + + The :func:`.remove` function cannot be called at the same time + that the target event is being run. This has implications + for thread safety, and also means an event cannot be removed + from inside the listener function for itself. The list of + events to be run are present inside of a mutable collection + that can't be changed during iteration. + + Event registration and removal is not intended to be a "high + velocity" operation; it is a configurational operation. For + systems that need to quickly associate and deassociate with + events at high scale, use a mutable structure that is handled + from inside of a single listener. + + .. seealso:: + + :func:`.listen` + + """ + _event_key(target, identifier, fn).remove() + + +def contains(target: Any, identifier: str, fn: Callable[..., Any]) -> bool: + """Return True if the given target/ident/fn is set up to listen.""" + + return _event_key(target, identifier, fn).contains() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/attr.py b/venv/lib/python3.11/site-packages/sqlalchemy/event/attr.py new file mode 100644 index 0000000..ef2b334 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/event/attr.py @@ -0,0 +1,655 @@ +# event/attr.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Attribute implementation for _Dispatch classes. + +The various listener targets for a particular event class are represented +as attributes, which refer to collections of listeners to be fired off. +These collections can exist at the class level as well as at the instance +level. An event is fired off using code like this:: + + some_object.dispatch.first_connect(arg1, arg2) + +Above, ``some_object.dispatch`` would be an instance of ``_Dispatch`` and +``first_connect`` is typically an instance of ``_ListenerCollection`` +if event listeners are present, or ``_EmptyListener`` if none are present. + +The attribute mechanics here spend effort trying to ensure listener functions +are available with a minimum of function call overhead, that unnecessary +objects aren't created (i.e. many empty per-instance listener collections), +as well as that everything is garbage collectable when owning references are +lost. Other features such as "propagation" of listener functions across +many ``_Dispatch`` instances, "joining" of multiple ``_Dispatch`` instances, +as well as support for subclass propagation (e.g. events assigned to +``Pool`` vs. ``QueuePool``) are all implemented here. + +""" +from __future__ import annotations + +import collections +from itertools import chain +import threading +from types import TracebackType +import typing +from typing import Any +from typing import cast +from typing import Collection +from typing import Deque +from typing import FrozenSet +from typing import Generic +from typing import Iterator +from typing import MutableMapping +from typing import MutableSequence +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +import weakref + +from . import legacy +from . import registry +from .registry import _ET +from .registry import _EventKey +from .registry import _ListenerFnType +from .. import exc +from .. import util +from ..util.concurrency import AsyncAdaptedLock +from ..util.typing import Protocol + +_T = TypeVar("_T", bound=Any) + +if typing.TYPE_CHECKING: + from .base import _Dispatch + from .base import _DispatchCommon + from .base import _HasEventsDispatch + + +class RefCollection(util.MemoizedSlots, Generic[_ET]): + __slots__ = ("ref",) + + ref: weakref.ref[RefCollection[_ET]] + + def _memoized_attr_ref(self) -> weakref.ref[RefCollection[_ET]]: + return weakref.ref(self, registry._collection_gced) + + +class _empty_collection(Collection[_T]): + def append(self, element: _T) -> None: + pass + + def appendleft(self, element: _T) -> None: + pass + + def extend(self, other: Sequence[_T]) -> None: + pass + + def remove(self, element: _T) -> None: + pass + + def __contains__(self, element: Any) -> bool: + return False + + def __iter__(self) -> Iterator[_T]: + return iter([]) + + def clear(self) -> None: + pass + + def __len__(self) -> int: + return 0 + + +_ListenerFnSequenceType = Union[Deque[_T], _empty_collection[_T]] + + +class _ClsLevelDispatch(RefCollection[_ET]): + """Class-level events on :class:`._Dispatch` classes.""" + + __slots__ = ( + "clsname", + "name", + "arg_names", + "has_kw", + "legacy_signatures", + "_clslevel", + "__weakref__", + ) + + clsname: str + name: str + arg_names: Sequence[str] + has_kw: bool + legacy_signatures: MutableSequence[legacy._LegacySignatureType] + _clslevel: MutableMapping[ + Type[_ET], _ListenerFnSequenceType[_ListenerFnType] + ] + + def __init__( + self, + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + fn: _ListenerFnType, + ): + self.name = fn.__name__ + self.clsname = parent_dispatch_cls.__name__ + argspec = util.inspect_getfullargspec(fn) + self.arg_names = argspec.args[1:] + self.has_kw = bool(argspec.varkw) + self.legacy_signatures = list( + reversed( + sorted( + getattr(fn, "_legacy_signatures", []), key=lambda s: s[0] + ) + ) + ) + fn.__doc__ = legacy._augment_fn_docs(self, parent_dispatch_cls, fn) + + self._clslevel = weakref.WeakKeyDictionary() + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: + if named: + fn = self._wrap_fn_for_kw(fn) + if self.legacy_signatures: + try: + argspec = util.get_callable_argspec(fn, no_self=True) + except TypeError: + pass + else: + fn = legacy._wrap_fn_for_legacy(self, fn, argspec) + return fn + + def _wrap_fn_for_kw(self, fn: _ListenerFnType) -> _ListenerFnType: + def wrap_kw(*args: Any, **kw: Any) -> Any: + argdict = dict(zip(self.arg_names, args)) + argdict.update(kw) + return fn(**argdict) + + return wrap_kw + + def _do_insert_or_append( + self, event_key: _EventKey[_ET], is_append: bool + ) -> None: + target = event_key.dispatch_target + assert isinstance( + target, type + ), "Class-level Event targets must be classes." + if not getattr(target, "_sa_propagate_class_events", True): + raise exc.InvalidRequestError( + f"Can't assign an event directly to the {target} class" + ) + + cls: Type[_ET] + + for cls in util.walk_subclasses(target): + if cls is not target and cls not in self._clslevel: + self.update_subclass(cls) + else: + if cls not in self._clslevel: + self.update_subclass(cls) + if is_append: + self._clslevel[cls].append(event_key._listen_fn) + else: + self._clslevel[cls].appendleft(event_key._listen_fn) + registry._stored_in_collection(event_key, self) + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=False) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self._do_insert_or_append(event_key, is_append=True) + + def update_subclass(self, target: Type[_ET]) -> None: + if target not in self._clslevel: + if getattr(target, "_sa_propagate_class_events", True): + self._clslevel[target] = collections.deque() + else: + self._clslevel[target] = _empty_collection() + + clslevel = self._clslevel[target] + cls: Type[_ET] + for cls in target.__mro__[1:]: + if cls in self._clslevel: + clslevel.extend( + [fn for fn in self._clslevel[cls] if fn not in clslevel] + ) + + def remove(self, event_key: _EventKey[_ET]) -> None: + target = event_key.dispatch_target + cls: Type[_ET] + for cls in util.walk_subclasses(target): + if cls in self._clslevel: + self._clslevel[cls].remove(event_key._listen_fn) + registry._removed_from_collection(event_key, self) + + def clear(self) -> None: + """Clear all class level listeners""" + + to_clear: Set[_ListenerFnType] = set() + for dispatcher in self._clslevel.values(): + to_clear.update(dispatcher) + dispatcher.clear() + registry._clear(self, to_clear) + + def for_modify(self, obj: _Dispatch[_ET]) -> _ClsLevelDispatch[_ET]: + """Return an event collection which can be modified. + + For _ClsLevelDispatch at the class level of + a dispatcher, this returns self. + + """ + return self + + +class _InstanceLevelDispatch(RefCollection[_ET], Collection[_ListenerFnType]): + __slots__ = () + + parent: _ClsLevelDispatch[_ET] + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: + return self.parent._adjust_fn_spec(fn, named) + + def __contains__(self, item: Any) -> bool: + raise NotImplementedError() + + def __len__(self) -> int: + raise NotImplementedError() + + def __iter__(self) -> Iterator[_ListenerFnType]: + raise NotImplementedError() + + def __bool__(self) -> bool: + raise NotImplementedError() + + def exec_once(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def __call__(self, *args: Any, **kw: Any) -> None: + raise NotImplementedError() + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + raise NotImplementedError() + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + raise NotImplementedError() + + def remove(self, event_key: _EventKey[_ET]) -> None: + raise NotImplementedError() + + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _InstanceLevelDispatch[_ET]: + """Return an event collection which can be modified. + + For _ClsLevelDispatch at the class level of + a dispatcher, this returns self. + + """ + return self + + +class _EmptyListener(_InstanceLevelDispatch[_ET]): + """Serves as a proxy interface to the events + served by a _ClsLevelDispatch, when there are no + instance-level events present. + + Is replaced by _ListenerCollection when instance-level + events are added. + + """ + + __slots__ = "parent", "parent_listeners", "name" + + propagate: FrozenSet[_ListenerFnType] = frozenset() + listeners: Tuple[()] = () + parent: _ClsLevelDispatch[_ET] + parent_listeners: _ListenerFnSequenceType[_ListenerFnType] + name: str + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) + self.parent = parent + self.parent_listeners = parent._clslevel[target_cls] + self.name = parent.name + + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _ListenerCollection[_ET]: + """Return an event collection which can be modified. + + For _EmptyListener at the instance level of + a dispatcher, this generates a new + _ListenerCollection, applies it to the instance, + and returns it. + + """ + obj = cast("_Dispatch[_ET]", obj) + + assert obj._instance_cls is not None + result = _ListenerCollection(self.parent, obj._instance_cls) + if getattr(obj, self.name) is self: + setattr(obj, self.name, result) + else: + assert isinstance(getattr(obj, self.name), _JoinedListener) + return result + + def _needs_modify(self, *args: Any, **kw: Any) -> NoReturn: + raise NotImplementedError("need to call for_modify()") + + def exec_once(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def insert(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def append(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def remove(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def clear(self, *args: Any, **kw: Any) -> NoReturn: + self._needs_modify(*args, **kw) + + def __call__(self, *args: Any, **kw: Any) -> None: + """Execute this event.""" + + for fn in self.parent_listeners: + fn(*args, **kw) + + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners + + def __len__(self) -> int: + return len(self.parent_listeners) + + def __iter__(self) -> Iterator[_ListenerFnType]: + return iter(self.parent_listeners) + + def __bool__(self) -> bool: + return bool(self.parent_listeners) + + +class _MutexProtocol(Protocol): + def __enter__(self) -> bool: ... + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: ... + + +class _CompoundListener(_InstanceLevelDispatch[_ET]): + __slots__ = ( + "_exec_once_mutex", + "_exec_once", + "_exec_w_sync_once", + "_is_asyncio", + ) + + _exec_once_mutex: _MutexProtocol + parent_listeners: Collection[_ListenerFnType] + listeners: Collection[_ListenerFnType] + _exec_once: bool + _exec_w_sync_once: bool + + def __init__(self, *arg: Any, **kw: Any): + super().__init__(*arg, **kw) + self._is_asyncio = False + + def _set_asyncio(self) -> None: + self._is_asyncio = True + + def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol: + if self._is_asyncio: + return AsyncAdaptedLock() + else: + return threading.Lock() + + def _exec_once_impl( + self, retry_on_exception: bool, *args: Any, **kw: Any + ) -> None: + with self._exec_once_mutex: + if not self._exec_once: + try: + self(*args, **kw) + exception = False + except: + exception = True + raise + finally: + if not exception or not retry_on_exception: + self._exec_once = True + + def exec_once(self, *args: Any, **kw: Any) -> None: + """Execute this event, but only if it has not been + executed already for this collection.""" + + if not self._exec_once: + self._exec_once_impl(False, *args, **kw) + + def exec_once_unless_exception(self, *args: Any, **kw: Any) -> None: + """Execute this event, but only if it has not been + executed already for this collection, or was called + by a previous exec_once_unless_exception call and + raised an exception. + + If exec_once was already called, then this method will never run + the callable regardless of whether it raised or not. + + .. versionadded:: 1.3.8 + + """ + if not self._exec_once: + self._exec_once_impl(True, *args, **kw) + + def _exec_w_sync_on_first_run(self, *args: Any, **kw: Any) -> None: + """Execute this event, and use a mutex if it has not been + executed already for this collection, or was called + by a previous _exec_w_sync_on_first_run call and + raised an exception. + + If _exec_w_sync_on_first_run was already called and didn't raise an + exception, then a mutex is not used. + + .. versionadded:: 1.4.11 + + """ + if not self._exec_w_sync_once: + with self._exec_once_mutex: + try: + self(*args, **kw) + except: + raise + else: + self._exec_w_sync_once = True + else: + self(*args, **kw) + + def __call__(self, *args: Any, **kw: Any) -> None: + """Execute this event.""" + + for fn in self.parent_listeners: + fn(*args, **kw) + for fn in self.listeners: + fn(*args, **kw) + + def __contains__(self, item: Any) -> bool: + return item in self.parent_listeners or item in self.listeners + + def __len__(self) -> int: + return len(self.parent_listeners) + len(self.listeners) + + def __iter__(self) -> Iterator[_ListenerFnType]: + return chain(self.parent_listeners, self.listeners) + + def __bool__(self) -> bool: + return bool(self.listeners or self.parent_listeners) + + +class _ListenerCollection(_CompoundListener[_ET]): + """Instance-level attributes on instances of :class:`._Dispatch`. + + Represents a collection of listeners. + + As of 0.7.9, _ListenerCollection is only first + created via the _EmptyListener.for_modify() method. + + """ + + __slots__ = ( + "parent_listeners", + "parent", + "name", + "listeners", + "propagate", + "__weakref__", + ) + + parent_listeners: Collection[_ListenerFnType] + parent: _ClsLevelDispatch[_ET] + name: str + listeners: Deque[_ListenerFnType] + propagate: Set[_ListenerFnType] + + def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): + super().__init__() + if target_cls not in parent._clslevel: + parent.update_subclass(target_cls) + self._exec_once = False + self._exec_w_sync_once = False + self.parent_listeners = parent._clslevel[target_cls] + self.parent = parent + self.name = parent.name + self.listeners = collections.deque() + self.propagate = set() + + def for_modify( + self, obj: _DispatchCommon[_ET] + ) -> _ListenerCollection[_ET]: + """Return an event collection which can be modified. + + For _ListenerCollection at the instance level of + a dispatcher, this returns self. + + """ + return self + + def _update( + self, other: _ListenerCollection[_ET], only_propagate: bool = True + ) -> None: + """Populate from the listeners in another :class:`_Dispatch` + object.""" + existing_listeners = self.listeners + existing_listener_set = set(existing_listeners) + self.propagate.update(other.propagate) + other_listeners = [ + l + for l in other.listeners + if l not in existing_listener_set + and not only_propagate + or l in self.propagate + ] + + existing_listeners.extend(other_listeners) + + if other._is_asyncio: + self._set_asyncio() + + to_associate = other.propagate.union(other_listeners) + registry._stored_in_collection_multi(self, other, to_associate) + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + if event_key.prepend_to_list(self, self.listeners): + if propagate: + self.propagate.add(event_key._listen_fn) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + if event_key.append_to_list(self, self.listeners): + if propagate: + self.propagate.add(event_key._listen_fn) + + def remove(self, event_key: _EventKey[_ET]) -> None: + self.listeners.remove(event_key._listen_fn) + self.propagate.discard(event_key._listen_fn) + registry._removed_from_collection(event_key, self) + + def clear(self) -> None: + registry._clear(self, self.listeners) + self.propagate.clear() + self.listeners.clear() + + +class _JoinedListener(_CompoundListener[_ET]): + __slots__ = "parent_dispatch", "name", "local", "parent_listeners" + + parent_dispatch: _DispatchCommon[_ET] + name: str + local: _InstanceLevelDispatch[_ET] + parent_listeners: Collection[_ListenerFnType] + + def __init__( + self, + parent_dispatch: _DispatchCommon[_ET], + name: str, + local: _EmptyListener[_ET], + ): + self._exec_once = False + self.parent_dispatch = parent_dispatch + self.name = name + self.local = local + self.parent_listeners = self.local + + if not typing.TYPE_CHECKING: + # first error, I don't really understand: + # Signature of "listeners" incompatible with + # supertype "_CompoundListener" [override] + # the name / return type are exactly the same + # second error is getattr_isn't typed, the cast() here + # adds too much method overhead + @property + def listeners(self) -> Collection[_ListenerFnType]: + return getattr(self.parent_dispatch, self.name) + + def _adjust_fn_spec( + self, fn: _ListenerFnType, named: bool + ) -> _ListenerFnType: + return self.local._adjust_fn_spec(fn, named) + + def for_modify(self, obj: _DispatchCommon[_ET]) -> _JoinedListener[_ET]: + self.local = self.parent_listeners = self.local.for_modify(obj) + return self + + def insert(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self.local.insert(event_key, propagate) + + def append(self, event_key: _EventKey[_ET], propagate: bool) -> None: + self.local.append(event_key, propagate) + + def remove(self, event_key: _EventKey[_ET]) -> None: + self.local.remove(event_key) + + def clear(self) -> None: + raise NotImplementedError() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/event/base.py new file mode 100644 index 0000000..1f52e2e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/event/base.py @@ -0,0 +1,462 @@ +# event/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Base implementation classes. + +The public-facing ``Events`` serves as the base class for an event interface; +its public attributes represent different kinds of events. These attributes +are mirrored onto a ``_Dispatch`` class, which serves as a container for +collections of listener functions. These collections are represented both +at the class level of a particular ``_Dispatch`` class as well as within +instances of ``_Dispatch``. + +""" +from __future__ import annotations + +import typing +from typing import Any +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union +import weakref + +from .attr import _ClsLevelDispatch +from .attr import _EmptyListener +from .attr import _InstanceLevelDispatch +from .attr import _JoinedListener +from .registry import _ET +from .registry import _EventKey +from .. import util +from ..util.typing import Literal + +_registrars: MutableMapping[str, List[Type[_HasEventsDispatch[Any]]]] = ( + util.defaultdict(list) +) + + +def _is_event_name(name: str) -> bool: + # _sa_event prefix is special to support internal-only event names. + # most event names are just plain method names that aren't + # underscored. + + return ( + not name.startswith("_") and name != "dispatch" + ) or name.startswith("_sa_event") + + +class _UnpickleDispatch: + """Serializable callable that re-generates an instance of + :class:`_Dispatch` given a particular :class:`.Events` subclass. + + """ + + def __call__(self, _instance_cls: Type[_ET]) -> _Dispatch[_ET]: + for cls in _instance_cls.__mro__: + if "dispatch" in cls.__dict__: + return cast( + "_Dispatch[_ET]", cls.__dict__["dispatch"].dispatch + )._for_class(_instance_cls) + else: + raise AttributeError("No class with a 'dispatch' member present.") + + +class _DispatchCommon(Generic[_ET]): + __slots__ = () + + _instance_cls: Optional[Type[_ET]] + + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: + raise NotImplementedError() + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + raise NotImplementedError() + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + raise NotImplementedError() + + +class _Dispatch(_DispatchCommon[_ET]): + """Mirror the event listening definitions of an Events class with + listener collections. + + Classes which define a "dispatch" member will return a + non-instantiated :class:`._Dispatch` subclass when the member + is accessed at the class level. When the "dispatch" member is + accessed at the instance level of its owner, an instance + of the :class:`._Dispatch` class is returned. + + A :class:`._Dispatch` class is generated for each :class:`.Events` + class defined, by the :meth:`._HasEventsDispatch._create_dispatcher_class` + method. The original :class:`.Events` classes remain untouched. + This decouples the construction of :class:`.Events` subclasses from + the implementation used by the event internals, and allows + inspecting tools like Sphinx to work in an unsurprising + way against the public API. + + """ + + # "active_history" is an ORM case we add here. ideally a better + # system would be in place for ad-hoc attributes. + __slots__ = "_parent", "_instance_cls", "__dict__", "_empty_listeners" + + _active_history: bool + + _empty_listener_reg: MutableMapping[ + Type[_ET], Dict[str, _EmptyListener[_ET]] + ] = weakref.WeakKeyDictionary() + + _empty_listeners: Dict[str, _EmptyListener[_ET]] + + _event_names: List[str] + + _instance_cls: Optional[Type[_ET]] + + _joined_dispatch_cls: Type[_JoinedDispatcher[_ET]] + + _events: Type[_HasEventsDispatch[_ET]] + """reference back to the Events class. + + Bidirectional against _HasEventsDispatch.dispatch + + """ + + def __init__( + self, + parent: Optional[_Dispatch[_ET]], + instance_cls: Optional[Type[_ET]] = None, + ): + self._parent = parent + self._instance_cls = instance_cls + + if instance_cls: + assert parent is not None + try: + self._empty_listeners = self._empty_listener_reg[instance_cls] + except KeyError: + self._empty_listeners = self._empty_listener_reg[ + instance_cls + ] = { + ls.name: _EmptyListener(ls, instance_cls) + for ls in parent._event_descriptors + } + else: + self._empty_listeners = {} + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + # Assign EmptyListeners as attributes on demand + # to reduce startup time for new dispatch objects. + try: + ls = self._empty_listeners[name] + except KeyError: + raise AttributeError(name) + else: + setattr(self, ls.name, ls) + return ls + + @property + def _event_descriptors(self) -> Iterator[_ClsLevelDispatch[_ET]]: + for k in self._event_names: + # Yield _ClsLevelDispatch related + # to relevant event name. + yield getattr(self, k) + + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self._events._listen(event_key, **kw) + + def _for_class(self, instance_cls: Type[_ET]) -> _Dispatch[_ET]: + return self.__class__(self, instance_cls) + + def _for_instance(self, instance: _ET) -> _Dispatch[_ET]: + instance_cls = instance.__class__ + return self._for_class(instance_cls) + + def _join(self, other: _DispatchCommon[_ET]) -> _JoinedDispatcher[_ET]: + """Create a 'join' of this :class:`._Dispatch` and another. + + This new dispatcher will dispatch events to both + :class:`._Dispatch` objects. + + """ + if "_joined_dispatch_cls" not in self.__class__.__dict__: + cls = type( + "Joined%s" % self.__class__.__name__, + (_JoinedDispatcher,), + {"__slots__": self._event_names}, + ) + self.__class__._joined_dispatch_cls = cls + return self._joined_dispatch_cls(self, other) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return _UnpickleDispatch(), (self._instance_cls,) + + def _update( + self, other: _Dispatch[_ET], only_propagate: bool = True + ) -> None: + """Populate from the listeners in another :class:`_Dispatch` + object.""" + for ls in other._event_descriptors: + if isinstance(ls, _EmptyListener): + continue + getattr(self, ls.name).for_modify(self)._update( + ls, only_propagate=only_propagate + ) + + def _clear(self) -> None: + for ls in self._event_descriptors: + ls.for_modify(self).clear() + + +def _remove_dispatcher(cls: Type[_HasEventsDispatch[_ET]]) -> None: + for k in cls.dispatch._event_names: + _registrars[k].remove(cls) + if not _registrars[k]: + del _registrars[k] + + +class _HasEventsDispatch(Generic[_ET]): + _dispatch_target: Optional[Type[_ET]] + """class which will receive the .dispatch collection""" + + dispatch: _Dispatch[_ET] + """reference back to the _Dispatch class. + + Bidirectional against _Dispatch._events + + """ + + if typing.TYPE_CHECKING: + + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: ... + + def __init_subclass__(cls) -> None: + """Intercept new Event subclasses and create associated _Dispatch + classes.""" + + cls._create_dispatcher_class(cls.__name__, cls.__bases__, cls.__dict__) + + @classmethod + def _accept_with( + cls, target: Union[_ET, Type[_ET]], identifier: str + ) -> Optional[Union[_ET, Type[_ET]]]: + raise NotImplementedError() + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + *, + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: + raise NotImplementedError() + + @staticmethod + def _set_dispatch( + klass: Type[_HasEventsDispatch[_ET]], + dispatch_cls: Type[_Dispatch[_ET]], + ) -> _Dispatch[_ET]: + # This allows an Events subclass to define additional utility + # methods made available to the target via + # "self.dispatch._events." + # @staticmethod to allow easy "super" calls while in a metaclass + # constructor. + klass.dispatch = dispatch_cls(None) + dispatch_cls._events = klass + return klass.dispatch + + @classmethod + def _create_dispatcher_class( + cls, classname: str, bases: Tuple[type, ...], dict_: Mapping[str, Any] + ) -> None: + """Create a :class:`._Dispatch` class corresponding to an + :class:`.Events` class.""" + + # there's all kinds of ways to do this, + # i.e. make a Dispatch class that shares the '_listen' method + # of the Event class, this is the straight monkeypatch. + if hasattr(cls, "dispatch"): + dispatch_base = cls.dispatch.__class__ + else: + dispatch_base = _Dispatch + + event_names = [k for k in dict_ if _is_event_name(k)] + dispatch_cls = cast( + "Type[_Dispatch[_ET]]", + type( + "%sDispatch" % classname, + (dispatch_base,), + {"__slots__": event_names}, + ), + ) + + dispatch_cls._event_names = event_names + dispatch_inst = cls._set_dispatch(cls, dispatch_cls) + for k in dispatch_cls._event_names: + setattr(dispatch_inst, k, _ClsLevelDispatch(cls, dict_[k])) + _registrars[k].append(cls) + + for super_ in dispatch_cls.__bases__: + if issubclass(super_, _Dispatch) and super_ is not _Dispatch: + for ls in super_._events.dispatch._event_descriptors: + setattr(dispatch_inst, ls.name, ls) + dispatch_cls._event_names.append(ls.name) + + if getattr(cls, "_dispatch_target", None): + dispatch_target_cls = cls._dispatch_target + assert dispatch_target_cls is not None + if ( + hasattr(dispatch_target_cls, "__slots__") + and "_slots_dispatch" in dispatch_target_cls.__slots__ + ): + dispatch_target_cls.dispatch = slots_dispatcher(cls) + else: + dispatch_target_cls.dispatch = dispatcher(cls) + + +class Events(_HasEventsDispatch[_ET]): + """Define event listening functions for a particular target type.""" + + @classmethod + def _accept_with( + cls, target: Union[_ET, Type[_ET]], identifier: str + ) -> Optional[Union[_ET, Type[_ET]]]: + def dispatch_is(*types: Type[Any]) -> bool: + return all(isinstance(target.dispatch, t) for t in types) + + def dispatch_parent_is(t: Type[Any]) -> bool: + return isinstance( + cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t + ) + + # Mapper, ClassManager, Session override this to + # also accept classes, scoped_sessions, sessionmakers, etc. + if hasattr(target, "dispatch"): + if ( + dispatch_is(cls.dispatch.__class__) + or dispatch_is(type, cls.dispatch.__class__) + or ( + dispatch_is(_JoinedDispatcher) + and dispatch_parent_is(cls.dispatch.__class__) + ) + ): + return target + + return None + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + *, + propagate: bool = False, + insert: bool = False, + named: bool = False, + asyncio: bool = False, + ) -> None: + event_key.base_listen( + propagate=propagate, insert=insert, named=named, asyncio=asyncio + ) + + @classmethod + def _remove(cls, event_key: _EventKey[_ET]) -> None: + event_key.remove() + + @classmethod + def _clear(cls) -> None: + cls.dispatch._clear() + + +class _JoinedDispatcher(_DispatchCommon[_ET]): + """Represent a connection between two _Dispatch objects.""" + + __slots__ = "local", "parent", "_instance_cls" + + local: _DispatchCommon[_ET] + parent: _DispatchCommon[_ET] + _instance_cls: Optional[Type[_ET]] + + def __init__( + self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] + ): + self.local = local + self.parent = parent + self._instance_cls = self.local._instance_cls + + def __getattr__(self, name: str) -> _JoinedListener[_ET]: + # Assign _JoinedListeners as attributes on demand + # to reduce startup time for new dispatch objects. + ls = getattr(self.local, name) + jl = _JoinedListener(self.parent, ls.name, ls) + setattr(self, ls.name, jl) + return jl + + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self.parent._listen(event_key, **kw) + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + return self.parent._events + + +class dispatcher(Generic[_ET]): + """Descriptor used by target classes to + deliver the _Dispatch class at the class level + and produce new _Dispatch instances for target + instances. + + """ + + def __init__(self, events: Type[_HasEventsDispatch[_ET]]): + self.dispatch = events.dispatch + self.events = events + + @overload + def __get__( + self, obj: Literal[None], cls: Type[Any] + ) -> Type[_Dispatch[_ET]]: ... + + @overload + def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: ... + + def __get__(self, obj: Any, cls: Type[Any]) -> Any: + if obj is None: + return self.dispatch + + disp = self.dispatch._for_instance(obj) + try: + obj.__dict__["dispatch"] = disp + except AttributeError as ae: + raise TypeError( + "target %r doesn't have __dict__, should it be " + "defining _slots_dispatch?" % (obj,) + ) from ae + return disp + + +class slots_dispatcher(dispatcher[_ET]): + def __get__(self, obj: Any, cls: Type[Any]) -> Any: + if obj is None: + return self.dispatch + + if hasattr(obj, "_slots_dispatch"): + return obj._slots_dispatch + + disp = self.dispatch._for_instance(obj) + obj._slots_dispatch = disp + return disp diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/legacy.py b/venv/lib/python3.11/site-packages/sqlalchemy/event/legacy.py new file mode 100644 index 0000000..57e561c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/event/legacy.py @@ -0,0 +1,246 @@ +# event/legacy.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Routines to handle adaption of legacy call signatures, +generation of deprecation notes and docstrings. + +""" +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type + +from .registry import _ET +from .registry import _ListenerFnType +from .. import util +from ..util.compat import FullArgSpec + +if typing.TYPE_CHECKING: + from .attr import _ClsLevelDispatch + from .base import _HasEventsDispatch + + +_LegacySignatureType = Tuple[str, List[str], Optional[Callable[..., Any]]] + + +def _legacy_signature( + since: str, + argnames: List[str], + converter: Optional[Callable[..., Any]] = None, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """legacy sig decorator + + + :param since: string version for deprecation warning + :param argnames: list of strings, which is *all* arguments that the legacy + version accepted, including arguments that are still there + :param converter: lambda that will accept tuple of this full arg signature + and return tuple of new arg signature. + + """ + + def leg(fn: Callable[..., Any]) -> Callable[..., Any]: + if not hasattr(fn, "_legacy_signatures"): + fn._legacy_signatures = [] # type: ignore[attr-defined] + fn._legacy_signatures.append((since, argnames, converter)) # type: ignore[attr-defined] # noqa: E501 + return fn + + return leg + + +def _wrap_fn_for_legacy( + dispatch_collection: _ClsLevelDispatch[_ET], + fn: _ListenerFnType, + argspec: FullArgSpec, +) -> _ListenerFnType: + for since, argnames, conv in dispatch_collection.legacy_signatures: + if argnames[-1] == "**kw": + has_kw = True + argnames = argnames[0:-1] + else: + has_kw = False + + if len(argnames) == len(argspec.args) and has_kw is bool( + argspec.varkw + ): + formatted_def = "def %s(%s%s)" % ( + dispatch_collection.name, + ", ".join(dispatch_collection.arg_names), + ", **kw" if has_kw else "", + ) + warning_txt = ( + 'The argument signature for the "%s.%s" event listener ' + "has changed as of version %s, and conversion for " + "the old argument signature will be removed in a " + 'future release. The new signature is "%s"' + % ( + dispatch_collection.clsname, + dispatch_collection.name, + since, + formatted_def, + ) + ) + + if conv is not None: + assert not has_kw + + def wrap_leg(*args: Any, **kw: Any) -> Any: + util.warn_deprecated(warning_txt, version=since) + assert conv is not None + return fn(*conv(*args)) + + else: + + def wrap_leg(*args: Any, **kw: Any) -> Any: + util.warn_deprecated(warning_txt, version=since) + argdict = dict(zip(dispatch_collection.arg_names, args)) + args_from_dict = [argdict[name] for name in argnames] + if has_kw: + return fn(*args_from_dict, **kw) + else: + return fn(*args_from_dict) + + return wrap_leg + else: + return fn + + +def _indent(text: str, indent: str) -> str: + return "\n".join(indent + line for line in text.split("\n")) + + +def _standard_listen_example( + dispatch_collection: _ClsLevelDispatch[_ET], + sample_target: Any, + fn: _ListenerFnType, +) -> str: + example_kw_arg = _indent( + "\n".join( + "%(arg)s = kw['%(arg)s']" % {"arg": arg} + for arg in dispatch_collection.arg_names[0:2] + ), + " ", + ) + if dispatch_collection.legacy_signatures: + current_since = max( + since + for since, args, conv in dispatch_collection.legacy_signatures + ) + else: + current_since = None + text = ( + "from sqlalchemy import event\n\n\n" + "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" + "def receive_%(event_name)s(" + "%(named_event_arguments)s%(has_kw_arguments)s):\n" + " \"listen for the '%(event_name)s' event\"\n" + "\n # ... (event handling logic) ...\n" + ) + + text %= { + "current_since": ( + " (arguments as of %s)" % current_since if current_since else "" + ), + "event_name": fn.__name__, + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", + "named_event_arguments": ", ".join(dispatch_collection.arg_names), + "example_kw_arg": example_kw_arg, + "sample_target": sample_target, + } + return text + + +def _legacy_listen_examples( + dispatch_collection: _ClsLevelDispatch[_ET], + sample_target: str, + fn: _ListenerFnType, +) -> str: + text = "" + for since, args, conv in dispatch_collection.legacy_signatures: + text += ( + "\n# DEPRECATED calling style (pre-%(since)s, " + "will be removed in a future release)\n" + "@event.listens_for(%(sample_target)s, '%(event_name)s')\n" + "def receive_%(event_name)s(" + "%(named_event_arguments)s%(has_kw_arguments)s):\n" + " \"listen for the '%(event_name)s' event\"\n" + "\n # ... (event handling logic) ...\n" + % { + "since": since, + "event_name": fn.__name__, + "has_kw_arguments": ( + " **kw" if dispatch_collection.has_kw else "" + ), + "named_event_arguments": ", ".join(args), + "sample_target": sample_target, + } + ) + return text + + +def _version_signature_changes( + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + dispatch_collection: _ClsLevelDispatch[_ET], +) -> str: + since, args, conv = dispatch_collection.legacy_signatures[0] + return ( + "\n.. versionchanged:: %(since)s\n" + " The :meth:`.%(clsname)s.%(event_name)s` event now accepts the \n" + " arguments %(named_event_arguments)s%(has_kw_arguments)s.\n" + " Support for listener functions which accept the previous \n" + ' argument signature(s) listed above as "deprecated" will be \n' + " removed in a future release." + % { + "since": since, + "clsname": parent_dispatch_cls.__name__, + "event_name": dispatch_collection.name, + "named_event_arguments": ", ".join( + ":paramref:`.%(clsname)s.%(event_name)s.%(param_name)s`" + % { + "clsname": parent_dispatch_cls.__name__, + "event_name": dispatch_collection.name, + "param_name": param_name, + } + for param_name in dispatch_collection.arg_names + ), + "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", + } + ) + + +def _augment_fn_docs( + dispatch_collection: _ClsLevelDispatch[_ET], + parent_dispatch_cls: Type[_HasEventsDispatch[_ET]], + fn: _ListenerFnType, +) -> str: + header = ( + ".. container:: event_signatures\n\n" + " Example argument forms::\n" + "\n" + ) + + sample_target = getattr(parent_dispatch_cls, "_target_class_doc", "obj") + text = header + _indent( + _standard_listen_example(dispatch_collection, sample_target, fn), + " " * 8, + ) + if dispatch_collection.legacy_signatures: + text += _indent( + _legacy_listen_examples(dispatch_collection, sample_target, fn), + " " * 8, + ) + + text += _version_signature_changes( + parent_dispatch_cls, dispatch_collection + ) + + return util.inject_docstring_text(fn.__doc__, text, 1) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/event/registry.py b/venv/lib/python3.11/site-packages/sqlalchemy/event/registry.py new file mode 100644 index 0000000..773620f --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/event/registry.py @@ -0,0 +1,386 @@ +# event/registry.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides managed registration services on behalf of :func:`.listen` +arguments. + +By "managed registration", we mean that event listening functions and +other objects can be added to various collections in such a way that their +membership in all those collections can be revoked at once, based on +an equivalent :class:`._EventKey`. + +""" +from __future__ import annotations + +import collections +import types +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Deque +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union +import weakref + +from .. import exc +from .. import util + +if typing.TYPE_CHECKING: + from .attr import RefCollection + from .base import dispatcher + +_ListenerFnType = Callable[..., Any] +_ListenerFnKeyType = Union[int, Tuple[int, int]] +_EventKeyTupleType = Tuple[int, str, _ListenerFnKeyType] + + +_ET = TypeVar("_ET", bound="EventTarget") + + +class EventTarget: + """represents an event target, that is, something we can listen on + either with that target as a class or as an instance. + + Examples include: Connection, Mapper, Table, Session, + InstrumentedAttribute, Engine, Pool, Dialect. + + """ + + __slots__ = () + + dispatch: dispatcher[Any] + + +_RefCollectionToListenerType = Dict[ + "weakref.ref[RefCollection[Any]]", + "weakref.ref[_ListenerFnType]", +] + +_key_to_collection: Dict[_EventKeyTupleType, _RefCollectionToListenerType] = ( + collections.defaultdict(dict) +) +""" +Given an original listen() argument, can locate all +listener collections and the listener fn contained + +(target, identifier, fn) -> { + ref(listenercollection) -> ref(listener_fn) + ref(listenercollection) -> ref(listener_fn) + ref(listenercollection) -> ref(listener_fn) + } +""" + +_ListenerToEventKeyType = Dict[ + "weakref.ref[_ListenerFnType]", + _EventKeyTupleType, +] +_collection_to_key: Dict[ + weakref.ref[RefCollection[Any]], + _ListenerToEventKeyType, +] = collections.defaultdict(dict) +""" +Given a _ListenerCollection or _ClsLevelListener, can locate +all the original listen() arguments and the listener fn contained + +ref(listenercollection) -> { + ref(listener_fn) -> (target, identifier, fn), + ref(listener_fn) -> (target, identifier, fn), + ref(listener_fn) -> (target, identifier, fn), + } +""" + + +def _collection_gced(ref: weakref.ref[Any]) -> None: + # defaultdict, so can't get a KeyError + if not _collection_to_key or ref not in _collection_to_key: + return + + ref = cast("weakref.ref[RefCollection[EventTarget]]", ref) + + listener_to_key = _collection_to_key.pop(ref) + for key in listener_to_key.values(): + if key in _key_to_collection: + # defaultdict, so can't get a KeyError + dispatch_reg = _key_to_collection[key] + dispatch_reg.pop(ref) + if not dispatch_reg: + _key_to_collection.pop(key) + + +def _stored_in_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> bool: + key = event_key._key + + dispatch_reg = _key_to_collection[key] + + owner_ref = owner.ref + listen_ref = weakref.ref(event_key._listen_fn) + + if owner_ref in dispatch_reg: + return False + + dispatch_reg[owner_ref] = listen_ref + + listener_to_key = _collection_to_key[owner_ref] + listener_to_key[listen_ref] = key + + return True + + +def _removed_from_collection( + event_key: _EventKey[_ET], owner: RefCollection[_ET] +) -> None: + key = event_key._key + + dispatch_reg = _key_to_collection[key] + + listen_ref = weakref.ref(event_key._listen_fn) + + owner_ref = owner.ref + dispatch_reg.pop(owner_ref, None) + if not dispatch_reg: + del _key_to_collection[key] + + if owner_ref in _collection_to_key: + listener_to_key = _collection_to_key[owner_ref] + listener_to_key.pop(listen_ref) + + +def _stored_in_collection_multi( + newowner: RefCollection[_ET], + oldowner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: + if not elements: + return + + oldowner_ref = oldowner.ref + newowner_ref = newowner.ref + + old_listener_to_key = _collection_to_key[oldowner_ref] + new_listener_to_key = _collection_to_key[newowner_ref] + + for listen_fn in elements: + listen_ref = weakref.ref(listen_fn) + try: + key = old_listener_to_key[listen_ref] + except KeyError: + # can occur during interpreter shutdown. + # see #6740 + continue + + try: + dispatch_reg = _key_to_collection[key] + except KeyError: + continue + + if newowner_ref in dispatch_reg: + assert dispatch_reg[newowner_ref] == listen_ref + else: + dispatch_reg[newowner_ref] = listen_ref + + new_listener_to_key[listen_ref] = key + + +def _clear( + owner: RefCollection[_ET], + elements: Iterable[_ListenerFnType], +) -> None: + if not elements: + return + + owner_ref = owner.ref + listener_to_key = _collection_to_key[owner_ref] + for listen_fn in elements: + listen_ref = weakref.ref(listen_fn) + key = listener_to_key[listen_ref] + dispatch_reg = _key_to_collection[key] + dispatch_reg.pop(owner_ref, None) + + if not dispatch_reg: + del _key_to_collection[key] + + +class _EventKey(Generic[_ET]): + """Represent :func:`.listen` arguments.""" + + __slots__ = ( + "target", + "identifier", + "fn", + "fn_key", + "fn_wrap", + "dispatch_target", + ) + + target: _ET + identifier: str + fn: _ListenerFnType + fn_key: _ListenerFnKeyType + dispatch_target: Any + _fn_wrap: Optional[_ListenerFnType] + + def __init__( + self, + target: _ET, + identifier: str, + fn: _ListenerFnType, + dispatch_target: Any, + _fn_wrap: Optional[_ListenerFnType] = None, + ): + self.target = target + self.identifier = identifier + self.fn = fn + if isinstance(fn, types.MethodType): + self.fn_key = id(fn.__func__), id(fn.__self__) + else: + self.fn_key = id(fn) + self.fn_wrap = _fn_wrap + self.dispatch_target = dispatch_target + + @property + def _key(self) -> _EventKeyTupleType: + return (id(self.target), self.identifier, self.fn_key) + + def with_wrapper(self, fn_wrap: _ListenerFnType) -> _EventKey[_ET]: + if fn_wrap is self._listen_fn: + return self + else: + return _EventKey( + self.target, + self.identifier, + self.fn, + self.dispatch_target, + _fn_wrap=fn_wrap, + ) + + def with_dispatch_target(self, dispatch_target: Any) -> _EventKey[_ET]: + if dispatch_target is self.dispatch_target: + return self + else: + return _EventKey( + self.target, + self.identifier, + self.fn, + dispatch_target, + _fn_wrap=self.fn_wrap, + ) + + def listen(self, *args: Any, **kw: Any) -> None: + once = kw.pop("once", False) + once_unless_exception = kw.pop("_once_unless_exception", False) + named = kw.pop("named", False) + + target, identifier, fn = ( + self.dispatch_target, + self.identifier, + self._listen_fn, + ) + + dispatch_collection = getattr(target.dispatch, identifier) + + adjusted_fn = dispatch_collection._adjust_fn_spec(fn, named) + + self = self.with_wrapper(adjusted_fn) + + stub_function = getattr( + self.dispatch_target.dispatch._events, self.identifier + ) + if hasattr(stub_function, "_sa_warn"): + stub_function._sa_warn() + + if once or once_unless_exception: + self.with_wrapper( + util.only_once( + self._listen_fn, retry_on_exception=once_unless_exception + ) + ).listen(*args, **kw) + else: + self.dispatch_target.dispatch._listen(self, *args, **kw) + + def remove(self) -> None: + key = self._key + + if key not in _key_to_collection: + raise exc.InvalidRequestError( + "No listeners found for event %s / %r / %s " + % (self.target, self.identifier, self.fn) + ) + + dispatch_reg = _key_to_collection.pop(key) + + for collection_ref, listener_ref in dispatch_reg.items(): + collection = collection_ref() + listener_fn = listener_ref() + if collection is not None and listener_fn is not None: + collection.remove(self.with_wrapper(listener_fn)) + + def contains(self) -> bool: + """Return True if this event key is registered to listen.""" + return self._key in _key_to_collection + + def base_listen( + self, + propagate: bool = False, + insert: bool = False, + named: bool = False, + retval: Optional[bool] = None, + asyncio: bool = False, + ) -> None: + target, identifier = self.dispatch_target, self.identifier + + dispatch_collection = getattr(target.dispatch, identifier) + + for_modify = dispatch_collection.for_modify(target.dispatch) + if asyncio: + for_modify._set_asyncio() + + if insert: + for_modify.insert(self, propagate) + else: + for_modify.append(self, propagate) + + @property + def _listen_fn(self) -> _ListenerFnType: + return self.fn_wrap or self.fn + + def append_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: + if _stored_in_collection(self, owner): + list_.append(self._listen_fn) + return True + else: + return False + + def remove_from_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> None: + _removed_from_collection(self, owner) + list_.remove(self._listen_fn) + + def prepend_to_list( + self, + owner: RefCollection[_ET], + list_: Deque[_ListenerFnType], + ) -> bool: + if _stored_in_collection(self, owner): + list_.appendleft(self._listen_fn) + return True + else: + return False diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/events.py b/venv/lib/python3.11/site-packages/sqlalchemy/events.py new file mode 100644 index 0000000..8c3bf01 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/events.py @@ -0,0 +1,17 @@ +# events.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Core event interfaces.""" + +from __future__ import annotations + +from .engine.events import ConnectionEvents +from .engine.events import DialectEvents +from .pool import PoolResetState +from .pool.events import PoolEvents +from .sql.base import SchemaEventTarget +from .sql.events import DDLEvents diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/exc.py b/venv/lib/python3.11/site-packages/sqlalchemy/exc.py new file mode 100644 index 0000000..7d7eff3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/exc.py @@ -0,0 +1,830 @@ +# exc.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Exceptions used with SQLAlchemy. + +The base exception class is :exc:`.SQLAlchemyError`. Exceptions which are +raised as a result of DBAPI exceptions are all subclasses of +:exc:`.DBAPIError`. + +""" +from __future__ import annotations + +import typing +from typing import Any +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import Union + +from .util import compat +from .util import preloaded as _preloaded + +if typing.TYPE_CHECKING: + from .engine.interfaces import _AnyExecuteParams + from .engine.interfaces import Dialect + from .sql.compiler import Compiled + from .sql.compiler import TypeCompiler + from .sql.elements import ClauseElement + +if typing.TYPE_CHECKING: + _version_token: str +else: + # set by __init__.py + _version_token = None + + +class HasDescriptionCode: + """helper which adds 'code' as an attribute and '_code_str' as a method""" + + code: Optional[str] = None + + def __init__(self, *arg: Any, **kw: Any): + code = kw.pop("code", None) + if code is not None: + self.code = code + super().__init__(*arg, **kw) + + _what_are_we = "error" + + def _code_str(self) -> str: + if not self.code: + return "" + else: + return ( + f"(Background on this {self._what_are_we} at: " + f"https://sqlalche.me/e/{_version_token}/{self.code})" + ) + + def __str__(self) -> str: + message = super().__str__() + if self.code: + message = "%s %s" % (message, self._code_str()) + return message + + +class SQLAlchemyError(HasDescriptionCode, Exception): + """Generic error class.""" + + def _message(self) -> str: + # rules: + # + # 1. single arg string will usually be a unicode + # object, but since __str__() must return unicode, check for + # bytestring just in case + # + # 2. for multiple self.args, this is not a case in current + # SQLAlchemy though this is happening in at least one known external + # library, call str() which does a repr(). + # + text: str + + if len(self.args) == 1: + arg_text = self.args[0] + + if isinstance(arg_text, bytes): + text = compat.decode_backslashreplace(arg_text, "utf-8") + # This is for when the argument is not a string of any sort. + # Otherwise, converting this exception to string would fail for + # non-string arguments. + else: + text = str(arg_text) + + return text + else: + # this is not a normal case within SQLAlchemy but is here for + # compatibility with Exception.args - the str() comes out as + # a repr() of the tuple + return str(self.args) + + def _sql_message(self) -> str: + message = self._message() + + if self.code: + message = "%s %s" % (message, self._code_str()) + + return message + + def __str__(self) -> str: + return self._sql_message() + + +class ArgumentError(SQLAlchemyError): + """Raised when an invalid or conflicting function argument is supplied. + + This error generally corresponds to construction time state errors. + + """ + + +class DuplicateColumnError(ArgumentError): + """a Column is being added to a Table that would replace another + Column, without appropriate parameters to allow this in place. + + .. versionadded:: 2.0.0b4 + + """ + + +class ObjectNotExecutableError(ArgumentError): + """Raised when an object is passed to .execute() that can't be + executed as SQL. + + """ + + def __init__(self, target: Any): + super().__init__("Not an executable object: %r" % target) + self.target = target + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return self.__class__, (self.target,) + + +class NoSuchModuleError(ArgumentError): + """Raised when a dynamically-loaded module (usually a database dialect) + of a particular name cannot be located.""" + + +class NoForeignKeysError(ArgumentError): + """Raised when no foreign keys can be located between two selectables + during a join.""" + + +class AmbiguousForeignKeysError(ArgumentError): + """Raised when more than one foreign key matching can be located + between two selectables during a join.""" + + +class ConstraintColumnNotFoundError(ArgumentError): + """raised when a constraint refers to a string column name that + is not present in the table being constrained. + + .. versionadded:: 2.0 + + """ + + +class CircularDependencyError(SQLAlchemyError): + """Raised by topological sorts when a circular dependency is detected. + + There are two scenarios where this error occurs: + + * In a Session flush operation, if two objects are mutually dependent + on each other, they can not be inserted or deleted via INSERT or + DELETE statements alone; an UPDATE will be needed to post-associate + or pre-deassociate one of the foreign key constrained values. + The ``post_update`` flag described at :ref:`post_update` can resolve + this cycle. + * In a :attr:`_schema.MetaData.sorted_tables` operation, two + :class:`_schema.ForeignKey` + or :class:`_schema.ForeignKeyConstraint` objects mutually refer to each + other. Apply the ``use_alter=True`` flag to one or both, + see :ref:`use_alter`. + + """ + + def __init__( + self, + message: str, + cycles: Any, + edges: Any, + msg: Optional[str] = None, + code: Optional[str] = None, + ): + if msg is None: + message += " (%s)" % ", ".join(repr(s) for s in cycles) + else: + message = msg + SQLAlchemyError.__init__(self, message, code=code) + self.cycles = cycles + self.edges = edges + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + (None, self.cycles, self.edges, self.args[0]), + {"code": self.code} if self.code is not None else {}, + ) + + +class CompileError(SQLAlchemyError): + """Raised when an error occurs during SQL compilation""" + + +class UnsupportedCompilationError(CompileError): + """Raised when an operation is not supported by the given compiler. + + .. seealso:: + + :ref:`faq_sql_expression_string` + + :ref:`error_l7de` + """ + + code = "l7de" + + def __init__( + self, + compiler: Union[Compiled, TypeCompiler], + element_type: Type[ClauseElement], + message: Optional[str] = None, + ): + super().__init__( + "Compiler %r can't render element of type %s%s" + % (compiler, element_type, ": %s" % message if message else "") + ) + self.compiler = compiler + self.element_type = element_type + self.message = message + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return self.__class__, (self.compiler, self.element_type, self.message) + + +class IdentifierError(SQLAlchemyError): + """Raised when a schema name is beyond the max character limit""" + + +class DisconnectionError(SQLAlchemyError): + """A disconnect is detected on a raw DB-API connection. + + This error is raised and consumed internally by a connection pool. It can + be raised by the :meth:`_events.PoolEvents.checkout` + event so that the host pool + forces a retry; the exception will be caught three times in a row before + the pool gives up and raises :class:`~sqlalchemy.exc.InvalidRequestError` + regarding the connection attempt. + + """ + + invalidate_pool: bool = False + + +class InvalidatePoolError(DisconnectionError): + """Raised when the connection pool should invalidate all stale connections. + + A subclass of :class:`_exc.DisconnectionError` that indicates that the + disconnect situation encountered on the connection probably means the + entire pool should be invalidated, as the database has been restarted. + + This exception will be handled otherwise the same way as + :class:`_exc.DisconnectionError`, allowing three attempts to reconnect + before giving up. + + .. versionadded:: 1.2 + + """ + + invalidate_pool: bool = True + + +class TimeoutError(SQLAlchemyError): # noqa + """Raised when a connection pool times out on getting a connection.""" + + +class InvalidRequestError(SQLAlchemyError): + """SQLAlchemy was asked to do something it can't do. + + This error generally corresponds to runtime state errors. + + """ + + +class IllegalStateChangeError(InvalidRequestError): + """An object that tracks state encountered an illegal state change + of some kind. + + .. versionadded:: 2.0 + + """ + + +class NoInspectionAvailable(InvalidRequestError): + """A subject passed to :func:`sqlalchemy.inspection.inspect` produced + no context for inspection.""" + + +class PendingRollbackError(InvalidRequestError): + """A transaction has failed and needs to be rolled back before + continuing. + + .. versionadded:: 1.4 + + """ + + +class ResourceClosedError(InvalidRequestError): + """An operation was requested from a connection, cursor, or other + object that's in a closed state.""" + + +class NoSuchColumnError(InvalidRequestError, KeyError): + """A nonexistent column is requested from a ``Row``.""" + + +class NoResultFound(InvalidRequestError): + """A database result was required but none was found. + + + .. versionchanged:: 1.4 This exception is now part of the + ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol + remains importable from ``sqlalchemy.orm.exc``. + + + """ + + +class MultipleResultsFound(InvalidRequestError): + """A single database result was required but more than one were found. + + .. versionchanged:: 1.4 This exception is now part of the + ``sqlalchemy.exc`` module in Core, moved from the ORM. The symbol + remains importable from ``sqlalchemy.orm.exc``. + + + """ + + +class NoReferenceError(InvalidRequestError): + """Raised by ``ForeignKey`` to indicate a reference cannot be resolved.""" + + table_name: str + + +class AwaitRequired(InvalidRequestError): + """Error raised by the async greenlet spawn if no async operation + was awaited when it required one. + + """ + + code = "xd1r" + + +class MissingGreenlet(InvalidRequestError): + r"""Error raised by the async greenlet await\_ if called while not inside + the greenlet spawn context. + + """ + + code = "xd2s" + + +class NoReferencedTableError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Table`` cannot be + located. + + """ + + def __init__(self, message: str, tname: str): + NoReferenceError.__init__(self, message) + self.table_name = tname + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return self.__class__, (self.args[0], self.table_name) + + +class NoReferencedColumnError(NoReferenceError): + """Raised by ``ForeignKey`` when the referred ``Column`` cannot be + located. + + """ + + def __init__(self, message: str, tname: str, cname: str): + NoReferenceError.__init__(self, message) + self.table_name = tname + self.column_name = cname + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + (self.args[0], self.table_name, self.column_name), + ) + + +class NoSuchTableError(InvalidRequestError): + """Table does not exist or is not visible to a connection.""" + + +class UnreflectableTableError(InvalidRequestError): + """Table exists but can't be reflected for some reason. + + .. versionadded:: 1.2 + + """ + + +class UnboundExecutionError(InvalidRequestError): + """SQL was attempted without a database connection to execute it on.""" + + +class DontWrapMixin: + """A mixin class which, when applied to a user-defined Exception class, + will not be wrapped inside of :exc:`.StatementError` if the error is + emitted within the process of executing a statement. + + E.g.:: + + from sqlalchemy.exc import DontWrapMixin + + class MyCustomException(Exception, DontWrapMixin): + pass + + class MySpecialType(TypeDecorator): + impl = String + + def process_bind_param(self, value, dialect): + if value == 'invalid': + raise MyCustomException("invalid!") + + """ + + +class StatementError(SQLAlchemyError): + """An error occurred during execution of a SQL statement. + + :class:`StatementError` wraps the exception raised + during execution, and features :attr:`.statement` + and :attr:`.params` attributes which supply context regarding + the specifics of the statement which had an issue. + + The wrapped exception object is available in + the :attr:`.orig` attribute. + + """ + + statement: Optional[str] = None + """The string SQL statement being invoked when this exception occurred.""" + + params: Optional[_AnyExecuteParams] = None + """The parameter list being used when this exception occurred.""" + + orig: Optional[BaseException] = None + """The original exception that was thrown. + + """ + + ismulti: Optional[bool] = None + """multi parameter passed to repr_params(). None is meaningful.""" + + connection_invalidated: bool = False + + def __init__( + self, + message: str, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: Optional[BaseException], + hide_parameters: bool = False, + code: Optional[str] = None, + ismulti: Optional[bool] = None, + ): + SQLAlchemyError.__init__(self, message, code=code) + self.statement = statement + self.params = params + self.orig = orig + self.ismulti = ismulti + self.hide_parameters = hide_parameters + self.detail: List[str] = [] + + def add_detail(self, msg: str) -> None: + self.detail.append(msg) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + ( + self.args[0], + self.statement, + self.params, + self.orig, + self.hide_parameters, + self.__dict__.get("code"), + self.ismulti, + ), + {"detail": self.detail}, + ) + + @_preloaded.preload_module("sqlalchemy.sql.util") + def _sql_message(self) -> str: + util = _preloaded.sql_util + + details = [self._message()] + if self.statement: + stmt_detail = "[SQL: %s]" % self.statement + details.append(stmt_detail) + if self.params: + if self.hide_parameters: + details.append( + "[SQL parameters hidden due to hide_parameters=True]" + ) + else: + params_repr = util._repr_params( + self.params, 10, ismulti=self.ismulti + ) + details.append("[parameters: %r]" % params_repr) + code_str = self._code_str() + if code_str: + details.append(code_str) + return "\n".join(["(%s)" % det for det in self.detail] + details) + + +class DBAPIError(StatementError): + """Raised when the execution of a database operation fails. + + Wraps exceptions raised by the DB-API underlying the + database operation. Driver-specific implementations of the standard + DB-API exception types are wrapped by matching sub-types of SQLAlchemy's + :class:`DBAPIError` when possible. DB-API's ``Error`` type maps to + :class:`DBAPIError` in SQLAlchemy, otherwise the names are identical. Note + that there is no guarantee that different DB-API implementations will + raise the same exception type for any given error condition. + + :class:`DBAPIError` features :attr:`~.StatementError.statement` + and :attr:`~.StatementError.params` attributes which supply context + regarding the specifics of the statement which had an issue, for the + typical case when the error was raised within the context of + emitting a SQL statement. + + The wrapped exception object is available in the + :attr:`~.StatementError.orig` attribute. Its type and properties are + DB-API implementation specific. + + """ + + code = "dbapi" + + @overload + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: Exception, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> StatementError: ... + + @overload + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: DontWrapMixin, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> DontWrapMixin: ... + + @overload + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: BaseException, + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> BaseException: ... + + @classmethod + def instance( + cls, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: Union[BaseException, DontWrapMixin], + dbapi_base_err: Type[Exception], + hide_parameters: bool = False, + connection_invalidated: bool = False, + dialect: Optional[Dialect] = None, + ismulti: Optional[bool] = None, + ) -> Union[BaseException, DontWrapMixin]: + # Don't ever wrap these, just return them directly as if + # DBAPIError didn't exist. + if ( + isinstance(orig, BaseException) and not isinstance(orig, Exception) + ) or isinstance(orig, DontWrapMixin): + return orig + + if orig is not None: + # not a DBAPI error, statement is present. + # raise a StatementError + if isinstance(orig, SQLAlchemyError) and statement: + return StatementError( + "(%s.%s) %s" + % ( + orig.__class__.__module__, + orig.__class__.__name__, + orig.args[0], + ), + statement, + params, + orig, + hide_parameters=hide_parameters, + code=orig.code, + ismulti=ismulti, + ) + elif not isinstance(orig, dbapi_base_err) and statement: + return StatementError( + "(%s.%s) %s" + % ( + orig.__class__.__module__, + orig.__class__.__name__, + orig, + ), + statement, + params, + orig, + hide_parameters=hide_parameters, + ismulti=ismulti, + ) + + glob = globals() + for super_ in orig.__class__.__mro__: + name = super_.__name__ + if dialect: + name = dialect.dbapi_exception_translation_map.get( + name, name + ) + if name in glob and issubclass(glob[name], DBAPIError): + cls = glob[name] + break + + return cls( + statement, + params, + orig, + connection_invalidated=connection_invalidated, + hide_parameters=hide_parameters, + code=cls.code, + ismulti=ismulti, + ) + + def __reduce__(self) -> Union[str, Tuple[Any, ...]]: + return ( + self.__class__, + ( + self.statement, + self.params, + self.orig, + self.hide_parameters, + self.connection_invalidated, + self.__dict__.get("code"), + self.ismulti, + ), + {"detail": self.detail}, + ) + + def __init__( + self, + statement: Optional[str], + params: Optional[_AnyExecuteParams], + orig: BaseException, + hide_parameters: bool = False, + connection_invalidated: bool = False, + code: Optional[str] = None, + ismulti: Optional[bool] = None, + ): + try: + text = str(orig) + except Exception as e: + text = "Error in str() of DB-API-generated exception: " + str(e) + StatementError.__init__( + self, + "(%s.%s) %s" + % (orig.__class__.__module__, orig.__class__.__name__, text), + statement, + params, + orig, + hide_parameters, + code=code, + ismulti=ismulti, + ) + self.connection_invalidated = connection_invalidated + + +class InterfaceError(DBAPIError): + """Wraps a DB-API InterfaceError.""" + + code = "rvf5" + + +class DatabaseError(DBAPIError): + """Wraps a DB-API DatabaseError.""" + + code = "4xp6" + + +class DataError(DatabaseError): + """Wraps a DB-API DataError.""" + + code = "9h9h" + + +class OperationalError(DatabaseError): + """Wraps a DB-API OperationalError.""" + + code = "e3q8" + + +class IntegrityError(DatabaseError): + """Wraps a DB-API IntegrityError.""" + + code = "gkpj" + + +class InternalError(DatabaseError): + """Wraps a DB-API InternalError.""" + + code = "2j85" + + +class ProgrammingError(DatabaseError): + """Wraps a DB-API ProgrammingError.""" + + code = "f405" + + +class NotSupportedError(DatabaseError): + """Wraps a DB-API NotSupportedError.""" + + code = "tw8g" + + +# Warnings + + +class SATestSuiteWarning(Warning): + """warning for a condition detected during tests that is non-fatal + + Currently outside of SAWarning so that we can work around tools like + Alembic doing the wrong thing with warnings. + + """ + + +class SADeprecationWarning(HasDescriptionCode, DeprecationWarning): + """Issued for usage of deprecated APIs.""" + + deprecated_since: Optional[str] = None + "Indicates the version that started raising this deprecation warning" + + +class Base20DeprecationWarning(SADeprecationWarning): + """Issued for usage of APIs specifically deprecated or legacy in + SQLAlchemy 2.0. + + .. seealso:: + + :ref:`error_b8d9`. + + :ref:`deprecation_20_mode` + + """ + + deprecated_since: Optional[str] = "1.4" + "Indicates the version that started raising this deprecation warning" + + def __str__(self) -> str: + return ( + super().__str__() + + " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)" + ) + + +class LegacyAPIWarning(Base20DeprecationWarning): + """indicates an API that is in 'legacy' status, a long term deprecation.""" + + +class MovedIn20Warning(Base20DeprecationWarning): + """Subtype of RemovedIn20Warning to indicate an API that moved only.""" + + +class SAPendingDeprecationWarning(PendingDeprecationWarning): + """A similar warning as :class:`_exc.SADeprecationWarning`, this warning + is not used in modern versions of SQLAlchemy. + + """ + + deprecated_since: Optional[str] = None + "Indicates the version that started raising this deprecation warning" + + +class SAWarning(HasDescriptionCode, RuntimeWarning): + """Issued at runtime.""" + + _what_are_we = "warning" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__init__.py new file mode 100644 index 0000000..f03ed94 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__init__.py @@ -0,0 +1,11 @@ +# ext/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .. import util as _sa_util + + +_sa_util.preloaded.import_prefix("sqlalchemy.ext") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..0340e5e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/associationproxy.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/associationproxy.cpython-311.pyc new file mode 100644 index 0000000..5e08d9b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/associationproxy.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/automap.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/automap.cpython-311.pyc new file mode 100644 index 0000000..846e172 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/automap.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/baked.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/baked.cpython-311.pyc new file mode 100644 index 0000000..0e36847 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/baked.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/compiler.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/compiler.cpython-311.pyc new file mode 100644 index 0000000..0b1beaa Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/compiler.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/horizontal_shard.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/horizontal_shard.cpython-311.pyc new file mode 100644 index 0000000..2bd5054 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/horizontal_shard.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/hybrid.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/hybrid.cpython-311.pyc new file mode 100644 index 0000000..31a156f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/hybrid.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/indexable.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/indexable.cpython-311.pyc new file mode 100644 index 0000000..d7bde5e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/indexable.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/instrumentation.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/instrumentation.cpython-311.pyc new file mode 100644 index 0000000..90b77be Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/instrumentation.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/mutable.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/mutable.cpython-311.pyc new file mode 100644 index 0000000..0247602 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/mutable.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/orderinglist.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/orderinglist.cpython-311.pyc new file mode 100644 index 0000000..c51955b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/orderinglist.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/serializer.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/serializer.cpython-311.pyc new file mode 100644 index 0000000..3d5c8d3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/__pycache__/serializer.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/associationproxy.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/associationproxy.py new file mode 100644 index 0000000..80e6fda --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/associationproxy.py @@ -0,0 +1,2005 @@ +# ext/associationproxy.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Contain the ``AssociationProxy`` class. + +The ``AssociationProxy`` is a Python property object which provides +transparent proxied access to the endpoint of an association object. + +See the example ``examples/association/proxied_association.py``. + +""" +from __future__ import annotations + +import operator +import typing +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generic +from typing import ItemsView +from typing import Iterable +from typing import Iterator +from typing import KeysView +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import MutableSequence +from typing import MutableSet +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import ValuesView + +from .. import ColumnElement +from .. import exc +from .. import inspect +from .. import orm +from .. import util +from ..orm import collections +from ..orm import InspectionAttrExtensionType +from ..orm import interfaces +from ..orm import ORMDescriptor +from ..orm.base import SQLORMOperations +from ..orm.interfaces import _AttributeOptions +from ..orm.interfaces import _DCAttributeOptions +from ..orm.interfaces import _DEFAULT_ATTRIBUTE_OPTIONS +from ..sql import operators +from ..sql import or_ +from ..sql.base import _NoArg +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import Self +from ..util.typing import SupportsIndex +from ..util.typing import SupportsKeysAndGetItem + +if typing.TYPE_CHECKING: + from ..orm.interfaces import MapperProperty + from ..orm.interfaces import PropComparator + from ..orm.mapper import Mapper + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) +_S = TypeVar("_S", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +def association_proxy( + target_collection: str, + attr: str, + *, + creator: Optional[_CreatorProtocol] = None, + getset_factory: Optional[_GetSetFactoryProtocol] = None, + proxy_factory: Optional[_ProxyFactoryProtocol] = None, + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, + info: Optional[_InfoType] = None, + cascade_scalar_deletes: bool = False, + create_on_none_assignment: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, +) -> AssociationProxy[Any]: + r"""Return a Python property implementing a view of a target + attribute which references an attribute on members of the + target. + + The returned value is an instance of :class:`.AssociationProxy`. + + Implements a Python property representing a relationship as a collection + of simpler values, or a scalar value. The proxied property will mimic + the collection type of the target (list, dict or set), or, in the case of + a one to one relationship, a simple scalar value. + + :param target_collection: Name of the attribute that is the immediate + target. This attribute is typically mapped by + :func:`~sqlalchemy.orm.relationship` to link to a target collection, but + can also be a many-to-one or non-scalar relationship. + + :param attr: Attribute on the associated instance or instances that + are available on instances of the target object. + + :param creator: optional. + + Defines custom behavior when new items are added to the proxied + collection. + + By default, adding new items to the collection will trigger a + construction of an instance of the target object, passing the given + item as a positional argument to the target constructor. For cases + where this isn't sufficient, :paramref:`.association_proxy.creator` + can supply a callable that will construct the object in the + appropriate way, given the item that was passed. + + For list- and set- oriented collections, a single argument is + passed to the callable. For dictionary oriented collections, two + arguments are passed, corresponding to the key and value. + + The :paramref:`.association_proxy.creator` callable is also invoked + for scalar (i.e. many-to-one, one-to-one) relationships. If the + current value of the target relationship attribute is ``None``, the + callable is used to construct a new object. If an object value already + exists, the given attribute value is populated onto that object. + + .. seealso:: + + :ref:`associationproxy_creator` + + :param cascade_scalar_deletes: when True, indicates that setting + the proxied value to ``None``, or deleting it via ``del``, should + also remove the source object. Only applies to scalar attributes. + Normally, removing the proxied target will not remove the proxy + source, as this object may have other state that is still to be + kept. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`cascade_scalar_deletes` - complete usage example + + :param create_on_none_assignment: when True, indicates that setting + the proxied value to ``None`` should **create** the source object + if it does not exist, using the creator. Only applies to scalar + attributes. This is mutually exclusive + vs. the :paramref:`.assocation_proxy.cascade_scalar_deletes`. + + .. versionadded:: 2.0.18 + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the attribute established by this :class:`.AssociationProxy` + should be part of the ``__repr__()`` method as generated by the dataclass + process. + + .. versionadded:: 2.0.0b4 + + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, specifies a default-value + generation function that will take place as part of the ``__init__()`` + method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to :ref:`orm_declarative_native_dataclasses`, + indicates if this field should be marked as keyword-only when generating + the ``__init__()`` method as generated by the dataclass process. + + .. versionadded:: 2.0.0b4 + + :param info: optional, will be assigned to + :attr:`.AssociationProxy.info` if present. + + + The following additional parameters involve injection of custom behaviors + within the :class:`.AssociationProxy` object and are for advanced use + only: + + :param getset_factory: Optional. Proxied attribute access is + automatically handled by routines that get and set values based on + the `attr` argument for this proxy. + + If you would like to customize this behavior, you may supply a + `getset_factory` callable that produces a tuple of `getter` and + `setter` functions. The factory is called with two arguments, the + abstract type of the underlying collection and this proxy instance. + + :param proxy_factory: Optional. The type of collection to emulate is + determined by sniffing the target collection. If your collection + type can't be determined by duck typing or you'd like to use a + different collection implementation, you may supply a factory + function to produce those collections. Only applicable to + non-scalar relationships. + + :param proxy_bulk_set: Optional, use with proxy_factory. + + + """ + return AssociationProxy( + target_collection, + attr, + creator=creator, + getset_factory=getset_factory, + proxy_factory=proxy_factory, + proxy_bulk_set=proxy_bulk_set, + info=info, + cascade_scalar_deletes=cascade_scalar_deletes, + create_on_none_assignment=create_on_none_assignment, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + ) + + +class AssociationProxyExtensionType(InspectionAttrExtensionType): + ASSOCIATION_PROXY = "ASSOCIATION_PROXY" + """Symbol indicating an :class:`.InspectionAttr` that's + of type :class:`.AssociationProxy`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + """ + + +class _GetterProtocol(Protocol[_T_co]): + def __call__(self, instance: Any) -> _T_co: ... + + +# mypy 0.990 we are no longer allowed to make this Protocol[_T_con] +class _SetterProtocol(Protocol): ... + + +class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]): + def __call__(self, instance: Any, value: _T_con) -> None: ... + + +class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]): + def __call__(self, instance: Any, key: Any, value: _T_con) -> None: ... + + +# mypy 0.990 we are no longer allowed to make this Protocol[_T_con] +class _CreatorProtocol(Protocol): ... + + +class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): + def __call__(self, value: _T_con) -> Any: ... + + +class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): + def __call__(self, key: Any, value: Optional[_T_con]) -> Any: ... + + +class _LazyCollectionProtocol(Protocol[_T]): + def __call__( + self, + ) -> Union[ + MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T] + ]: ... + + +class _GetSetFactoryProtocol(Protocol): + def __call__( + self, + collection_class: Optional[Type[Any]], + assoc_instance: AssociationProxyInstance[Any], + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... + + +class _ProxyFactoryProtocol(Protocol): + def __call__( + self, + lazy_collection: _LazyCollectionProtocol[Any], + creator: _CreatorProtocol, + value_attr: str, + parent: AssociationProxyInstance[Any], + ) -> Any: ... + + +class _ProxyBulkSetProtocol(Protocol): + def __call__( + self, proxy: _AssociationCollection[Any], collection: Iterable[Any] + ) -> None: ... + + +class _AssociationProxyProtocol(Protocol[_T]): + """describes the interface of :class:`.AssociationProxy` + without including descriptor methods in the interface.""" + + creator: Optional[_CreatorProtocol] + key: str + target_collection: str + value_attr: str + cascade_scalar_deletes: bool + create_on_none_assignment: bool + getset_factory: Optional[_GetSetFactoryProtocol] + proxy_factory: Optional[_ProxyFactoryProtocol] + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] + + @util.ro_memoized_property + def info(self) -> _InfoType: ... + + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: ... + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... + + +class AssociationProxy( + interfaces.InspectionAttrInfo, + ORMDescriptor[_T], + _DCAttributeOptions, + _AssociationProxyProtocol[_T], +): + """A descriptor that presents a read/write view of an object attribute.""" + + is_attribute = True + extension_type = AssociationProxyExtensionType.ASSOCIATION_PROXY + + def __init__( + self, + target_collection: str, + attr: str, + *, + creator: Optional[_CreatorProtocol] = None, + getset_factory: Optional[_GetSetFactoryProtocol] = None, + proxy_factory: Optional[_ProxyFactoryProtocol] = None, + proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, + info: Optional[_InfoType] = None, + cascade_scalar_deletes: bool = False, + create_on_none_assignment: bool = False, + attribute_options: Optional[_AttributeOptions] = None, + ): + """Construct a new :class:`.AssociationProxy`. + + The :class:`.AssociationProxy` object is typically constructed using + the :func:`.association_proxy` constructor function. See the + description of :func:`.association_proxy` for a description of all + parameters. + + + """ + self.target_collection = target_collection + self.value_attr = attr + self.creator = creator + self.getset_factory = getset_factory + self.proxy_factory = proxy_factory + self.proxy_bulk_set = proxy_bulk_set + + if cascade_scalar_deletes and create_on_none_assignment: + raise exc.ArgumentError( + "The cascade_scalar_deletes and create_on_none_assignment " + "parameters are mutually exclusive." + ) + self.cascade_scalar_deletes = cascade_scalar_deletes + self.create_on_none_assignment = create_on_none_assignment + + self.key = "_%s_%s_%s" % ( + type(self).__name__, + target_collection, + id(self), + ) + if info: + self.info = info # type: ignore + + if ( + attribute_options + and attribute_options != _DEFAULT_ATTRIBUTE_OPTIONS + ): + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS + + @overload + def __get__( + self, instance: Literal[None], owner: Literal[None] + ) -> Self: ... + + @overload + def __get__( + self, instance: Literal[None], owner: Any + ) -> AssociationProxyInstance[_T]: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: ... + + def __get__( + self, instance: object, owner: Any + ) -> Union[AssociationProxyInstance[_T], _T, AssociationProxy[_T]]: + if owner is None: + return self + inst = self._as_instance(owner, instance) + if inst: + return inst.get(instance) + + assert instance is None + + return self + + def __set__(self, instance: object, values: _T) -> None: + class_ = type(instance) + self._as_instance(class_, instance).set(instance, values) + + def __delete__(self, instance: object) -> None: + class_ = type(instance) + self._as_instance(class_, instance).delete(instance) + + def for_class( + self, class_: Type[Any], obj: Optional[object] = None + ) -> AssociationProxyInstance[_T]: + r"""Return the internal state local to a specific mapped class. + + E.g., given a class ``User``:: + + class User(Base): + # ... + + keywords = association_proxy('kws', 'keyword') + + If we access this :class:`.AssociationProxy` from + :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the + target class for this proxy as mapped by ``User``:: + + inspect(User).all_orm_descriptors["keywords"].for_class(User).target_class + + This returns an instance of :class:`.AssociationProxyInstance` that + is specific to the ``User`` class. The :class:`.AssociationProxy` + object remains agnostic of its parent class. + + :param class\_: the class that we are returning state for. + + :param obj: optional, an instance of the class that is required + if the attribute refers to a polymorphic target, e.g. where we have + to look at the type of the actual destination object to get the + complete path. + + .. versionadded:: 1.3 - :class:`.AssociationProxy` no longer stores + any state specific to a particular parent class; the state is now + stored in per-class :class:`.AssociationProxyInstance` objects. + + + """ + return self._as_instance(class_, obj) + + def _as_instance( + self, class_: Any, obj: Any + ) -> AssociationProxyInstance[_T]: + try: + inst = class_.__dict__[self.key + "_inst"] + except KeyError: + inst = None + + # avoid exception context + if inst is None: + owner = self._calc_owner(class_) + if owner is not None: + inst = AssociationProxyInstance.for_proxy(self, owner, obj) + setattr(class_, self.key + "_inst", inst) + else: + inst = None + + if inst is not None and not inst._is_canonical: + # the AssociationProxyInstance can't be generalized + # since the proxied attribute is not on the targeted + # class, only on subclasses of it, which might be + # different. only return for the specific + # object's current value + return inst._non_canonical_get_for_object(obj) # type: ignore + else: + return inst # type: ignore # TODO + + def _calc_owner(self, target_cls: Any) -> Any: + # we might be getting invoked for a subclass + # that is not mapped yet, in some declarative situations. + # save until we are mapped + try: + insp = inspect(target_cls) + except exc.NoInspectionAvailable: + # can't find a mapper, don't set owner. if we are a not-yet-mapped + # subclass, we can also scan through __mro__ to find a mapped + # class, but instead just wait for us to be called again against a + # mapped class normally. + return None + else: + return insp.mapper.class_manager.class_ + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(instance: Any) -> Optional[Any]: + return _getter(instance) if instance is not None else None + + if collection_class is dict: + + def dict_setter(instance: Any, k: Any, value: Any) -> None: + setattr(instance, attr, value) + + return getter, dict_setter + + else: + + def plain_setter(o: Any, v: Any) -> None: + setattr(o, attr, v) + + return getter, plain_setter + + def __repr__(self) -> str: + return "AssociationProxy(%r, %r)" % ( + self.target_collection, + self.value_attr, + ) + + +# the pep-673 Self type does not work in Mypy for a "hybrid" +# style method that returns type or Self, so for one specific case +# we still need to use the pre-pep-673 workaround. +_Self = TypeVar("_Self", bound="AssociationProxyInstance[Any]") + + +class AssociationProxyInstance(SQLORMOperations[_T]): + """A per-class object that serves class- and object-specific results. + + This is used by :class:`.AssociationProxy` when it is invoked + in terms of a specific class or instance of a class, i.e. when it is + used as a regular Python descriptor. + + When referring to the :class:`.AssociationProxy` as a normal Python + descriptor, the :class:`.AssociationProxyInstance` is the object that + actually serves the information. Under normal circumstances, its presence + is transparent:: + + >>> User.keywords.scalar + False + + In the special case that the :class:`.AssociationProxy` object is being + accessed directly, in order to get an explicit handle to the + :class:`.AssociationProxyInstance`, use the + :meth:`.AssociationProxy.for_class` method:: + + proxy_state = inspect(User).all_orm_descriptors["keywords"].for_class(User) + + # view if proxy object is scalar or not + >>> proxy_state.scalar + False + + .. versionadded:: 1.3 + + """ # noqa + + collection_class: Optional[Type[Any]] + parent: _AssociationProxyProtocol[_T] + + def __init__( + self, + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ): + self.parent = parent + self.key = parent.key + self.owning_class = owning_class + self.target_collection = parent.target_collection + self.collection_class = None + self.target_class = target_class + self.value_attr = value_attr + + target_class: Type[Any] + """The intermediary class handled by this + :class:`.AssociationProxyInstance`. + + Intercepted append/set/assignment events will result + in the generation of new instances of this class. + + """ + + @classmethod + def for_proxy( + cls, + parent: AssociationProxy[_T], + owning_class: Type[Any], + parent_instance: Any, + ) -> AssociationProxyInstance[_T]: + target_collection = parent.target_collection + value_attr = parent.value_attr + prop = cast( + "orm.RelationshipProperty[_T]", + orm.class_mapper(owning_class).get_property(target_collection), + ) + + # this was never asserted before but this should be made clear. + if not isinstance(prop, orm.RelationshipProperty): + raise NotImplementedError( + "association proxy to a non-relationship " + "intermediary is not supported" + ) from None + + target_class = prop.mapper.class_ + + try: + target_assoc = cast( + "AssociationProxyInstance[_T]", + cls._cls_unwrap_target_assoc_proxy(target_class, value_attr), + ) + except AttributeError: + # the proxied attribute doesn't exist on the target class; + # return an "ambiguous" instance that will work on a per-object + # basis + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + except Exception as err: + raise exc.InvalidRequestError( + f"Association proxy received an unexpected error when " + f"trying to retreive attribute " + f'"{target_class.__name__}.{parent.value_attr}" from ' + f'class "{target_class.__name__}": {err}' + ) from err + else: + return cls._construct_for_assoc( + target_assoc, parent, owning_class, target_class, value_attr + ) + + @classmethod + def _construct_for_assoc( + cls, + target_assoc: Optional[AssociationProxyInstance[_T]], + parent: _AssociationProxyProtocol[_T], + owning_class: Type[Any], + target_class: Type[Any], + value_attr: str, + ) -> AssociationProxyInstance[_T]: + if target_assoc is not None: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + attr = getattr(target_class, value_attr) + if not hasattr(attr, "_is_internal_proxy"): + return AmbiguousAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + is_object = attr._impl_uses_objects + if is_object: + return ObjectAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + else: + return ColumnAssociationProxyInstance( + parent, owning_class, target_class, value_attr + ) + + def _get_property(self) -> MapperProperty[Any]: + return orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + @property + def _comparator(self) -> PropComparator[Any]: + return getattr( # type: ignore + self.owning_class, self.target_collection + ).comparator + + def __clause_element__(self) -> NoReturn: + raise NotImplementedError( + "The association proxy can't be used as a plain column " + "expression; it only works inside of a comparison expression" + ) + + @classmethod + def _cls_unwrap_target_assoc_proxy( + cls, target_class: Any, value_attr: str + ) -> Optional[AssociationProxyInstance[_T]]: + attr = getattr(target_class, value_attr) + assert not isinstance(attr, AssociationProxy) + if isinstance(attr, AssociationProxyInstance): + return attr + return None + + @util.memoized_property + def _unwrap_target_assoc_proxy( + self, + ) -> Optional[AssociationProxyInstance[_T]]: + return self._cls_unwrap_target_assoc_proxy( + self.target_class, self.value_attr + ) + + @property + def remote_attr(self) -> SQLORMOperations[_T]: + """The 'remote' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.local_attr` + + """ + return cast( + "SQLORMOperations[_T]", getattr(self.target_class, self.value_attr) + ) + + @property + def local_attr(self) -> SQLORMOperations[Any]: + """The 'local' class attribute referenced by this + :class:`.AssociationProxyInstance`. + + .. seealso:: + + :attr:`.AssociationProxyInstance.attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return cast( + "SQLORMOperations[Any]", + getattr(self.owning_class, self.target_collection), + ) + + @property + def attr(self) -> Tuple[SQLORMOperations[Any], SQLORMOperations[_T]]: + """Return a tuple of ``(local_attr, remote_attr)``. + + This attribute was originally intended to facilitate using the + :meth:`_query.Query.join` method to join across the two relationships + at once, however this makes use of a deprecated calling style. + + To use :meth:`_sql.select.join` or :meth:`_orm.Query.join` with + an association proxy, the current method is to make use of the + :attr:`.AssociationProxyInstance.local_attr` and + :attr:`.AssociationProxyInstance.remote_attr` attributes separately:: + + stmt = ( + select(Parent). + join(Parent.proxied.local_attr). + join(Parent.proxied.remote_attr) + ) + + A future release may seek to provide a more succinct join pattern + for association proxy attributes. + + .. seealso:: + + :attr:`.AssociationProxyInstance.local_attr` + + :attr:`.AssociationProxyInstance.remote_attr` + + """ + return (self.local_attr, self.remote_attr) + + @util.memoized_property + def scalar(self) -> bool: + """Return ``True`` if this :class:`.AssociationProxyInstance` + proxies a scalar relationship on the local side.""" + + scalar = not self._get_property().uselist + if scalar: + self._initialize_scalar_accessors() + return scalar + + @util.memoized_property + def _value_is_scalar(self) -> bool: + return ( + not self._get_property() + .mapper.get_property(self.value_attr) + .uselist + ) + + @property + def _target_is_object(self) -> bool: + raise NotImplementedError() + + _scalar_get: _GetterProtocol[_T] + _scalar_set: _PlainSetterProtocol[_T] + + def _initialize_scalar_accessors(self) -> None: + if self.parent.getset_factory: + get, set_ = self.parent.getset_factory(None, self) + else: + get, set_ = self.parent._default_getset(None) + self._scalar_get, self._scalar_set = get, cast( + "_PlainSetterProtocol[_T]", set_ + ) + + def _default_getset( + self, collection_class: Any + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + attr = self.value_attr + _getter = operator.attrgetter(attr) + + def getter(instance: Any) -> Optional[_T]: + return _getter(instance) if instance is not None else None + + if collection_class is dict: + + def dict_setter(instance: Any, k: Any, value: _T) -> None: + setattr(instance, attr, value) + + return getter, dict_setter + else: + + def plain_setter(o: Any, v: _T) -> None: + setattr(o, attr, v) + + return getter, plain_setter + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.parent.info + + @overload + def get(self: _Self, obj: Literal[None]) -> _Self: ... + + @overload + def get(self, obj: Any) -> _T: ... + + def get( + self, obj: Any + ) -> Union[Optional[_T], AssociationProxyInstance[_T]]: + if obj is None: + return self + + proxy: _T + + if self.scalar: + target = getattr(obj, self.target_collection) + return self._scalar_get(target) + else: + try: + # If the owning instance is reborn (orm session resurrect, + # etc.), refresh the proxy cache. + creator_id, self_id, proxy = cast( + "Tuple[int, int, _T]", getattr(obj, self.key) + ) + except AttributeError: + pass + else: + if id(obj) == creator_id and id(self) == self_id: + assert self.collection_class is not None + return proxy + + self.collection_class, proxy = self._new( + _lazy_collection(obj, self.target_collection) + ) + setattr(obj, self.key, (id(obj), id(self), proxy)) + return proxy + + def set(self, obj: Any, values: _T) -> None: + if self.scalar: + creator = cast( + "_PlainCreatorProtocol[_T]", + ( + self.parent.creator + if self.parent.creator + else self.target_class + ), + ) + target = getattr(obj, self.target_collection) + if target is None: + if ( + values is None + and not self.parent.create_on_none_assignment + ): + return + setattr(obj, self.target_collection, creator(values)) + else: + self._scalar_set(target, values) + if values is None and self.parent.cascade_scalar_deletes: + setattr(obj, self.target_collection, None) + else: + proxy = self.get(obj) + assert self.collection_class is not None + if proxy is not values: + proxy._bulk_replace(self, values) + + def delete(self, obj: Any) -> None: + if self.owning_class is None: + self._calc_owner(obj, None) + + if self.scalar: + target = getattr(obj, self.target_collection) + if target is not None: + delattr(target, self.value_attr) + delattr(obj, self.target_collection) + + def _new( + self, lazy_collection: _LazyCollectionProtocol[_T] + ) -> Tuple[Type[Any], _T]: + creator = ( + self.parent.creator + if self.parent.creator is not None + else cast("_CreatorProtocol", self.target_class) + ) + collection_class = util.duck_type_collection(lazy_collection()) + + if collection_class is None: + raise exc.InvalidRequestError( + f"lazy collection factory did not return a " + f"valid collection type, got {collection_class}" + ) + if self.parent.proxy_factory: + return ( + collection_class, + self.parent.proxy_factory( + lazy_collection, creator, self.value_attr, self + ), + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory(collection_class, self) + else: + getter, setter = self.parent._default_getset(collection_class) + + if collection_class is list: + return ( + collection_class, + cast( + _T, + _AssociationList( + lazy_collection, creator, getter, setter, self + ), + ), + ) + elif collection_class is dict: + return ( + collection_class, + cast( + _T, + _AssociationDict( + lazy_collection, creator, getter, setter, self + ), + ), + ) + elif collection_class is set: + return ( + collection_class, + cast( + _T, + _AssociationSet( + lazy_collection, creator, getter, setter, self + ), + ), + ) + else: + raise exc.ArgumentError( + "could not guess which interface to use for " + 'collection_class "%s" backing "%s"; specify a ' + "proxy_factory and proxy_bulk_set manually" + % (self.collection_class, self.target_collection) + ) + + def _set( + self, proxy: _AssociationCollection[Any], values: Iterable[Any] + ) -> None: + if self.parent.proxy_bulk_set: + self.parent.proxy_bulk_set(proxy, values) + elif self.collection_class is list: + cast("_AssociationList[Any]", proxy).extend(values) + elif self.collection_class is dict: + cast("_AssociationDict[Any, Any]", proxy).update(values) + elif self.collection_class is set: + cast("_AssociationSet[Any]", proxy).update(values) + else: + raise exc.ArgumentError( + "no proxy_bulk_set supplied for custom " + "collection_class implementation" + ) + + def _inflate(self, proxy: _AssociationCollection[Any]) -> None: + creator = ( + self.parent.creator + and self.parent.creator + or cast(_CreatorProtocol, self.target_class) + ) + + if self.parent.getset_factory: + getter, setter = self.parent.getset_factory( + self.collection_class, self + ) + else: + getter, setter = self.parent._default_getset(self.collection_class) + + proxy.creator = creator + proxy.getter = getter + proxy.setter = setter + + def _criterion_exists( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + is_has = kwargs.pop("is_has", None) + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + inner = target_assoc._criterion_exists( + criterion=criterion, **kwargs + ) + return self._comparator._criterion_exists(inner) + + if self._target_is_object: + attr = getattr(self.target_class, self.value_attr) + value_expr = attr.comparator._criterion_exists(criterion, **kwargs) + else: + if kwargs: + raise exc.ArgumentError( + "Can't apply keyword arguments to column-targeted " + "association proxy; use ==" + ) + elif is_has and criterion is not None: + raise exc.ArgumentError( + "Non-empty has() not allowed for " + "column-targeted association proxy; use ==" + ) + + value_expr = criterion + + return self._comparator._criterion_exists(value_expr) + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce a proxied 'any' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + self.scalar + and (not self._target_is_object or self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'any()' not implemented for scalar attributes. Use has()." + ) + return self._criterion_exists( + criterion=criterion, is_has=False, **kwargs + ) + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce a proxied 'has' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any` + and/or :meth:`.Relationship.Comparator.has` + operators of the underlying proxied attributes. + + """ + if self._unwrap_target_assoc_proxy is None and ( + not self.scalar + or (self._target_is_object and not self._value_is_scalar) + ): + raise exc.InvalidRequestError( + "'has()' not implemented for collections. Use any()." + ) + return self._criterion_exists( + criterion=criterion, is_has=True, **kwargs + ) + + def __repr__(self) -> str: + return "%s(%r)" % (self.__class__.__name__, self.parent) + + +class AmbiguousAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` where we cannot determine + the type of target object. + """ + + _is_canonical = False + + def _ambiguous(self) -> NoReturn: + raise AttributeError( + "Association proxy %s.%s refers to an attribute '%s' that is not " + "directly mapped on class %s; therefore this operation cannot " + "proceed since we don't know what type of object is referred " + "towards" + % ( + self.owning_class.__name__, + self.target_collection, + self.value_attr, + self.target_class, + ) + ) + + def get(self, obj: Any) -> Any: + if obj is None: + return self + else: + return super().get(obj) + + def __eq__(self, obj: object) -> NoReturn: + self._ambiguous() + + def __ne__(self, obj: object) -> NoReturn: + self._ambiguous() + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> NoReturn: + self._ambiguous() + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> NoReturn: + self._ambiguous() + + @util.memoized_property + def _lookup_cache(self) -> Dict[Type[Any], AssociationProxyInstance[_T]]: + # mapping of ->AssociationProxyInstance. + # e.g. proxy is A-> A.b -> B -> B.b_attr, but B.b_attr doesn't exist; + # only B1(B) and B2(B) have "b_attr", keys in here would be B1, B2 + return {} + + def _non_canonical_get_for_object( + self, parent_instance: Any + ) -> AssociationProxyInstance[_T]: + if parent_instance is not None: + actual_obj = getattr(parent_instance, self.target_collection) + if actual_obj is not None: + try: + insp = inspect(actual_obj) + except exc.NoInspectionAvailable: + pass + else: + mapper = insp.mapper + instance_class = mapper.class_ + if instance_class not in self._lookup_cache: + self._populate_cache(instance_class, mapper) + + try: + return self._lookup_cache[instance_class] + except KeyError: + pass + + # no object or ambiguous object given, so return "self", which + # is a proxy with generally only instance-level functionality + return self + + def _populate_cache( + self, instance_class: Any, mapper: Mapper[Any] + ) -> None: + prop = orm.class_mapper(self.owning_class).get_property( + self.target_collection + ) + + if mapper.isa(prop.mapper): + target_class = instance_class + try: + target_assoc = self._cls_unwrap_target_assoc_proxy( + target_class, self.value_attr + ) + except AttributeError: + pass + else: + self._lookup_cache[instance_class] = self._construct_for_assoc( + cast("AssociationProxyInstance[_T]", target_assoc), + self.parent, + self.owning_class, + target_class, + self.value_attr, + ) + + +class ObjectAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` that has an object as a target.""" + + _target_is_object: bool = True + _is_canonical = True + + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: + """Produce a proxied 'contains' expression using EXISTS. + + This expression will be a composed product + using the :meth:`.Relationship.Comparator.any`, + :meth:`.Relationship.Comparator.has`, + and/or :meth:`.Relationship.Comparator.contains` + operators of the underlying proxied attributes. + """ + + target_assoc = self._unwrap_target_assoc_proxy + if target_assoc is not None: + return self._comparator._criterion_exists( + target_assoc.contains(other) + if not target_assoc.scalar + else target_assoc == other + ) + elif ( + self._target_is_object + and self.scalar + and not self._value_is_scalar + ): + return self._comparator.has( + getattr(self.target_class, self.value_attr).contains(other) + ) + elif self._target_is_object and self.scalar and self._value_is_scalar: + raise exc.InvalidRequestError( + "contains() doesn't apply to a scalar object endpoint; use ==" + ) + else: + return self._comparator._criterion_exists( + **{self.value_attr: other} + ) + + def __eq__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + if obj is None: + return or_( + self._comparator.has(**{self.value_attr: obj}), + self._comparator == None, + ) + else: + return self._comparator.has(**{self.value_attr: obj}) + + def __ne__(self, obj: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # note the has() here will fail for collections; eq_() + # is only allowed with a scalar. + return self._comparator.has( + getattr(self.target_class, self.value_attr) != obj + ) + + +class ColumnAssociationProxyInstance(AssociationProxyInstance[_T]): + """an :class:`.AssociationProxyInstance` that has a database column as a + target. + """ + + _target_is_object: bool = False + _is_canonical = True + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + # special case "is None" to check for no related row as well + expr = self._criterion_exists( + self.remote_attr.operate(operators.eq, other) + ) + if other is None: + return or_(expr, self._comparator == None) + else: + return expr + + def operate( + self, op: operators.OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return self._criterion_exists( + self.remote_attr.operate(op, *other, **kwargs) + ) + + +class _lazy_collection(_LazyCollectionProtocol[_T]): + def __init__(self, obj: Any, target: str): + self.parent = obj + self.target = target + + def __call__( + self, + ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: + return getattr(self.parent, self.target) # type: ignore[no-any-return] + + def __getstate__(self) -> Any: + return {"obj": self.parent, "target": self.target} + + def __setstate__(self, state: Any) -> None: + self.parent = state["obj"] + self.target = state["target"] + + +_IT = TypeVar("_IT", bound="Any") +"""instance type - this is the type of object inside a collection. + +this is not the same as the _T of AssociationProxy and +AssociationProxyInstance itself, which will often refer to the +collection[_IT] type. + +""" + + +class _AssociationCollection(Generic[_IT]): + getter: _GetterProtocol[_IT] + """A function. Given an associated object, return the 'value'.""" + + creator: _CreatorProtocol + """ + A function that creates new target entities. Given one parameter: + value. This assertion is assumed:: + + obj = creator(somevalue) + assert getter(obj) == somevalue + """ + + parent: AssociationProxyInstance[_IT] + setter: _SetterProtocol + """A function. Given an associated object and a value, store that + value on the object. + """ + + lazy_collection: _LazyCollectionProtocol[_IT] + """A callable returning a list-based collection of entities (usually an + object attribute managed by a SQLAlchemy relationship())""" + + def __init__( + self, + lazy_collection: _LazyCollectionProtocol[_IT], + creator: _CreatorProtocol, + getter: _GetterProtocol[_IT], + setter: _SetterProtocol, + parent: AssociationProxyInstance[_IT], + ): + """Constructs an _AssociationCollection. + + This will always be a subclass of either _AssociationList, + _AssociationSet, or _AssociationDict. + + """ + self.lazy_collection = lazy_collection + self.creator = creator + self.getter = getter + self.setter = setter + self.parent = parent + + if typing.TYPE_CHECKING: + col: Collection[_IT] + else: + col = property(lambda self: self.lazy_collection()) + + def __len__(self) -> int: + return len(self.col) + + def __bool__(self) -> bool: + return bool(self.col) + + def __getstate__(self) -> Any: + return {"parent": self.parent, "lazy_collection": self.lazy_collection} + + def __setstate__(self, state: Any) -> None: + self.parent = state["parent"] + self.lazy_collection = state["lazy_collection"] + self.parent._inflate(self) + + def clear(self) -> None: + raise NotImplementedError() + + +class _AssociationSingleItem(_AssociationCollection[_T]): + setter: _PlainSetterProtocol[_T] + creator: _PlainCreatorProtocol[_T] + + def _create(self, value: _T) -> Any: + return self.creator(value) + + def _get(self, object_: Any) -> _T: + return self.getter(object_) + + def _bulk_replace( + self, assoc_proxy: AssociationProxyInstance[Any], values: Iterable[_IT] + ) -> None: + self.clear() + assoc_proxy._set(self, values) + + +class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): + """Generic, converting, list-to-list proxy.""" + + col: MutableSequence[_T] + + def _set(self, object_: Any, value: _T) -> None: + self.setter(object_, value) + + @overload + def __getitem__(self, index: int) -> _T: ... + + @overload + def __getitem__(self, index: slice) -> MutableSequence[_T]: ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[_T, MutableSequence[_T]]: + if not isinstance(index, slice): + return self._get(self.col[index]) + else: + return [self._get(member) for member in self.col[index]] + + @overload + def __setitem__(self, index: int, value: _T) -> None: ... + + @overload + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... + + def __setitem__( + self, index: Union[int, slice], value: Union[_T, Iterable[_T]] + ) -> None: + if not isinstance(index, slice): + self._set(self.col[index], cast("_T", value)) + else: + if index.stop is None: + stop = len(self) + elif index.stop < 0: + stop = len(self) + index.stop + else: + stop = index.stop + step = index.step or 1 + + start = index.start or 0 + rng = list(range(index.start or 0, stop, step)) + + sized_value = list(value) + + if step == 1: + for i in rng: + del self[start] + i = start + for item in sized_value: + self.insert(i, item) + i += 1 + else: + if len(sized_value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" + % (len(sized_value), len(rng)) + ) + for i, item in zip(rng, value): + self._set(self.col[i], item) + + @overload + def __delitem__(self, index: int) -> None: ... + + @overload + def __delitem__(self, index: slice) -> None: ... + + def __delitem__(self, index: Union[slice, int]) -> None: + del self.col[index] + + def __contains__(self, value: object) -> bool: + for member in self.col: + # testlib.pragma exempt:__eq__ + if self._get(member) == value: + return True + return False + + def __iter__(self) -> Iterator[_T]: + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or + just use the underlying collection directly from its property + on the parent. + """ + + for member in self.col: + yield self._get(member) + return + + def append(self, value: _T) -> None: + col = self.col + item = self._create(value) + col.append(item) + + def count(self, value: Any) -> int: + count = 0 + for v in self: + if v == value: + count += 1 + return count + + def extend(self, values: Iterable[_T]) -> None: + for v in values: + self.append(v) + + def insert(self, index: int, value: _T) -> None: + self.col[index:index] = [self._create(value)] + + def pop(self, index: int = -1) -> _T: + return self.getter(self.col.pop(index)) + + def remove(self, value: _T) -> None: + for i, val in enumerate(self): + if val == value: + del self.col[i] + return + raise ValueError("value not in list") + + def reverse(self) -> NoReturn: + """Not supported, use reversed(mylist)""" + + raise NotImplementedError() + + def sort(self) -> NoReturn: + """Not supported, use sorted(mylist)""" + + raise NotImplementedError() + + def clear(self) -> None: + del self.col[0 : len(self.col)] + + def __eq__(self, other: object) -> bool: + return list(self) == other + + def __ne__(self, other: object) -> bool: + return list(self) != other + + def __lt__(self, other: List[_T]) -> bool: + return list(self) < other + + def __le__(self, other: List[_T]) -> bool: + return list(self) <= other + + def __gt__(self, other: List[_T]) -> bool: + return list(self) > other + + def __ge__(self, other: List[_T]) -> bool: + return list(self) >= other + + def __add__(self, other: List[_T]) -> List[_T]: + try: + other = list(other) + except TypeError: + return NotImplemented + return list(self) + other + + def __radd__(self, other: List[_T]) -> List[_T]: + try: + other = list(other) + except TypeError: + return NotImplemented + return other + list(self) + + def __mul__(self, n: SupportsIndex) -> List[_T]: + if not isinstance(n, int): + return NotImplemented + return list(self) * n + + def __rmul__(self, n: SupportsIndex) -> List[_T]: + if not isinstance(n, int): + return NotImplemented + return n * list(self) + + def __iadd__(self, iterable: Iterable[_T]) -> Self: + self.extend(iterable) + return self + + def __imul__(self, n: SupportsIndex) -> Self: + # unlike a regular list *=, proxied __imul__ will generate unique + # backing objects for each copy. *= on proxied lists is a bit of + # a stretch anyhow, and this interpretation of the __imul__ contract + # is more plausibly useful than copying the backing objects. + if not isinstance(n, int): + raise NotImplementedError() + if n == 0: + self.clear() + elif n > 1: + self.extend(list(self) * (n - 1)) + return self + + if typing.TYPE_CHECKING: + # TODO: no idea how to do this without separate "stub" + def index( + self, value: Any, start: int = ..., stop: int = ... + ) -> int: ... + + else: + + def index(self, value: Any, *arg) -> int: + ls = list(self) + return ls.index(value, *arg) + + def copy(self) -> List[_T]: + return list(self) + + def __repr__(self) -> str: + return repr(list(self)) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): + """Generic, converting, dict-to-dict proxy.""" + + setter: _DictSetterProtocol[_VT] + creator: _KeyCreatorProtocol[_VT] + col: MutableMapping[_KT, Optional[_VT]] + + def _create(self, key: _KT, value: Optional[_VT]) -> Any: + return self.creator(key, value) + + def _get(self, object_: Any) -> _VT: + return self.getter(object_) + + def _set(self, object_: Any, key: _KT, value: _VT) -> None: + return self.setter(object_, key, value) + + def __getitem__(self, key: _KT) -> _VT: + return self._get(self.col[key]) + + def __setitem__(self, key: _KT, value: _VT) -> None: + if key in self.col: + self._set(self.col[key], key, value) + else: + self.col[key] = self._create(key, value) + + def __delitem__(self, key: _KT) -> None: + del self.col[key] + + def __contains__(self, key: object) -> bool: + return key in self.col + + def __iter__(self) -> Iterator[_KT]: + return iter(self.col.keys()) + + def clear(self) -> None: + self.col.clear() + + def __eq__(self, other: object) -> bool: + return dict(self) == other + + def __ne__(self, other: object) -> bool: + return dict(self) != other + + def __repr__(self) -> str: + return repr(dict(self)) + + @overload + def get(self, __key: _KT) -> Optional[_VT]: ... + + @overload + def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... + + def get( + self, key: _KT, default: Optional[Union[_VT, _T]] = None + ) -> Union[_VT, _T, None]: + try: + return self[key] + except KeyError: + return default + + def setdefault(self, key: _KT, default: Optional[_VT] = None) -> _VT: + # TODO: again, no idea how to create an actual MutableMapping. + # default must allow None, return type can't include None, + # the stub explicitly allows for default of None with a cryptic message + # "This overload should be allowed only if the value type is + # compatible with None.". + if key not in self.col: + self.col[key] = self._create(key, default) + return default # type: ignore + else: + return self[key] + + def keys(self) -> KeysView[_KT]: + return self.col.keys() + + def items(self) -> ItemsView[_KT, _VT]: + return ItemsView(self) + + def values(self) -> ValuesView[_VT]: + return ValuesView(self) + + @overload + def pop(self, __key: _KT) -> _VT: ... + + @overload + def pop( + self, __key: _KT, default: Union[_VT, _T] = ... + ) -> Union[_VT, _T]: ... + + def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]: + member = self.col.pop(__key, *arg, **kw) + return self._get(member) + + def popitem(self) -> Tuple[_KT, _VT]: + item = self.col.popitem() + return (item[0], self._get(item[1])) + + @overload + def update( + self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT + ) -> None: ... + + @overload + def update( + self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... + + @overload + def update(self, **kwargs: _VT) -> None: ... + + def update(self, *a: Any, **kw: Any) -> None: + up: Dict[_KT, _VT] = {} + up.update(*a, **kw) + + for key, value in up.items(): + self[key] = value + + def _bulk_replace( + self, + assoc_proxy: AssociationProxyInstance[Any], + values: Mapping[_KT, _VT], + ) -> None: + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + for key, member in values.items() or (): + if key in additions: + self[key] = member + elif key in constants: + self[key] = member + + for key in removals: + del self[key] + + def copy(self) -> Dict[_KT, _VT]: + return dict(self.items()) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(dict, func_name) + ): + func.__doc__ = getattr(dict, func_name).__doc__ + del func_name, func + + +class _AssociationSet(_AssociationSingleItem[_T], MutableSet[_T]): + """Generic, converting, set-to-set proxy.""" + + col: MutableSet[_T] + + def __len__(self) -> int: + return len(self.col) + + def __bool__(self) -> bool: + if self.col: + return True + else: + return False + + def __contains__(self, __o: object) -> bool: + for member in self.col: + if self._get(member) == __o: + return True + return False + + def __iter__(self) -> Iterator[_T]: + """Iterate over proxied values. + + For the actual domain objects, iterate over .col instead or just use + the underlying collection directly from its property on the parent. + + """ + for member in self.col: + yield self._get(member) + return + + def add(self, __element: _T) -> None: + if __element not in self: + self.col.add(self._create(__element)) + + # for discard and remove, choosing a more expensive check strategy rather + # than call self.creator() + def discard(self, __element: _T) -> None: + for member in self.col: + if self._get(member) == __element: + self.col.discard(member) + break + + def remove(self, __element: _T) -> None: + for member in self.col: + if self._get(member) == __element: + self.col.discard(member) + return + raise KeyError(__element) + + def pop(self) -> _T: + if not self.col: + raise KeyError("pop from an empty set") + member = self.col.pop() + return self._get(member) + + def update(self, *s: Iterable[_T]) -> None: + for iterable in s: + for value in iterable: + self.add(value) + + def _bulk_replace(self, assoc_proxy: Any, values: Iterable[_T]) -> None: + existing = set(self) + constants = existing.intersection(values or ()) + additions = set(values or ()).difference(constants) + removals = existing.difference(constants) + + appender = self.add + remover = self.remove + + for member in values or (): + if member in additions: + appender(member) + elif member in constants: + appender(member) + + for member in removals: + remover(member) + + def __ior__( # type: ignore + self, other: AbstractSet[_S] + ) -> MutableSet[Union[_T, _S]]: + if not collections._set_binops_check_strict(self, other): + raise NotImplementedError() + for value in other: + self.add(value) + return self + + def _set(self) -> Set[_T]: + return set(iter(self)) + + def union(self, *s: Iterable[_S]) -> MutableSet[Union[_T, _S]]: + return set(self).union(*s) + + def __or__(self, __s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.union(__s) + + def difference(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).difference(*s) + + def __sub__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.difference(s) + + def difference_update(self, *s: Iterable[Any]) -> None: + for other in s: + for value in other: + self.discard(value) + + def __isub__(self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + for value in s: + self.discard(value) + return self + + def intersection(self, *s: Iterable[Any]) -> MutableSet[_T]: + return set(self).intersection(*s) + + def __and__(self, s: AbstractSet[Any]) -> MutableSet[_T]: + return self.intersection(s) + + def intersection_update(self, *s: Iterable[Any]) -> None: + for other in s: + want, have = self.intersection(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __iand__(self, s: AbstractSet[Any]) -> Self: + if not collections._set_binops_check_strict(self, s): + raise NotImplementedError() + want = self.intersection(s) + have: Set[_T] = set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + return self + + def symmetric_difference(self, __s: Iterable[_T]) -> MutableSet[_T]: + return set(self).symmetric_difference(__s) + + def __xor__(self, s: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: + return self.symmetric_difference(s) + + def symmetric_difference_update(self, other: Iterable[Any]) -> None: + want, have = self.symmetric_difference(other), set(self) + + remove, add = have - want, want - have + + for value in remove: + self.remove(value) + for value in add: + self.add(value) + + def __ixor__(self, other: AbstractSet[_S]) -> MutableSet[Union[_T, _S]]: # type: ignore # noqa: E501 + if not collections._set_binops_check_strict(self, other): + raise NotImplementedError() + + self.symmetric_difference_update(other) + return self + + def issubset(self, __s: Iterable[Any]) -> bool: + return set(self).issubset(__s) + + def issuperset(self, __s: Iterable[Any]) -> bool: + return set(self).issuperset(__s) + + def clear(self) -> None: + self.col.clear() + + def copy(self) -> AbstractSet[_T]: + return set(self) + + def __eq__(self, other: object) -> bool: + return set(self) == other + + def __ne__(self, other: object) -> bool: + return set(self) != other + + def __lt__(self, other: AbstractSet[Any]) -> bool: + return set(self) < other + + def __le__(self, other: AbstractSet[Any]) -> bool: + return set(self) <= other + + def __gt__(self, other: AbstractSet[Any]) -> bool: + return set(self) > other + + def __ge__(self, other: AbstractSet[Any]) -> bool: + return set(self) >= other + + def __repr__(self) -> str: + return repr(set(self)) + + def __hash__(self) -> NoReturn: + raise TypeError("%s objects are unhashable" % type(self).__name__) + + if not typing.TYPE_CHECKING: + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(set, func_name) + ): + func.__doc__ = getattr(set, func_name).__doc__ + del func_name, func diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__init__.py new file mode 100644 index 0000000..78c707b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__init__.py @@ -0,0 +1,25 @@ +# ext/asyncio/__init__.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from .engine import async_engine_from_config as async_engine_from_config +from .engine import AsyncConnection as AsyncConnection +from .engine import AsyncEngine as AsyncEngine +from .engine import AsyncTransaction as AsyncTransaction +from .engine import create_async_engine as create_async_engine +from .engine import create_async_pool_from_url as create_async_pool_from_url +from .result import AsyncMappingResult as AsyncMappingResult +from .result import AsyncResult as AsyncResult +from .result import AsyncScalarResult as AsyncScalarResult +from .result import AsyncTupleResult as AsyncTupleResult +from .scoping import async_scoped_session as async_scoped_session +from .session import async_object_session as async_object_session +from .session import async_session as async_session +from .session import async_sessionmaker as async_sessionmaker +from .session import AsyncAttrs as AsyncAttrs +from .session import AsyncSession as AsyncSession +from .session import AsyncSessionTransaction as AsyncSessionTransaction +from .session import close_all_sessions as close_all_sessions diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..a647d42 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..785ef03 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/engine.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/engine.cpython-311.pyc new file mode 100644 index 0000000..4326d1c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/engine.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/exc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/exc.cpython-311.pyc new file mode 100644 index 0000000..5a71fac Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/exc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/result.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/result.cpython-311.pyc new file mode 100644 index 0000000..c6ae583 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/result.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-311.pyc new file mode 100644 index 0000000..8839d42 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/session.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/session.cpython-311.pyc new file mode 100644 index 0000000..0c267a0 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/__pycache__/session.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py new file mode 100644 index 0000000..9899364 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/base.py @@ -0,0 +1,279 @@ +# ext/asyncio/base.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import abc +import functools +from typing import Any +from typing import AsyncGenerator +from typing import AsyncIterator +from typing import Awaitable +from typing import Callable +from typing import ClassVar +from typing import Dict +from typing import Generator +from typing import Generic +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TypeVar +import weakref + +from . import exc as async_exc +from ... import util +from ...util.typing import Literal +from ...util.typing import Self + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + + +_PT = TypeVar("_PT", bound=Any) + + +class ReversibleProxy(Generic[_PT]): + _proxy_objects: ClassVar[ + Dict[weakref.ref[Any], weakref.ref[ReversibleProxy[Any]]] + ] = {} + __slots__ = ("__weakref__",) + + @overload + def _assign_proxied(self, target: _PT) -> _PT: ... + + @overload + def _assign_proxied(self, target: None) -> None: ... + + def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: + if target is not None: + target_ref: weakref.ref[_PT] = weakref.ref( + target, ReversibleProxy._target_gced + ) + proxy_ref = weakref.ref( + self, + functools.partial(ReversibleProxy._target_gced, target_ref), + ) + ReversibleProxy._proxy_objects[target_ref] = proxy_ref + + return target + + @classmethod + def _target_gced( + cls, + ref: weakref.ref[_PT], + proxy_ref: Optional[weakref.ref[Self]] = None, # noqa: U100 + ) -> None: + cls._proxy_objects.pop(ref, None) + + @classmethod + def _regenerate_proxy_for_target(cls, target: _PT) -> Self: + raise NotImplementedError() + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls, + target: _PT, + regenerate: Literal[True] = ..., + ) -> Self: ... + + @overload + @classmethod + def _retrieve_proxy_for_target( + cls, target: _PT, regenerate: bool = True + ) -> Optional[Self]: ... + + @classmethod + def _retrieve_proxy_for_target( + cls, target: _PT, regenerate: bool = True + ) -> Optional[Self]: + try: + proxy_ref = cls._proxy_objects[weakref.ref(target)] + except KeyError: + pass + else: + proxy = proxy_ref() + if proxy is not None: + return proxy # type: ignore + + if regenerate: + return cls._regenerate_proxy_for_target(target) + else: + return None + + +class StartableContext(Awaitable[_T_co], abc.ABC): + __slots__ = () + + @abc.abstractmethod + async def start(self, is_ctxmanager: bool = False) -> _T_co: + raise NotImplementedError() + + def __await__(self) -> Generator[Any, Any, _T_co]: + return self.start().__await__() + + async def __aenter__(self) -> _T_co: + return await self.start(is_ctxmanager=True) + + @abc.abstractmethod + async def __aexit__( + self, type_: Any, value: Any, traceback: Any + ) -> Optional[bool]: + pass + + def _raise_for_not_started(self) -> NoReturn: + raise async_exc.AsyncContextNotStarted( + "%s context has not been started and object has not been awaited." + % (self.__class__.__name__) + ) + + +class GeneratorStartableContext(StartableContext[_T_co]): + __slots__ = ("gen",) + + gen: AsyncGenerator[_T_co, Any] + + def __init__( + self, + func: Callable[..., AsyncIterator[_T_co]], + args: Tuple[Any, ...], + kwds: Dict[str, Any], + ): + self.gen = func(*args, **kwds) # type: ignore + + async def start(self, is_ctxmanager: bool = False) -> _T_co: + try: + start_value = await util.anext_(self.gen) + except StopAsyncIteration: + raise RuntimeError("generator didn't yield") from None + + # if not a context manager, then interrupt the generator, don't + # let it complete. this step is technically not needed, as the + # generator will close in any case at gc time. not clear if having + # this here is a good idea or not (though it helps for clarity IMO) + if not is_ctxmanager: + await self.gen.aclose() + + return start_value + + async def __aexit__( + self, typ: Any, value: Any, traceback: Any + ) -> Optional[bool]: + # vendored from contextlib.py + if typ is None: + try: + await util.anext_(self.gen) + except StopAsyncIteration: + return False + else: + raise RuntimeError("generator didn't stop") + else: + if value is None: + # Need to force instantiation so we can reliably + # tell if we get the same exception back + value = typ() + try: + await self.gen.athrow(value) + except StopAsyncIteration as exc: + # Suppress StopIteration *unless* it's the same exception that + # was passed to throw(). This prevents a StopIteration + # raised inside the "with" statement from being suppressed. + return exc is not value + except RuntimeError as exc: + # Don't re-raise the passed in exception. (issue27122) + if exc is value: + return False + # Avoid suppressing if a Stop(Async)Iteration exception + # was passed to athrow() and later wrapped into a RuntimeError + # (see PEP 479 for sync generators; async generators also + # have this behavior). But do this only if the exception + # wrapped + # by the RuntimeError is actully Stop(Async)Iteration (see + # issue29692). + if ( + isinstance(value, (StopIteration, StopAsyncIteration)) + and exc.__cause__ is value + ): + return False + raise + except BaseException as exc: + # only re-raise if it's *not* the exception that was + # passed to throw(), because __exit__() must not raise + # an exception unless __exit__() itself failed. But throw() + # has to raise the exception to signal propagation, so this + # fixes the impedance mismatch between the throw() protocol + # and the __exit__() protocol. + if exc is not value: + raise + return False + raise RuntimeError("generator didn't stop after athrow()") + + +def asyncstartablecontext( + func: Callable[..., AsyncIterator[_T_co]] +) -> Callable[..., GeneratorStartableContext[_T_co]]: + """@asyncstartablecontext decorator. + + the decorated function can be called either as ``async with fn()``, **or** + ``await fn()``. This is decidedly different from what + ``@contextlib.asynccontextmanager`` supports, and the usage pattern + is different as well. + + Typical usage:: + + @asyncstartablecontext + async def some_async_generator(): + + try: + yield + except GeneratorExit: + # return value was awaited, no context manager is present + # and caller will .close() the resource explicitly + pass + else: + + + + Above, ``GeneratorExit`` is caught if the function were used as an + ``await``. In this case, it's essential that the cleanup does **not** + occur, so there should not be a ``finally`` block. + + If ``GeneratorExit`` is not invoked, this means we're in ``__aexit__`` + and we were invoked as a context manager, and cleanup should proceed. + + + """ + + @functools.wraps(func) + def helper(*args: Any, **kwds: Any) -> GeneratorStartableContext[_T_co]: + return GeneratorStartableContext(func, args, kwds) + + return helper + + +class ProxyComparable(ReversibleProxy[_PT]): + __slots__ = () + + @util.ro_non_memoized_property + def _proxied(self) -> _PT: + raise NotImplementedError() + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, self.__class__) + and self._proxied == other._proxied + ) + + def __ne__(self, other: Any) -> bool: + return ( + not isinstance(other, self.__class__) + or self._proxied != other._proxied + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/engine.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/engine.py new file mode 100644 index 0000000..8fc8e96 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/engine.py @@ -0,0 +1,1466 @@ +# ext/asyncio/engine.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import asyncio +import contextlib +from typing import Any +from typing import AsyncIterator +from typing import Callable +from typing import Dict +from typing import Generator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import exc as async_exc +from .base import asyncstartablecontext +from .base import GeneratorStartableContext +from .base import ProxyComparable +from .base import StartableContext +from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult +from ... import exc +from ... import inspection +from ... import util +from ...engine import Connection +from ...engine import create_engine as _create_engine +from ...engine import create_pool_from_url as _create_pool_from_url +from ...engine import Engine +from ...engine.base import NestedTransaction +from ...engine.base import Transaction +from ...exc import ArgumentError +from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec + +if TYPE_CHECKING: + from ...engine.cursor import CursorResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import _CoreSingleExecuteParams + from ...engine.interfaces import _DBAPIAnyExecuteParams + from ...engine.interfaces import _ExecuteOptions + from ...engine.interfaces import CompiledCacheType + from ...engine.interfaces import CoreExecuteOptionsParameter + from ...engine.interfaces import Dialect + from ...engine.interfaces import IsolationLevel + from ...engine.interfaces import SchemaTranslateMapType + from ...engine.result import ScalarResult + from ...engine.url import URL + from ...pool import Pool + from ...pool import PoolProxiedConnection + from ...sql._typing import _InfoType + from ...sql.base import Executable + from ...sql.selectable import TypedReturnsRows + +_P = ParamSpec("_P") +_T = TypeVar("_T", bound=Any) + + +def create_async_engine(url: Union[str, URL], **kw: Any) -> AsyncEngine: + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_engine` are mostly + identical to those passed to the :func:`_sa.create_engine` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 1.4 + + :param async_creator: an async callable which returns a driver-level + asyncio connection. If given, the function should take no arguments, + and return a new asyncio connection from the underlying asyncio + database driver; the connection will be wrapped in the appropriate + structures to be used with the :class:`.AsyncEngine`. Note that the + parameters specified in the URL are not applied here, and the creator + function should use its own connection parameters. + + This parameter is the asyncio equivalent of the + :paramref:`_sa.create_engine.creator` parameter of the + :func:`_sa.create_engine` function. + + .. versionadded:: 2.0.16 + + """ + + if kw.get("server_side_cursors", False): + raise async_exc.AsyncMethodRequired( + "Can't set server_side_cursors for async engine globally; " + "use the connection.stream() method for an async " + "streaming result set" + ) + kw["_is_async"] = True + async_creator = kw.pop("async_creator", None) + if async_creator: + if kw.get("creator", None): + raise ArgumentError( + "Can only specify one of 'async_creator' or 'creator', " + "not both." + ) + + def creator() -> Any: + # note that to send adapted arguments like + # prepared_statement_cache_size, user would use + # "creator" and emulate this form here + return sync_engine.dialect.dbapi.connect( # type: ignore + async_creator_fn=async_creator + ) + + kw["creator"] = creator + sync_engine = _create_engine(url, **kw) + return AsyncEngine(sync_engine) + + +def async_engine_from_config( + configuration: Dict[str, Any], prefix: str = "sqlalchemy.", **kwargs: Any +) -> AsyncEngine: + """Create a new AsyncEngine instance using a configuration dictionary. + + This function is analogous to the :func:`_sa.engine_from_config` function + in SQLAlchemy Core, except that the requested dialect must be an + asyncio-compatible dialect such as :ref:`dialect-postgresql-asyncpg`. + The argument signature of the function is identical to that + of :func:`_sa.engine_from_config`. + + .. versionadded:: 1.4.29 + + """ + options = { + key[len(prefix) :]: value + for key, value in configuration.items() + if key.startswith(prefix) + } + options["_coerce_config"] = True + options.update(kwargs) + url = options.pop("url") + return create_async_engine(url, **options) + + +def create_async_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: + """Create a new async engine instance. + + Arguments passed to :func:`_asyncio.create_async_pool_from_url` are mostly + identical to those passed to the :func:`_sa.create_pool_from_url` function. + The specified dialect must be an asyncio-compatible dialect + such as :ref:`dialect-postgresql-asyncpg`. + + .. versionadded:: 2.0.10 + + """ + kwargs["_is_async"] = True + return _create_pool_from_url(url, **kwargs) + + +class AsyncConnectable: + __slots__ = "_slots_dispatch", "__weakref__" + + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncEngine.sync_engine or " + "AsyncConnection.sync_connection attributes." + ) + + +@util.create_proxy_methods( + Connection, + ":class:`_engine.Connection`", + ":class:`_asyncio.AsyncConnection`", + classmethods=[], + methods=[], + attributes=[ + "closed", + "invalidated", + "dialect", + "default_isolation_level", + ], +) +class AsyncConnection( + ProxyComparable[Connection], + StartableContext["AsyncConnection"], + AsyncConnectable, +): + """An asyncio proxy for a :class:`_engine.Connection`. + + :class:`_asyncio.AsyncConnection` is acquired using the + :meth:`_asyncio.AsyncEngine.connect` + method of :class:`_asyncio.AsyncEngine`:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + async with engine.connect() as conn: + result = await conn.execute(select(table)) + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncConnection is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncConnection that matches this one given only the + # "sync" elements. + __slots__ = ( + "engine", + "sync_engine", + "sync_connection", + ) + + def __init__( + self, + async_engine: AsyncEngine, + sync_connection: Optional[Connection] = None, + ): + self.engine = async_engine + self.sync_engine = async_engine.sync_engine + self.sync_connection = self._assign_proxied(sync_connection) + + sync_connection: Optional[Connection] + """Reference to the sync-style :class:`_engine.Connection` this + :class:`_asyncio.AsyncConnection` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncConnection` is associated with via its underlying + :class:`_engine.Connection`. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + @classmethod + def _regenerate_proxy_for_target( + cls, target: Connection + ) -> AsyncConnection: + return AsyncConnection( + AsyncEngine._retrieve_proxy_for_target(target.engine), target + ) + + async def start( + self, is_ctxmanager: bool = False # noqa: U100 + ) -> AsyncConnection: + """Start this :class:`_asyncio.AsyncConnection` object's context + outside of using a Python ``with:`` block. + + """ + if self.sync_connection: + raise exc.InvalidRequestError("connection is already started") + self.sync_connection = self._assign_proxied( + await greenlet_spawn(self.sync_engine.connect) + ) + return self + + @property + def connection(self) -> NoReturn: + """Not implemented for async; call + :meth:`_asyncio.AsyncConnection.get_raw_connection`. + """ + raise exc.InvalidRequestError( + "AsyncConnection.connection accessor is not implemented as the " + "attribute may need to reconnect on an invalidated connection. " + "Use the get_raw_connection() method." + ) + + async def get_raw_connection(self) -> PoolProxiedConnection: + """Return the pooled DBAPI-level connection in use by this + :class:`_asyncio.AsyncConnection`. + + This is a SQLAlchemy connection-pool proxied connection + which then has the attribute + :attr:`_pool._ConnectionFairy.driver_connection` that refers to the + actual driver connection. Its + :attr:`_pool._ConnectionFairy.dbapi_connection` refers instead + to an :class:`_engine.AdaptedConnection` instance that + adapts the driver connection to the DBAPI protocol. + + """ + + return await greenlet_spawn(getattr, self._proxied, "connection") + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + """Return the :attr:`_engine.Connection.info` dictionary of the + underlying :class:`_engine.Connection`. + + This dictionary is freely writable for user-defined state to be + associated with the database connection. + + This attribute is only available if the :class:`.AsyncConnection` is + currently connected. If the :attr:`.AsyncConnection.closed` attribute + is ``True``, then accessing this attribute will raise + :class:`.ResourceClosedError`. + + .. versionadded:: 1.4.0b2 + + """ + return self._proxied.info + + @util.ro_non_memoized_property + def _proxied(self) -> Connection: + if not self.sync_connection: + self._raise_for_not_started() + return self.sync_connection + + def begin(self) -> AsyncTransaction: + """Begin a transaction prior to autobegin occurring.""" + assert self._proxied + return AsyncTransaction(self) + + def begin_nested(self) -> AsyncTransaction: + """Begin a nested transaction and return a transaction handle.""" + assert self._proxied + return AsyncTransaction(self, nested=True) + + async def invalidate( + self, exception: Optional[BaseException] = None + ) -> None: + """Invalidate the underlying DBAPI connection associated with + this :class:`_engine.Connection`. + + See the method :meth:`_engine.Connection.invalidate` for full + detail on this method. + + """ + + return await greenlet_spawn( + self._proxied.invalidate, exception=exception + ) + + async def get_isolation_level(self) -> IsolationLevel: + return await greenlet_spawn(self._proxied.get_isolation_level) + + def in_transaction(self) -> bool: + """Return True if a transaction is in progress.""" + + return self._proxied.in_transaction() + + def in_nested_transaction(self) -> bool: + """Return True if a transaction is in progress. + + .. versionadded:: 1.4.0b2 + + """ + return self._proxied.in_nested_transaction() + + def get_transaction(self) -> Optional[AsyncTransaction]: + """Return an :class:`.AsyncTransaction` representing the current + transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_transaction` method to get the current + :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + + trans = self._proxied.get_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self) -> Optional[AsyncTransaction]: + """Return an :class:`.AsyncTransaction` representing the current + nested (savepoint) transaction, if any. + + This makes use of the underlying synchronous connection's + :meth:`_engine.Connection.get_nested_transaction` method to get the + current :class:`_engine.Transaction`, which is then proxied in a new + :class:`.AsyncTransaction` object. + + .. versionadded:: 1.4.0b2 + + """ + + trans = self._proxied.get_nested_transaction() + if trans is not None: + return AsyncTransaction._retrieve_proxy_for_target(trans) + else: + return None + + @overload + async def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + preserve_rowcount: bool = False, + **opt: Any, + ) -> AsyncConnection: ... + + @overload + async def execution_options(self, **opt: Any) -> AsyncConnection: ... + + async def execution_options(self, **opt: Any) -> AsyncConnection: + r"""Set non-SQL options for the connection which take effect + during execution. + + This returns this :class:`_asyncio.AsyncConnection` object with + the new options added. + + See :meth:`_engine.Connection.execution_options` for full details + on this method. + + """ + + conn = self._proxied + c2 = await greenlet_spawn(conn.execution_options, **opt) + assert c2 is conn + return self + + async def commit(self) -> None: + """Commit the transaction that is currently in progress. + + This method commits the current transaction if one has been started. + If no transaction was started, the method has no effect, assuming + the connection is in a non-invalidated state. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + """ + await greenlet_spawn(self._proxied.commit) + + async def rollback(self) -> None: + """Roll back the transaction that is currently in progress. + + This method rolls back the current transaction if one has been started. + If no transaction was started, the method has no effect. If a + transaction was started and the connection is in an invalidated state, + the transaction is cleared using this method. + + A transaction is begun on a :class:`_engine.Connection` automatically + whenever a statement is first executed, or when the + :meth:`_engine.Connection.begin` method is called. + + + """ + await greenlet_spawn(self._proxied.rollback) + + async def close(self) -> None: + """Close this :class:`_asyncio.AsyncConnection`. + + This has the effect of also rolling back the transaction if one + is in place. + + """ + await greenlet_spawn(self._proxied.close) + + async def aclose(self) -> None: + """A synonym for :meth:`_asyncio.AsyncConnection.close`. + + The :meth:`_asyncio.AsyncConnection.aclose` name is specifically + to support the Python standard library ``@contextlib.aclosing`` + context manager function. + + .. versionadded:: 2.0.20 + + """ + await self.close() + + async def exec_driver_sql( + self, + statement: str, + parameters: Optional[_DBAPIAnyExecuteParams] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a driver-level SQL string and return buffered + :class:`_engine.Result`. + + """ + + result = await greenlet_spawn( + self._proxied.exec_driver_sql, + statement, + parameters, + execution_options, + _require_await=True, + ) + + return await _ensure_sync_result(result, self.exec_driver_sql) + + @overload + def stream( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncResult[_T]]: ... + + @overload + def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncResult[Any]]: ... + + @asyncstartablecontext + async def stream( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> AsyncIterator[AsyncResult[Any]]: + """Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncResult` object. + + E.g.:: + + result = await conn.stream(stmt): + async for row in result: + print(f"{row}") + + The :meth:`.AsyncConnection.stream` + method supports optional context manager use against the + :class:`.AsyncResult` object, as in:: + + async with conn.stream(stmt) as result: + async for row in result: + print(f"{row}") + + In the above pattern, the :meth:`.AsyncResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncResult` object. + + .. seealso:: + + :meth:`.AsyncConnection.stream_scalars` + + """ + if not self.dialect.supports_server_side_cursors: + raise exc.InvalidRequestError( + "Cant use `stream` or `stream_scalars` with the current " + "dialect since it does not support server side cursors." + ) + + result = await greenlet_spawn( + self._proxied.execute, + statement, + parameters, + execution_options=util.EMPTY_DICT.merge_with( + execution_options, {"stream_results": True} + ), + _require_await=True, + ) + assert result.context._is_server_side + ar = AsyncResult(result) + try: + yield ar + except GeneratorExit: + pass + else: + task = asyncio.create_task(ar.close()) + await asyncio.shield(task) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[_T]: ... + + @overload + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: ... + + async def execute( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> CursorResult[Any]: + r"""Executes a SQL statement construct and return a buffered + :class:`_engine.Result`. + + :param object: The statement to be executed. This is always + an object that is in both the :class:`_expression.ClauseElement` and + :class:`_expression.Executable` hierarchies, including: + + * :class:`_expression.Select` + * :class:`_expression.Insert`, :class:`_expression.Update`, + :class:`_expression.Delete` + * :class:`_expression.TextClause` and + :class:`_expression.TextualSelect` + * :class:`_schema.DDL` and objects which inherit from + :class:`_schema.ExecutableDDLElement` + + :param parameters: parameters which will be bound into the statement. + This may be either a dictionary of parameter names to values, + or a mutable sequence (e.g. a list) of dictionaries. When a + list of dictionaries is passed, the underlying statement execution + will make use of the DBAPI ``cursor.executemany()`` method. + When a single dictionary is passed, the DBAPI ``cursor.execute()`` + method will be used. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`. + + :return: a :class:`_engine.Result` object. + + """ + result = await greenlet_spawn( + self._proxied.execute, + statement, + parameters, + execution_options=execution_options, + _require_await=True, + ) + return await _ensure_sync_result(result, self.execute) + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Optional[_T]: ... + + @overload + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: ... + + async def scalar( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Any: + r"""Executes a SQL statement construct and returns a scalar object. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalar` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a scalar Python value representing the first column of the + first row returned. + + """ + result = await self.execute( + statement, parameters, execution_options=execution_options + ) + return result.scalar() + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[_T]: ... + + @overload + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: ... + + async def scalars( + self, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> ScalarResult[Any]: + r"""Executes a SQL statement construct and returns a scalar objects. + + This method is shorthand for invoking the + :meth:`_engine.Result.scalars` method after invoking the + :meth:`_engine.Connection.execute` method. Parameters are equivalent. + + :return: a :class:`_engine.ScalarResult` object. + + .. versionadded:: 1.4.24 + + """ + result = await self.execute( + statement, parameters, execution_options=execution_options + ) + return result.scalars() + + @overload + def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... + + @overload + def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... + + @asyncstartablecontext + async def stream_scalars( + self, + statement: Executable, + parameters: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> AsyncIterator[AsyncScalarResult[Any]]: + r"""Execute a statement and return an awaitable yielding a + :class:`_asyncio.AsyncScalarResult` object. + + E.g.:: + + result = await conn.stream_scalars(stmt) + async for scalar in result: + print(f"{scalar}") + + This method is shorthand for invoking the + :meth:`_engine.AsyncResult.scalars` method after invoking the + :meth:`_engine.Connection.stream` method. Parameters are equivalent. + + The :meth:`.AsyncConnection.stream_scalars` + method supports optional context manager use against the + :class:`.AsyncScalarResult` object, as in:: + + async with conn.stream_scalars(stmt) as result: + async for scalar in result: + print(f"{scalar}") + + In the above pattern, the :meth:`.AsyncScalarResult.close` method is + invoked unconditionally, even if the iterator is interrupted by an + exception throw. Context manager use remains optional, however, + and the function may be called in either an ``async with fn():`` or + ``await fn()`` style. + + .. versionadded:: 2.0.0b3 added context manager support + + :return: an awaitable object that will yield an + :class:`_asyncio.AsyncScalarResult` object. + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`.AsyncConnection.stream` + + """ + + async with self.stream( + statement, parameters, execution_options=execution_options + ) as result: + yield result.scalars() + + async def run_sync( + self, + fn: Callable[Concatenate[Connection, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, + ) -> _T: + """Invoke the given synchronous (i.e. not async) callable, + passing a synchronous-style :class:`_engine.Connection` as the first + argument. + + This method allows traditional synchronous SQLAlchemy functions to + run within the context of an asyncio application. + + E.g.:: + + def do_something_with_core(conn: Connection, arg1: int, arg2: str) -> str: + '''A synchronous function that does not require awaiting + + :param conn: a Core SQLAlchemy Connection, used synchronously + + :return: an optional return value is supported + + ''' + conn.execute( + some_table.insert().values(int_col=arg1, str_col=arg2) + ) + return "success" + + + async def do_something_async(async_engine: AsyncEngine) -> None: + '''an async function that uses awaiting''' + + async with async_engine.begin() as async_conn: + # run do_something_with_core() with a sync-style + # Connection, proxied into an awaitable + return_code = await async_conn.run_sync(do_something_with_core, 5, "strval") + print(return_code) + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + The most rudimentary use of :meth:`.AsyncConnection.run_sync` is to + invoke methods such as :meth:`_schema.MetaData.create_all`, given + an :class:`.AsyncConnection` that needs to be provided to + :meth:`_schema.MetaData.create_all` as a :class:`_engine.Connection` + object:: + + # run metadata.create_all(conn) with a sync-style Connection, + # proxied into an awaitable + with async_engine.begin() as conn: + await conn.run_sync(metadata.create_all) + + .. note:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :meth:`.AsyncSession.run_sync` + + :ref:`session_run_sync` + + """ # noqa: E501 + + return await greenlet_spawn( + fn, self._proxied, *arg, _require_await=False, **kw + ) + + def __await__(self) -> Generator[Any, None, AsyncConnection]: + return self.start().__await__() + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + task = asyncio.create_task(self.close()) + await asyncio.shield(task) + + # START PROXY METHODS AsyncConnection + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + @property + def closed(self) -> Any: + r"""Return True if this connection is closed. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + """ # noqa: E501 + + return self._proxied.closed + + @property + def invalidated(self) -> Any: + r"""Return True if this connection was invalidated. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + This does not indicate whether or not the connection was + invalidated at the pool level, however + + + """ # noqa: E501 + + return self._proxied.invalidated + + @property + def dialect(self) -> Dialect: + r"""Proxy for the :attr:`_engine.Connection.dialect` attribute + on behalf of the :class:`_asyncio.AsyncConnection` class. + + """ # noqa: E501 + + return self._proxied.dialect + + @dialect.setter + def dialect(self, attr: Dialect) -> None: + self._proxied.dialect = attr + + @property + def default_isolation_level(self) -> Any: + r"""The initial-connection time isolation level associated with the + :class:`_engine.Dialect` in use. + + .. container:: class_bases + + Proxied for the :class:`_engine.Connection` class + on behalf of the :class:`_asyncio.AsyncConnection` class. + + This value is independent of the + :paramref:`.Connection.execution_options.isolation_level` and + :paramref:`.Engine.execution_options.isolation_level` execution + options, and is determined by the :class:`_engine.Dialect` when the + first connection is created, by performing a SQL query against the + database for the current isolation level before any additional commands + have been emitted. + + Calling this accessor does not invoke any new SQL queries. + + .. seealso:: + + :meth:`_engine.Connection.get_isolation_level` + - view current actual isolation level + + :paramref:`_sa.create_engine.isolation_level` + - set per :class:`_engine.Engine` isolation level + + :paramref:`.Connection.execution_options.isolation_level` + - set per :class:`_engine.Connection` isolation level + + + """ # noqa: E501 + + return self._proxied.default_isolation_level + + # END PROXY METHODS AsyncConnection + + +@util.create_proxy_methods( + Engine, + ":class:`_engine.Engine`", + ":class:`_asyncio.AsyncEngine`", + classmethods=[], + methods=[ + "clear_compiled_cache", + "update_execution_options", + "get_execution_options", + ], + attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"], +) +class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): + """An asyncio proxy for a :class:`_engine.Engine`. + + :class:`_asyncio.AsyncEngine` is acquired using the + :func:`_asyncio.create_async_engine` function:: + + from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") + + .. versionadded:: 1.4 + + """ # noqa + + # AsyncEngine is a thin proxy; no state should be added here + # that is not retrievable from the "sync" engine / connection, e.g. + # current transaction, info, etc. It should be possible to + # create a new AsyncEngine that matches this one given only the + # "sync" elements. + __slots__ = "sync_engine" + + _connection_cls: Type[AsyncConnection] = AsyncConnection + + sync_engine: Engine + """Reference to the sync-style :class:`_engine.Engine` this + :class:`_asyncio.AsyncEngine` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + """ + + def __init__(self, sync_engine: Engine): + if not sync_engine.dialect.is_async: + raise exc.InvalidRequestError( + "The asyncio extension requires an async driver to be used. " + f"The loaded {sync_engine.dialect.driver!r} is not async." + ) + self.sync_engine = self._assign_proxied(sync_engine) + + @util.ro_non_memoized_property + def _proxied(self) -> Engine: + return self.sync_engine + + @classmethod + def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: + return AsyncEngine(target) + + @contextlib.asynccontextmanager + async def begin(self) -> AsyncIterator[AsyncConnection]: + """Return a context manager which when entered will deliver an + :class:`_asyncio.AsyncConnection` with an + :class:`_asyncio.AsyncTransaction` established. + + E.g.:: + + async with async_engine.begin() as conn: + await conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) + await conn.execute(text("my_special_procedure(5)")) + + + """ + conn = self.connect() + + async with conn: + async with conn.begin(): + yield conn + + def connect(self) -> AsyncConnection: + """Return an :class:`_asyncio.AsyncConnection` object. + + The :class:`_asyncio.AsyncConnection` will procure a database + connection from the underlying connection pool when it is entered + as an async context manager:: + + async with async_engine.connect() as conn: + result = await conn.execute(select(user_table)) + + The :class:`_asyncio.AsyncConnection` may also be started outside of a + context manager by invoking its :meth:`_asyncio.AsyncConnection.start` + method. + + """ + + return self._connection_cls(self) + + async def raw_connection(self) -> PoolProxiedConnection: + """Return a "raw" DBAPI connection from the connection pool. + + .. seealso:: + + :ref:`dbapi_connections` + + """ + return await greenlet_spawn(self.sync_engine.raw_connection) + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + **opt: Any, + ) -> AsyncEngine: ... + + @overload + def execution_options(self, **opt: Any) -> AsyncEngine: ... + + def execution_options(self, **opt: Any) -> AsyncEngine: + """Return a new :class:`_asyncio.AsyncEngine` that will provide + :class:`_asyncio.AsyncConnection` objects with the given execution + options. + + Proxied from :meth:`_engine.Engine.execution_options`. See that + method for details. + + """ + + return AsyncEngine(self.sync_engine.execution_options(**opt)) + + async def dispose(self, close: bool = True) -> None: + """Dispose of the connection pool used by this + :class:`_asyncio.AsyncEngine`. + + :param close: if left at its default of ``True``, has the + effect of fully closing all **currently checked in** + database connections. Connections that are still checked out + will **not** be closed, however they will no longer be associated + with this :class:`_engine.Engine`, + so when they are closed individually, eventually the + :class:`_pool.Pool` which they are associated with will + be garbage collected and they will be closed out fully, if + not already closed on checkin. + + If set to ``False``, the previous connection pool is de-referenced, + and otherwise not touched in any way. + + .. seealso:: + + :meth:`_engine.Engine.dispose` + + """ + + await greenlet_spawn(self.sync_engine.dispose, close=close) + + # START PROXY METHODS AsyncEngine + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def clear_compiled_cache(self) -> None: + r"""Clear the compiled cache associated with the dialect. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + This applies **only** to the built-in cache that is established + via the :paramref:`_engine.create_engine.query_cache_size` parameter. + It will not impact any dictionary caches that were passed via the + :paramref:`.Connection.execution_options.compiled_cache` parameter. + + .. versionadded:: 1.4 + + + """ # noqa: E501 + + return self._proxied.clear_compiled_cache() + + def update_execution_options(self, **opt: Any) -> None: + r"""Update the default execution_options dictionary + of this :class:`_engine.Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + The given keys/values in \**opt are added to the + default execution options that will be used for + all connections. The initial contents of this dictionary + can be sent via the ``execution_options`` parameter + to :func:`_sa.create_engine`. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :meth:`_engine.Engine.execution_options` + + + """ # noqa: E501 + + return self._proxied.update_execution_options(**opt) + + def get_execution_options(self) -> _ExecuteOptions: + r"""Get the non-SQL options which will take effect during execution. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class on + behalf of the :class:`_asyncio.AsyncEngine` class. + + .. versionadded: 1.3 + + .. seealso:: + + :meth:`_engine.Engine.execution_options` + + """ # noqa: E501 + + return self._proxied.get_execution_options() + + @property + def url(self) -> URL: + r"""Proxy for the :attr:`_engine.Engine.url` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.url + + @url.setter + def url(self, attr: URL) -> None: + self._proxied.url = attr + + @property + def pool(self) -> Pool: + r"""Proxy for the :attr:`_engine.Engine.pool` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.pool + + @pool.setter + def pool(self, attr: Pool) -> None: + self._proxied.pool = attr + + @property + def dialect(self) -> Dialect: + r"""Proxy for the :attr:`_engine.Engine.dialect` attribute + on behalf of the :class:`_asyncio.AsyncEngine` class. + + """ # noqa: E501 + + return self._proxied.dialect + + @dialect.setter + def dialect(self, attr: Dialect) -> None: + self._proxied.dialect = attr + + @property + def engine(self) -> Any: + r"""Returns this :class:`.Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + Used for legacy schemes that accept :class:`.Connection` / + :class:`.Engine` objects within the same variable. + + + """ # noqa: E501 + + return self._proxied.engine + + @property + def name(self) -> Any: + r"""String name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + + """ # noqa: E501 + + return self._proxied.name + + @property + def driver(self) -> Any: + r"""Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect` + in use by this :class:`Engine`. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + + """ # noqa: E501 + + return self._proxied.driver + + @property + def echo(self) -> Any: + r"""When ``True``, enable log output for this element. + + .. container:: class_bases + + Proxied for the :class:`_engine.Engine` class + on behalf of the :class:`_asyncio.AsyncEngine` class. + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. + + """ # noqa: E501 + + return self._proxied.echo + + @echo.setter + def echo(self, attr: Any) -> None: + self._proxied.echo = attr + + # END PROXY METHODS AsyncEngine + + +class AsyncTransaction( + ProxyComparable[Transaction], StartableContext["AsyncTransaction"] +): + """An asyncio proxy for a :class:`_engine.Transaction`.""" + + __slots__ = ("connection", "sync_transaction", "nested") + + sync_transaction: Optional[Transaction] + connection: AsyncConnection + nested: bool + + def __init__(self, connection: AsyncConnection, nested: bool = False): + self.connection = connection + self.sync_transaction = None + self.nested = nested + + @classmethod + def _regenerate_proxy_for_target( + cls, target: Transaction + ) -> AsyncTransaction: + sync_connection = target.connection + sync_transaction = target + nested = isinstance(target, NestedTransaction) + + async_connection = AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + assert async_connection is not None + + obj = cls.__new__(cls) + obj.connection = async_connection + obj.sync_transaction = obj._assign_proxied(sync_transaction) + obj.nested = nested + return obj + + @util.ro_non_memoized_property + def _proxied(self) -> Transaction: + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + @property + def is_valid(self) -> bool: + return self._proxied.is_valid + + @property + def is_active(self) -> bool: + return self._proxied.is_active + + async def close(self) -> None: + """Close this :class:`.AsyncTransaction`. + + If this transaction is the base transaction in a begin/commit + nesting, the transaction will rollback(). Otherwise, the + method returns. + + This is used to cancel a Transaction without affecting the scope of + an enclosing transaction. + + """ + await greenlet_spawn(self._proxied.close) + + async def rollback(self) -> None: + """Roll back this :class:`.AsyncTransaction`.""" + await greenlet_spawn(self._proxied.rollback) + + async def commit(self) -> None: + """Commit this :class:`.AsyncTransaction`.""" + + await greenlet_spawn(self._proxied.commit) + + async def start(self, is_ctxmanager: bool = False) -> AsyncTransaction: + """Start this :class:`_asyncio.AsyncTransaction` object's context + outside of using a Python ``with:`` block. + + """ + + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.connection._proxied.begin_nested + if self.nested + else self.connection._proxied.begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn(self._proxied.__exit__, type_, value, traceback) + + +@overload +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: ... + + +@overload +def _get_sync_engine_or_connection( + async_engine: AsyncConnection, +) -> Connection: ... + + +def _get_sync_engine_or_connection( + async_engine: Union[AsyncEngine, AsyncConnection] +) -> Union[Engine, Connection]: + if isinstance(async_engine, AsyncConnection): + return async_engine._proxied + + try: + return async_engine.sync_engine + except AttributeError as e: + raise exc.ArgumentError( + "AsyncEngine expected, got %r" % async_engine + ) from e + + +@inspection._inspects(AsyncConnection) +def _no_insp_for_async_conn_yet( + subject: AsyncConnection, # noqa: U100 +) -> NoReturn: + raise exc.NoInspectionAvailable( + "Inspection on an AsyncConnection is currently not supported. " + "Please use ``run_sync`` to pass a callable where it's possible " + "to call ``inspect`` on the passed connection.", + code="xd3s", + ) + + +@inspection._inspects(AsyncEngine) +def _no_insp_for_async_engine_xyet( + subject: AsyncEngine, # noqa: U100 +) -> NoReturn: + raise exc.NoInspectionAvailable( + "Inspection on an AsyncEngine is currently not supported. " + "Please obtain a connection then use ``conn.run_sync`` to pass a " + "callable where it's possible to call ``inspect`` on the " + "passed connection.", + code="xd3s", + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/exc.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/exc.py new file mode 100644 index 0000000..1cf6f36 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/exc.py @@ -0,0 +1,21 @@ +# ext/asyncio/exc.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from ... import exc + + +class AsyncMethodRequired(exc.InvalidRequestError): + """an API can't be used because its result would not be + compatible with async""" + + +class AsyncContextNotStarted(exc.InvalidRequestError): + """a startable context manager has not been started.""" + + +class AsyncContextAlreadyStarted(exc.InvalidRequestError): + """a startable context manager is already started.""" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/result.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/result.py new file mode 100644 index 0000000..7dcbe32 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/result.py @@ -0,0 +1,961 @@ +# ext/asyncio/result.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import operator +from typing import Any +from typing import AsyncIterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import exc as async_exc +from ... import util +from ...engine import Result +from ...engine.result import _NO_ROW +from ...engine.result import _R +from ...engine.result import _WithKeys +from ...engine.result import FilterResult +from ...engine.result import FrozenResult +from ...engine.result import ResultMetaData +from ...engine.row import Row +from ...engine.row import RowMapping +from ...sql.base import _generative +from ...util.concurrency import greenlet_spawn +from ...util.typing import Literal +from ...util.typing import Self + +if TYPE_CHECKING: + from ...engine import CursorResult + from ...engine.result import _KeyIndexType + from ...engine.result import _UniqueFilterType + +_T = TypeVar("_T", bound=Any) +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + + +class AsyncCommon(FilterResult[_R]): + __slots__ = () + + _real_result: Result[Any] + _metadata: ResultMetaData + + async def close(self) -> None: # type: ignore[override] + """Close this result.""" + + await greenlet_spawn(self._real_result.close) + + @property + def closed(self) -> bool: + """proxies the .closed attribute of the underlying result object, + if any, else raises ``AttributeError``. + + .. versionadded:: 2.0.0b3 + + """ + return self._real_result.closed + + +class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): + """An asyncio wrapper around a :class:`_result.Result` object. + + The :class:`_asyncio.AsyncResult` only applies to statement executions that + use a server-side cursor. It is returned only from the + :meth:`_asyncio.AsyncConnection.stream` and + :meth:`_asyncio.AsyncSession.stream` methods. + + .. note:: As is the case with :class:`_engine.Result`, this object is + used for ORM results returned by :meth:`_asyncio.AsyncSession.execute`, + which can yield instances of ORM mapped objects either individually or + within tuple-like rows. Note that these result objects do not + deduplicate instances or rows automatically as is the case with the + legacy :class:`_orm.Query` object. For in-Python de-duplication of + instances or rows, use the :meth:`_asyncio.AsyncResult.unique` modifier + method. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _real_result: Result[_TP] + + def __init__(self, real_result: Result[_TP]): + self._real_result = real_result + + self._metadata = real_result._metadata + self._unique_filter_state = real_result._unique_filter_state + self._post_creational_filter = None + + # BaseCursorResult pre-generates the "_row_getter". Use that + # if available rather than building a second one + if "_row_getter" in real_result.__dict__: + self._set_memoized_attribute( + "_row_getter", real_result.__dict__["_row_getter"] + ) + + @property + def t(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + The :attr:`_asyncio.AsyncResult.t` attribute is a synonym for + calling the :meth:`_asyncio.AsyncResult.tuples` method. + + .. versionadded:: 2.0 + + """ + return self # type: ignore + + def tuples(self) -> AsyncTupleResult[_TP]: + """Apply a "typed tuple" typing filter to returned rows. + + This method returns the same :class:`_asyncio.AsyncResult` object + at runtime, + however annotates as returning a :class:`_asyncio.AsyncTupleResult` + object that will indicate to :pep:`484` typing tools that plain typed + ``Tuple`` instances are returned rather than rows. This allows + tuple unpacking and ``__getitem__`` access of :class:`_engine.Row` + objects to by typed, for those cases where the statement invoked + itself included typing information. + + .. versionadded:: 2.0 + + :return: the :class:`_result.AsyncTupleResult` type at typing time. + + .. seealso:: + + :attr:`_asyncio.AsyncResult.t` - shorter synonym + + :attr:`_engine.Row.t` - :class:`_engine.Row` version + + """ + + return self # type: ignore + + @_generative + def unique(self, strategy: Optional[_UniqueFilterType] = None) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncResult`. + + Refer to :meth:`_engine.Result.unique` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row. + + Refer to :meth:`_engine.Result.columns` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + return self._column_slices(col_expressions) + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[Row[_TP]]]: + """Iterate through sub-lists of rows of the size given. + + An async iterator is returned:: + + async def scroll_results(connection): + result = await connection.stream(select(users_table)) + + async for partition in result.partitions(100): + print("list of rows: %s" % partition) + + Refer to :meth:`_engine.Result.partitions` in the synchronous + SQLAlchemy API for a complete behavioral description. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[Row[_TP]]: + """A synonym for the :meth:`_asyncio.AsyncResult.all` method. + + .. versionadded:: 2.0 + + """ + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[Row[_TP]]: + """Fetch one row. + + When all rows are exhausted, returns None. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch the first row of a result only, use the + :meth:`_asyncio.AsyncResult.first` method. To iterate through all + rows, iterate the :class:`_asyncio.AsyncResult` object directly. + + :return: a :class:`_engine.Row` object if no filters are applied, + or ``None`` if no rows remain. + + """ + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[Row[_TP]]: + """Fetch many rows. + + When all rows are exhausted, returns an empty list. + + This method is provided for backwards compatibility with + SQLAlchemy 1.x.x. + + To fetch rows in groups, use the + :meth:`._asyncio.AsyncResult.partitions` method. + + :return: a list of :class:`_engine.Row` objects. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.partitions` + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[Row[_TP]]: + """Return all rows in a list. + + Closes the result set after invocation. Subsequent invocations + will return an empty list. + + :return: a list of :class:`_engine.Row` objects. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncResult[_TP]: + return self + + async def __anext__(self) -> Row[_TP]: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[Row[_TP]]: + """Fetch the first row or ``None`` if no row is present. + + Closes the result set and discards remaining rows. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar` method, + or combine :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.first`. + + Additionally, in contrast to the behavior of the legacy ORM + :meth:`_orm.Query.first` method, **no limit is applied** to the + SQL query which was invoked to produce this + :class:`_asyncio.AsyncResult`; + for a DBAPI driver that buffers results in memory before yielding + rows, all rows will be sent to the Python process and all but + the first row will be discarded. + + .. seealso:: + + :ref:`migration_20_unify_select` + + :return: a :class:`_engine.Row` object, or None + if no rows remain. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.scalar` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[Row[_TP]]: + """Return at most one result or raise an exception. + + Returns ``None`` if the result has no rows. + Raises :class:`.MultipleResultsFound` + if multiple rows are returned. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row` or ``None`` if no row + is available. + + :raises: :class:`.MultipleResultsFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one` + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + @overload + async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ... + + @overload + async def scalar_one(self) -> Any: ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, True, True) + + @overload + async def scalar_one_or_none( + self: AsyncResult[Tuple[_T]], + ) -> Optional[_T]: ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one scalar result or ``None``. + + This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and + then :meth:`_asyncio.AsyncResult.one_or_none`. + + .. seealso:: + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalars` + + """ + return await greenlet_spawn(self._only_one_row, True, False, True) + + async def one(self) -> Row[_TP]: + """Return exactly one row or raise an exception. + + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows + would be returned. + + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the + :meth:`_asyncio.AsyncResult.scalar_one` method, or combine + :meth:`_asyncio.AsyncResult.scalars` and + :meth:`_asyncio.AsyncResult.one`. + + .. versionadded:: 1.4 + + :return: The first :class:`_engine.Row`. + + :raises: :class:`.MultipleResultsFound`, :class:`.NoResultFound` + + .. seealso:: + + :meth:`_asyncio.AsyncResult.first` + + :meth:`_asyncio.AsyncResult.one_or_none` + + :meth:`_asyncio.AsyncResult.scalar_one` + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + @overload + async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ... + + @overload + async def scalar(self) -> Any: ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value, or ``None`` if no rows remain. + + """ + return await greenlet_spawn(self._only_one_row, False, False, True) + + async def freeze(self) -> FrozenResult[_TP]: + """Return a callable object that will produce copies of this + :class:`_asyncio.AsyncResult` when invoked. + + The callable object returned is an instance of + :class:`_engine.FrozenResult`. + + This is used for result set caching. The method must be called + on the result when it has been unconsumed, and calling the method + will consume the result fully. When the :class:`_engine.FrozenResult` + is retrieved from a cache, it can be called any number of times where + it will produce a new :class:`_engine.Result` object each time + against its stored set of rows. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - example usage within the + ORM to implement a result-set cache. + + """ + + return await greenlet_spawn(FrozenResult, self) + + @overload + def scalars( + self: AsyncResult[Tuple[_T]], index: Literal[0] + ) -> AsyncScalarResult[_T]: ... + + @overload + def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ... + + @overload + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ... + + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: + """Return an :class:`_asyncio.AsyncScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + Refer to :meth:`_result.Result.scalars` in the synchronous + SQLAlchemy API for a complete behavioral description. + + :param index: integer or row key indicating the column to be fetched + from each row, defaults to ``0`` indicating the first column. + + :return: a new :class:`_asyncio.AsyncScalarResult` filtering object + referring to this :class:`_asyncio.AsyncResult` object. + + """ + return AsyncScalarResult(self._real_result, index) + + def mappings(self) -> AsyncMappingResult: + """Apply a mappings filter to returned rows, returning an instance of + :class:`_asyncio.AsyncMappingResult`. + + When this filter is applied, fetching rows will return + :class:`_engine.RowMapping` objects instead of :class:`_engine.Row` + objects. + + :return: a new :class:`_asyncio.AsyncMappingResult` filtering object + referring to the underlying :class:`_result.Result` object. + + """ + + return AsyncMappingResult(self._real_result) + + +class AsyncScalarResult(AsyncCommon[_R]): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns scalar values + rather than :class:`_row.Row` values. + + The :class:`_asyncio.AsyncScalarResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.scalars` method. + + Refer to the :class:`_result.ScalarResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _generate_rows = False + + def __init__(self, real_result: Result[Any], index: _KeyIndexType): + self._real_result = real_result + + if real_result._source_supports_scalars: + self._metadata = real_result._metadata + self._post_creational_filter = None + else: + self._metadata = real_result._metadata._reduce([index]) + self._post_creational_filter = operator.itemgetter(0) + + self._unique_filter_state = real_result._unique_filter_state + + def unique( + self, + strategy: Optional[_UniqueFilterType] = None, + ) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncScalarResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_asyncio.AsyncScalarResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[_R]: + """Return all scalar values in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncScalarResult[_R]: + return self + + async def __anext__(self) -> _R: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + scalar values, rather than :class:`_engine.Row` objects, + are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncMappingResult(_WithKeys, AsyncCommon[RowMapping]): + """A wrapper for a :class:`_asyncio.AsyncResult` that returns dictionary + values rather than :class:`_engine.Row` values. + + The :class:`_asyncio.AsyncMappingResult` object is acquired by calling the + :meth:`_asyncio.AsyncResult.mappings` method. + + Refer to the :class:`_result.MappingResult` object in the synchronous + SQLAlchemy API for a complete behavioral description. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _generate_rows = True + + _post_creational_filter = operator.attrgetter("_mapping") + + def __init__(self, result: Result[Any]): + self._real_result = result + self._unique_filter_state = result._unique_filter_state + self._metadata = result._metadata + if result._source_supports_scalars: + self._metadata = self._metadata._reduce([0]) + + def unique( + self, + strategy: Optional[_UniqueFilterType] = None, + ) -> Self: + """Apply unique filtering to the objects returned by this + :class:`_asyncio.AsyncMappingResult`. + + See :meth:`_asyncio.AsyncResult.unique` for usage details. + + """ + self._unique_filter_state = (set(), strategy) + return self + + def columns(self, *col_expressions: _KeyIndexType) -> Self: + r"""Establish the columns that should be returned in each row.""" + return self._column_slices(col_expressions) + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[RowMapping]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_asyncio.AsyncResult.partitions` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + getter = self._manyrow_getter + + while True: + partition = await greenlet_spawn(getter, self, size) + if partition: + yield partition + else: + break + + async def fetchall(self) -> Sequence[RowMapping]: + """A synonym for the :meth:`_asyncio.AsyncMappingResult.all` method.""" + + return await greenlet_spawn(self._allrows) + + async def fetchone(self) -> Optional[RowMapping]: + """Fetch one object. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchone` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + return None + else: + return row + + async def fetchmany( + self, size: Optional[int] = None + ) -> Sequence[RowMapping]: + """Fetch many rows. + + Equivalent to :meth:`_asyncio.AsyncResult.fetchmany` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return await greenlet_spawn(self._manyrow_getter, self, size) + + async def all(self) -> Sequence[RowMapping]: + """Return all rows in a list. + + Equivalent to :meth:`_asyncio.AsyncResult.all` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + + return await greenlet_spawn(self._allrows) + + def __aiter__(self) -> AsyncMappingResult: + return self + + async def __anext__(self) -> RowMapping: + row = await greenlet_spawn(self._onerow_getter, self) + if row is _NO_ROW: + raise StopAsyncIteration() + else: + return row + + async def first(self) -> Optional[RowMapping]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_asyncio.AsyncResult.first` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, False, False, False) + + async def one_or_none(self) -> Optional[RowMapping]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one_or_none` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, False, False) + + async def one(self) -> RowMapping: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_asyncio.AsyncResult.one` except that + :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` + objects, are returned. + + """ + return await greenlet_spawn(self._only_one_row, True, True, False) + + +class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): + """A :class:`_asyncio.AsyncResult` that's typed as returning plain + Python tuples instead of rows. + + Since :class:`_engine.Row` acts like a tuple in every way already, + this class is a typing only class, regular :class:`_asyncio.AsyncResult` is + still used at runtime. + + """ + + __slots__ = () + + if TYPE_CHECKING: + + async def partitions( + self, size: Optional[int] = None + ) -> AsyncIterator[Sequence[_R]]: + """Iterate through sub-lists of elements of the size given. + + Equivalent to :meth:`_result.Result.partitions` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def fetchone(self) -> Optional[_R]: + """Fetch one tuple. + + Equivalent to :meth:`_result.Result.fetchone` except that + tuple values, rather than :class:`_engine.Row` + objects, are returned. + + """ + ... + + async def fetchall(self) -> Sequence[_R]: + """A synonym for the :meth:`_engine.ScalarResult.all` method.""" + ... + + async def fetchmany(self, size: Optional[int] = None) -> Sequence[_R]: + """Fetch many objects. + + Equivalent to :meth:`_result.Result.fetchmany` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def all(self) -> Sequence[_R]: # noqa: A001 + """Return all scalar values in a list. + + Equivalent to :meth:`_result.Result.all` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def __aiter__(self) -> AsyncIterator[_R]: ... + + async def __anext__(self) -> _R: ... + + async def first(self) -> Optional[_R]: + """Fetch the first object or ``None`` if no object is present. + + Equivalent to :meth:`_result.Result.first` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + + """ + ... + + async def one_or_none(self) -> Optional[_R]: + """Return at most one object or raise an exception. + + Equivalent to :meth:`_result.Result.one_or_none` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + async def one(self) -> _R: + """Return exactly one object or raise an exception. + + Equivalent to :meth:`_result.Result.one` except that + tuple values, rather than :class:`_engine.Row` objects, + are returned. + + """ + ... + + @overload + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ... + + @overload + async def scalar_one(self) -> Any: ... + + async def scalar_one(self) -> Any: + """Return exactly one scalar result or raise an exception. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one`. + + .. seealso:: + + :meth:`_engine.Result.one` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + async def scalar_one_or_none( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: ... + + @overload + async def scalar_one_or_none(self) -> Optional[Any]: ... + + async def scalar_one_or_none(self) -> Optional[Any]: + """Return exactly one or no scalar result. + + This is equivalent to calling :meth:`_engine.Result.scalars` + and then :meth:`_engine.Result.one_or_none`. + + .. seealso:: + + :meth:`_engine.Result.one_or_none` + + :meth:`_engine.Result.scalars` + + """ + ... + + @overload + async def scalar( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: ... + + @overload + async def scalar(self) -> Any: ... + + async def scalar(self) -> Any: + """Fetch the first column of the first row, and close the result + set. + + Returns ``None`` if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + + After calling this method, the object is fully closed, + e.g. the :meth:`_engine.CursorResult.close` + method will have been called. + + :return: a Python scalar value , or ``None`` if no rows remain. + + """ + ... + + +_RT = TypeVar("_RT", bound="Result[Any]") + + +async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: + cursor_result: CursorResult[Any] + + try: + is_cursor = result._is_cursor + except AttributeError: + # legacy execute(DefaultGenerator) case + return result + + if not is_cursor: + cursor_result = getattr(result, "raw", None) # type: ignore + else: + cursor_result = result # type: ignore + if cursor_result and cursor_result.context._is_server_side: + await greenlet_spawn(cursor_result.close) + raise async_exc.AsyncMethodRequired( + "Can't use the %s.%s() method with a " + "server-side cursor. " + "Use the %s.stream() method for an async " + "streaming result set." + % ( + calling_method.__self__.__class__.__name__, + calling_method.__name__, + calling_method.__self__.__class__.__name__, + ) + ) + return result diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/scoping.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/scoping.py new file mode 100644 index 0000000..e879a16 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/scoping.py @@ -0,0 +1,1614 @@ +# ext/asyncio/scoping.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .session import _AS +from .session import async_sessionmaker +from .session import AsyncSession +from ... import exc as sa_exc +from ... import util +from ...orm.session import Session +from ...util import create_proxy_methods +from ...util import ScopedRegistry +from ...util import warn +from ...util import warn_deprecated + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .result import AsyncResult + from .result import AsyncScalarResult + from .session import AsyncSessionTransaction + from ...engine import Connection + from ...engine import CursorResult + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import RowMapping + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import CoreExecuteOptionsParameter + from ...engine.result import ScalarResult + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm._typing import OrmExecuteOptionsParameter + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...sql.base import Executable + from ...sql.dml import UpdateBase + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateParameter + from ...sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + + +@create_proxy_methods( + AsyncSession, + ":class:`_asyncio.AsyncSession`", + ":class:`_asyncio.scoping.async_scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "aclose", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "reset", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get_bind", + "is_modified", + "invalidate", + "merge", + "refresh", + "rollback", + "scalar", + "scalars", + "get", + "get_one", + "stream", + "stream_scalars", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], + use_intermediate_variable=["get"], +) +class async_scoped_session(Generic[_AS]): + """Provides scoped management of :class:`.AsyncSession` objects. + + See the section :ref:`asyncio_scoped_session` for usage details. + + .. versionadded:: 1.4.19 + + + """ + + _support_async = True + + session_factory: async_sessionmaker[_AS] + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.AsyncSession` is needed.""" + + registry: ScopedRegistry[_AS] + + def __init__( + self, + session_factory: async_sessionmaker[_AS], + scopefunc: Callable[[], Any], + ): + """Construct a new :class:`_asyncio.async_scoped_session`. + + :param session_factory: a factory to create new :class:`_asyncio.AsyncSession` + instances. This is usually, but not necessarily, an instance + of :class:`_asyncio.async_sessionmaker`. + + :param scopefunc: function which defines + the current scope. A function such as ``asyncio.current_task`` + may be useful here. + + """ # noqa: E501 + + self.session_factory = session_factory + self.registry = ScopedRegistry(session_factory, scopefunc) + + @property + def _proxied(self) -> _AS: + return self.registry() + + def __call__(self, **kw: Any) -> _AS: + r"""Return the current :class:`.AsyncSession`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.AsyncSession` is not present. If the + :class:`.AsyncSession` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + + async def remove(self) -> None: + """Dispose of the current :class:`.AsyncSession`, if present. + + Different from scoped_session's remove method, this method would use + await to wait for the close method of AsyncSession. + + """ + + if self.registry.has(): + await self.registry().close() + self.registry.clear() + + # START PROXY METHODS async_scoped_session + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + + + """ # noqa: E501 + + return self._proxied.__iter__() + + async def aclose(self) -> None: + r"""A synonym for :meth:`_asyncio.AsyncSession.close`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The :meth:`_asyncio.AsyncSession.aclose` name is specifically + to support the Python standard library ``@contextlib.aclosing`` + context manager function. + + .. versionadded:: 2.0.20 + + + """ # noqa: E501 + + return await self._proxied.aclose() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def begin(self) -> AsyncSessionTransaction: + r"""Return an :class:`_asyncio.AsyncSessionTransaction` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + + """ # noqa: E501 + + return self._proxied.begin() + + def begin_nested(self) -> AsyncSessionTransaction: + r"""Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + .. seealso:: + + :ref:`aiosqlite_serializable` - special workarounds required + with the SQLite asyncio driver in order for SAVEPOINT to work + correctly. + + + """ # noqa: E501 + + return self._proxied.begin_nested() + + async def close(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_asyncio.AsyncSession`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.close` - main documentation for + "close" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + + """ # noqa: E501 + + return await self._proxied.close() + + async def reset(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`, resetting the session to its initial state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. versionadded:: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.reset` - main documentation for + "reset" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + + """ # noqa: E501 + + return await self._proxied.reset() + + async def commit(self) -> None: + r"""Commit the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.commit` - main documentation for + "commit" + + """ # noqa: E501 + + return await self._proxied.commit() + + async def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + **kw: Any, + ) -> AsyncConnection: + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + This method may also be used to establish execution options for the + database connection used by the current transaction. + + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. + + .. seealso:: + + :meth:`_orm.Session.connection` - main documentation for + "connection" + + + """ # noqa: E501 + + return await self._proxied.connection( + bind_arguments=bind_arguments, + execution_options=execution_options, + **kw, + ) + + async def delete(self, instance: object) -> None: + r"""Mark an instance as deleted. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + The database delete operation occurs upon ``flush()``. + + As this operation may need to cascade along unloaded relationships, + it is awaitable to allow for those queries to take place. + + .. seealso:: + + :meth:`_orm.Session.delete` - main documentation for delete + + + """ # noqa: E501 + + return await self._proxied.delete(instance) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: ... + + @overload + async def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result[Any]: + r"""Execute a statement and return a buffered + :class:`_engine.Result` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.execute` - main documentation for execute + + + """ # noqa: E501 + + return await self._proxied.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + r"""Flush all the object changes to the database. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.flush` - main documentation for flush + + + """ # noqa: E501 + + return await self._proxied.flush(objects=objects) + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: + r"""Return a "bind" to which the synchronous proxied :class:`_orm.Session` + is bound. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Unlike the :meth:`_orm.Session.get_bind` method, this method is + currently **not** used by this :class:`.AsyncSession` in any way + in order to resolve engines for requests. + + .. note:: + + This method proxies directly to the :meth:`_orm.Session.get_bind` + method, however is currently **not** useful as an override target, + in contrast to that of the :meth:`_orm.Session.get_bind` method. + The example below illustrates how to implement custom + :meth:`_orm.Session.get_bind` schemes that work with + :class:`.AsyncSession` and :class:`.AsyncEngine`. + + The pattern introduced at :ref:`session_custom_partitioning` + illustrates how to apply a custom bind-lookup scheme to a + :class:`_orm.Session` given a set of :class:`_engine.Engine` objects. + To apply a corresponding :meth:`_orm.Session.get_bind` implementation + for use with a :class:`.AsyncSession` and :class:`.AsyncEngine` + objects, continue to subclass :class:`_orm.Session` and apply it to + :class:`.AsyncSession` using + :paramref:`.AsyncSession.sync_session_class`. The inner method must + continue to return :class:`_engine.Engine` instances, which can be + acquired from a :class:`_asyncio.AsyncEngine` using the + :attr:`_asyncio.AsyncEngine.sync_engine` attribute:: + + # using example from "Custom Vertical Partitioning" + + + import random + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session + + # construct async engines w/ async drivers + engines = { + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + } + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kw): + # within get_bind(), return sync engines + if mapper and issubclass(mapper.class_, MyOtherClass): + return engines['other'].sync_engine + elif self._flushing or isinstance(clause, (Update, Delete)): + return engines['leader'].sync_engine + else: + return engines[ + random.choice(['follower1','follower2']) + ].sync_engine + + # apply to AsyncSession using sync_session_class + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) + + The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, + implicitly non-blocking context in the same manner as ORM event hooks + and functions that are invoked via :meth:`.AsyncSession.run_sync`, so + routines that wish to run SQL commands inside of + :meth:`_orm.Session.get_bind` can continue to do so using + blocking-style code, which will be translated to implicitly async calls + at the point of invoking IO on the database drivers. + + + """ # noqa: E501 + + return self._proxied.get_bind( + mapper=mapper, clause=clause, bind=bind, **kw + ) + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + async def invalidate(self) -> None: + r"""Close this Session, using connection invalidation. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + For a complete description, see :meth:`_orm.Session.invalidate`. + + """ # noqa: E501 + + return await self._proxied.invalidate() + + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + r"""Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.merge` - main documentation for merge + + + """ # noqa: E501 + + return await self._proxied.merge(instance, load=load, options=options) + + async def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: ForUpdateParameter = None, + ) -> None: + r"""Expire and refresh the attributes on the given instance. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + .. seealso:: + + :meth:`_orm.Session.refresh` - main documentation for refresh + + + """ # noqa: E501 + + return await self._proxied.refresh( + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def rollback(self) -> None: + r"""Rollback the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.rollback` - main documentation for + "rollback" + + """ # noqa: E501 + + return await self._proxied.rollback() + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: ... + + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + r"""Execute a statement and return a scalar result. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.scalar` - main documentation for scalar + + + """ # noqa: E501 + + return await self._proxied.scalar( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: ... + + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + r"""Execute a statement and return scalar results. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_asyncio.AsyncSession.scalars` + + .. versionadded:: 1.4.26 Added + :meth:`_asyncio.async_scoped_session.scalars` + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version + + + """ # noqa: E501 + + return await self._proxied.scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Union[_O, None]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + + """ # noqa: E501 + + result = await self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + return result + + async def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> _O: + r"""Return an instance based on the given primary key identifier, + or raise an exception if not found. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. + + ..versionadded: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.get_one` - main documentation for get_one + + + """ # noqa: E501 + + return await self._proxied.get_one( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + r"""Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return await self._proxied.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: ... + + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + r"""Execute a statement and return a stream of scalar results. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.scalars` - non streaming version + + + """ # noqa: E501 + + return await self._proxied.stream_scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @property + def bind(self) -> Any: + r"""Proxy for the :attr:`_asyncio.AsyncSession.bind` attribute + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + """ # noqa: E501 + + return self._proxied.bind + + @bind.setter + def bind(self, attr: Any) -> None: + self._proxied.bind = attr + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> Any: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: Any) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> Any: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: Any) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class + on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + async def close_all(cls) -> None: + r"""Close all :class:`_asyncio.AsyncSession` sessions. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. deprecated:: 2.0 The :meth:`.AsyncSession.close_all` method is deprecated and will be removed in a future release. Please refer to :func:`_asyncio.close_all_sessions`. + + """ # noqa: E501 + + return await AsyncSession.close_all() + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.object_session`. + + + + """ # noqa: E501 + + return AsyncSession.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_asyncio.AsyncSession` class on + behalf of the :class:`_asyncio.scoping.async_scoped_session` class. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.util.identity_key`. + + + + """ # noqa: E501 + + return AsyncSession.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS async_scoped_session diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/session.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/session.py new file mode 100644 index 0000000..c5fe469 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/asyncio/session.py @@ -0,0 +1,1936 @@ +# ext/asyncio/session.py +# Copyright (C) 2020-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import asyncio +from typing import Any +from typing import Awaitable +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import engine +from .base import ReversibleProxy +from .base import StartableContext +from .result import _ensure_sync_result +from .result import AsyncResult +from .result import AsyncScalarResult +from ... import util +from ...orm import close_all_sessions as _sync_close_all_sessions +from ...orm import object_session +from ...orm import Session +from ...orm import SessionTransaction +from ...orm import state as _instance_state +from ...util.concurrency import greenlet_spawn +from ...util.typing import Concatenate +from ...util.typing import ParamSpec + + +if TYPE_CHECKING: + from .engine import AsyncConnection + from .engine import AsyncEngine + from ...engine import Connection + from ...engine import CursorResult + from ...engine import Engine + from ...engine import Result + from ...engine import Row + from ...engine import RowMapping + from ...engine import ScalarResult + from ...engine.interfaces import _CoreAnyExecuteParams + from ...engine.interfaces import CoreExecuteOptionsParameter + from ...event import dispatcher + from ...orm._typing import _IdentityKeyType + from ...orm._typing import _O + from ...orm._typing import OrmExecuteOptionsParameter + from ...orm.identity import IdentityMap + from ...orm.interfaces import ORMOption + from ...orm.session import _BindArguments + from ...orm.session import _EntityBindKey + from ...orm.session import _PKIdentityArgument + from ...orm.session import _SessionBind + from ...orm.session import _SessionBindKey + from ...sql._typing import _InfoType + from ...sql.base import Executable + from ...sql.dml import UpdateBase + from ...sql.elements import ClauseElement + from ...sql.selectable import ForUpdateParameter + from ...sql.selectable import TypedReturnsRows + +_AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] + +_P = ParamSpec("_P") +_T = TypeVar("_T", bound=Any) + + +_EXECUTE_OPTIONS = util.immutabledict({"prebuffer_rows": True}) +_STREAM_OPTIONS = util.immutabledict({"stream_results": True}) + + +class AsyncAttrs: + """Mixin class which provides an awaitable accessor for all attributes. + + E.g.:: + + from __future__ import annotations + + from typing import List + + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy.ext.asyncio import AsyncAttrs + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(AsyncAttrs, DeclarativeBase): + pass + + + class A(Base): + __tablename__ = "a" + + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[str] + bs: Mapped[List[B]] = relationship() + + + class B(Base): + __tablename__ = "b" + id: Mapped[int] = mapped_column(primary_key=True) + a_id: Mapped[int] = mapped_column(ForeignKey("a.id")) + data: Mapped[str] + + In the above example, the :class:`_asyncio.AsyncAttrs` mixin is applied to + the declarative ``Base`` class where it takes effect for all subclasses. + This mixin adds a single new attribute + :attr:`_asyncio.AsyncAttrs.awaitable_attrs` to all classes, which will + yield the value of any attribute as an awaitable. This allows attributes + which may be subject to lazy loading or deferred / unexpiry loading to be + accessed such that IO can still be emitted:: + + a1 = (await async_session.scalars(select(A).where(A.id == 5))).one() + + # use the lazy loader on ``a1.bs`` via the ``.awaitable_attrs`` + # interface, so that it may be awaited + for b1 in await a1.awaitable_attrs.bs: + print(b1) + + The :attr:`_asyncio.AsyncAttrs.awaitable_attrs` performs a call against the + attribute that is approximately equivalent to using the + :meth:`_asyncio.AsyncSession.run_sync` method, e.g.:: + + for b1 in await async_session.run_sync(lambda sess: a1.bs): + print(b1) + + .. versionadded:: 2.0.13 + + .. seealso:: + + :ref:`asyncio_orm_avoid_lazyloads` + + """ + + class _AsyncAttrGetitem: + __slots__ = "_instance" + + def __init__(self, _instance: Any): + self._instance = _instance + + def __getattr__(self, name: str) -> Awaitable[Any]: + return greenlet_spawn(getattr, self._instance, name) + + @property + def awaitable_attrs(self) -> AsyncAttrs._AsyncAttrGetitem: + """provide a namespace of all attributes on this object wrapped + as awaitables. + + e.g.:: + + + a1 = (await async_session.scalars(select(A).where(A.id == 5))).one() + + some_attribute = await a1.awaitable_attrs.some_deferred_attribute + some_collection = await a1.awaitable_attrs.some_collection + + """ # noqa: E501 + + return AsyncAttrs._AsyncAttrGetitem(self) + + +@util.create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_asyncio.AsyncSession`", + classmethods=["object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "expire", + "expire_all", + "expunge", + "expunge_all", + "is_modified", + "in_transaction", + "in_nested_transaction", + ], + attributes=[ + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class AsyncSession(ReversibleProxy[Session]): + """Asyncio version of :class:`_orm.Session`. + + The :class:`_asyncio.AsyncSession` is a proxy for a traditional + :class:`_orm.Session` instance. + + The :class:`_asyncio.AsyncSession` is **not safe for use in concurrent + tasks.**. See :ref:`session_faq_threadsafe` for background. + + .. versionadded:: 1.4 + + To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session` + implementations, see the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. + + + """ + + _is_asyncio = True + + dispatch: dispatcher[Session] + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + *, + binds: Optional[Dict[_SessionBindKey, _AsyncSessionBind]] = None, + sync_session_class: Optional[Type[Session]] = None, + **kw: Any, + ): + r"""Construct a new :class:`_asyncio.AsyncSession`. + + All parameters other than ``sync_session_class`` are passed to the + ``sync_session_class`` callable directly to instantiate a new + :class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for + parameter documentation. + + :param sync_session_class: + A :class:`_orm.Session` subclass or other callable which will be used + to construct the :class:`_orm.Session` which will be proxied. This + parameter may be used to provide custom :class:`_orm.Session` + subclasses. Defaults to the + :attr:`_asyncio.AsyncSession.sync_session_class` class-level + attribute. + + .. versionadded:: 1.4.24 + + """ + sync_bind = sync_binds = None + + if bind: + self.bind = bind + sync_bind = engine._get_sync_engine_or_connection(bind) + + if binds: + self.binds = binds + sync_binds = { + key: engine._get_sync_engine_or_connection(b) + for key, b in binds.items() + } + + if sync_session_class: + self.sync_session_class = sync_session_class + + self.sync_session = self._proxied = self._assign_proxied( + self.sync_session_class(bind=sync_bind, binds=sync_binds, **kw) + ) + + sync_session_class: Type[Session] = Session + """The class or callable that provides the + underlying :class:`_orm.Session` instance for a particular + :class:`_asyncio.AsyncSession`. + + At the class level, this attribute is the default value for the + :paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom + subclasses of :class:`_asyncio.AsyncSession` can override this. + + At the instance level, this attribute indicates the current class or + callable that was used to provide the :class:`_orm.Session` instance for + this :class:`_asyncio.AsyncSession` instance. + + .. versionadded:: 1.4.24 + + """ + + sync_session: Session + """Reference to the underlying :class:`_orm.Session` this + :class:`_asyncio.AsyncSession` proxies requests towards. + + This instance can be used as an event target. + + .. seealso:: + + :ref:`asyncio_events` + + """ + + @classmethod + def _no_async_engine_events(cls) -> NoReturn: + raise NotImplementedError( + "asynchronous events are not implemented at this time. Apply " + "synchronous listeners to the AsyncSession.sync_session." + ) + + async def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: ForUpdateParameter = None, + ) -> None: + """Expire and refresh the attributes on the given instance. + + A query will be issued to the database and all attributes will be + refreshed with their current database value. + + This is the async version of the :meth:`_orm.Session.refresh` method. + See that method for a complete description of all options. + + .. seealso:: + + :meth:`_orm.Session.refresh` - main documentation for refresh + + """ + + await greenlet_spawn( + self.sync_session.refresh, + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + async def run_sync( + self, + fn: Callable[Concatenate[Session, _P], _T], + *arg: _P.args, + **kw: _P.kwargs, + ) -> _T: + """Invoke the given synchronous (i.e. not async) callable, + passing a synchronous-style :class:`_orm.Session` as the first + argument. + + This method allows traditional synchronous SQLAlchemy functions to + run within the context of an asyncio application. + + E.g.:: + + def some_business_method(session: Session, param: str) -> str: + '''A synchronous function that does not require awaiting + + :param session: a SQLAlchemy Session, used synchronously + + :return: an optional return value is supported + + ''' + session.add(MyObject(param=param)) + session.flush() + return "success" + + + async def do_something_async(async_engine: AsyncEngine) -> None: + '''an async function that uses awaiting''' + + with AsyncSession(async_engine) as async_session: + # run some_business_method() with a sync-style + # Session, proxied into an awaitable + return_code = await async_session.run_sync(some_business_method, param="param1") + print(return_code) + + This method maintains the asyncio event loop all the way through + to the database connection by running the given callable in a + specially instrumented greenlet. + + .. tip:: + + The provided callable is invoked inline within the asyncio event + loop, and will block on traditional IO calls. IO within this + callable should only call into SQLAlchemy's asyncio database + APIs which will be properly adapted to the greenlet context. + + .. seealso:: + + :class:`.AsyncAttrs` - a mixin for ORM mapped classes that provides + a similar feature more succinctly on a per-attribute basis + + :meth:`.AsyncConnection.run_sync` + + :ref:`session_run_sync` + """ # noqa: E501 + + return await greenlet_spawn( + fn, self.sync_session, *arg, _require_await=False, **kw + ) + + @overload + async def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: ... + + @overload + async def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + + @overload + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + + async def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Result[Any]: + """Execute a statement and return a buffered + :class:`_engine.Result` object. + + .. seealso:: + + :meth:`_orm.Session.execute` - main documentation for execute + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return await _ensure_sync_result(result, self.execute) + + @overload + async def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: ... + + @overload + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: ... + + async def scalar( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + """Execute a statement and return a scalar result. + + .. seealso:: + + :meth:`_orm.Session.scalar` - main documentation for scalar + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + return await greenlet_spawn( + self.sync_session.scalar, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + async def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: ... + + @overload + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: ... + + async def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + """Execute a statement and return scalar results. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_asyncio.AsyncSession.scalars` + + .. versionadded:: 1.4.26 Added + :meth:`_asyncio.async_scoped_session.scalars` + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.stream_scalars` - streaming version + + """ + + result = await self.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return result.scalars() + + async def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> Union[_O, None]: + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. seealso:: + + :meth:`_orm.Session.get` - main documentation for get + + + """ + + return await greenlet_spawn( + cast("Callable[..., _O]", self.sync_session.get), + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + + async def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + ) -> _O: + """Return an instance based on the given primary key identifier, + or raise an exception if not found. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. + + ..versionadded: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.get_one` - main documentation for get_one + + """ + + return await greenlet_spawn( + cast("Callable[..., _O]", self.sync_session.get_one), + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) + + @overload + async def stream( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[_T]: ... + + @overload + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: ... + + async def stream( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncResult[Any]: + """Execute a statement and return a streaming + :class:`_asyncio.AsyncResult` object. + + """ + + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _STREAM_OPTIONS + ) + else: + execution_options = _STREAM_OPTIONS + + result = await greenlet_spawn( + self.sync_session.execute, + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return AsyncResult(result) + + @overload + async def stream_scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[_T]: ... + + @overload + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: ... + + async def stream_scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> AsyncScalarResult[Any]: + """Execute a statement and return a stream of scalar results. + + :return: an :class:`_asyncio.AsyncScalarResult` object + + .. versionadded:: 1.4.24 + + .. seealso:: + + :meth:`_orm.Session.scalars` - main documentation for scalars + + :meth:`_asyncio.AsyncSession.scalars` - non streaming version + + """ + + result = await self.stream( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + return result.scalars() + + async def delete(self, instance: object) -> None: + """Mark an instance as deleted. + + The database delete operation occurs upon ``flush()``. + + As this operation may need to cascade along unloaded relationships, + it is awaitable to allow for those queries to take place. + + .. seealso:: + + :meth:`_orm.Session.delete` - main documentation for delete + + """ + await greenlet_spawn(self.sync_session.delete, instance) + + async def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + """Copy the state of a given instance into a corresponding instance + within this :class:`_asyncio.AsyncSession`. + + .. seealso:: + + :meth:`_orm.Session.merge` - main documentation for merge + + """ + return await greenlet_spawn( + self.sync_session.merge, instance, load=load, options=options + ) + + async def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + """Flush all the object changes to the database. + + .. seealso:: + + :meth:`_orm.Session.flush` - main documentation for flush + + """ + await greenlet_spawn(self.sync_session.flush, objects=objects) + + def get_transaction(self) -> Optional[AsyncSessionTransaction]: + """Return the current root transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + trans = self.sync_session.get_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_nested_transaction(self) -> Optional[AsyncSessionTransaction]: + """Return the current nested transaction in progress, if any. + + :return: an :class:`_asyncio.AsyncSessionTransaction` object, or + ``None``. + + .. versionadded:: 1.4.18 + + """ + + trans = self.sync_session.get_nested_transaction() + if trans is not None: + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) + else: + return None + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + **kw: Any, + ) -> Union[Engine, Connection]: + """Return a "bind" to which the synchronous proxied :class:`_orm.Session` + is bound. + + Unlike the :meth:`_orm.Session.get_bind` method, this method is + currently **not** used by this :class:`.AsyncSession` in any way + in order to resolve engines for requests. + + .. note:: + + This method proxies directly to the :meth:`_orm.Session.get_bind` + method, however is currently **not** useful as an override target, + in contrast to that of the :meth:`_orm.Session.get_bind` method. + The example below illustrates how to implement custom + :meth:`_orm.Session.get_bind` schemes that work with + :class:`.AsyncSession` and :class:`.AsyncEngine`. + + The pattern introduced at :ref:`session_custom_partitioning` + illustrates how to apply a custom bind-lookup scheme to a + :class:`_orm.Session` given a set of :class:`_engine.Engine` objects. + To apply a corresponding :meth:`_orm.Session.get_bind` implementation + for use with a :class:`.AsyncSession` and :class:`.AsyncEngine` + objects, continue to subclass :class:`_orm.Session` and apply it to + :class:`.AsyncSession` using + :paramref:`.AsyncSession.sync_session_class`. The inner method must + continue to return :class:`_engine.Engine` instances, which can be + acquired from a :class:`_asyncio.AsyncEngine` using the + :attr:`_asyncio.AsyncEngine.sync_engine` attribute:: + + # using example from "Custom Vertical Partitioning" + + + import random + + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import async_sessionmaker + from sqlalchemy.orm import Session + + # construct async engines w/ async drivers + engines = { + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), + } + + class RoutingSession(Session): + def get_bind(self, mapper=None, clause=None, **kw): + # within get_bind(), return sync engines + if mapper and issubclass(mapper.class_, MyOtherClass): + return engines['other'].sync_engine + elif self._flushing or isinstance(clause, (Update, Delete)): + return engines['leader'].sync_engine + else: + return engines[ + random.choice(['follower1','follower2']) + ].sync_engine + + # apply to AsyncSession using sync_session_class + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) + + The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, + implicitly non-blocking context in the same manner as ORM event hooks + and functions that are invoked via :meth:`.AsyncSession.run_sync`, so + routines that wish to run SQL commands inside of + :meth:`_orm.Session.get_bind` can continue to do so using + blocking-style code, which will be translated to implicitly async calls + at the point of invoking IO on the database drivers. + + """ # noqa: E501 + + return self.sync_session.get_bind( + mapper=mapper, clause=clause, bind=bind, **kw + ) + + async def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + **kw: Any, + ) -> AsyncConnection: + r"""Return a :class:`_asyncio.AsyncConnection` object corresponding to + this :class:`.Session` object's transactional state. + + This method may also be used to establish execution options for the + database connection used by the current transaction. + + .. versionadded:: 1.4.24 Added \**kw arguments which are passed + through to the underlying :meth:`_orm.Session.connection` method. + + .. seealso:: + + :meth:`_orm.Session.connection` - main documentation for + "connection" + + """ + + sync_connection = await greenlet_spawn( + self.sync_session.connection, + bind_arguments=bind_arguments, + execution_options=execution_options, + **kw, + ) + return engine.AsyncConnection._retrieve_proxy_for_target( + sync_connection + ) + + def begin(self) -> AsyncSessionTransaction: + """Return an :class:`_asyncio.AsyncSessionTransaction` object. + + The underlying :class:`_orm.Session` will perform the + "begin" action when the :class:`_asyncio.AsyncSessionTransaction` + object is entered:: + + async with async_session.begin(): + # .. ORM transaction is begun + + Note that database IO will not normally occur when the session-level + transaction is begun, as database transactions begin on an + on-demand basis. However, the begin block is async to accommodate + for a :meth:`_orm.SessionEvents.after_transaction_create` + event hook that may perform IO. + + For a general description of ORM begin, see + :meth:`_orm.Session.begin`. + + """ + + return AsyncSessionTransaction(self) + + def begin_nested(self) -> AsyncSessionTransaction: + """Return an :class:`_asyncio.AsyncSessionTransaction` object + which will begin a "nested" transaction, e.g. SAVEPOINT. + + Behavior is the same as that of :meth:`_asyncio.AsyncSession.begin`. + + For a general description of ORM begin nested, see + :meth:`_orm.Session.begin_nested`. + + .. seealso:: + + :ref:`aiosqlite_serializable` - special workarounds required + with the SQLite asyncio driver in order for SAVEPOINT to work + correctly. + + """ + + return AsyncSessionTransaction(self, nested=True) + + async def rollback(self) -> None: + """Rollback the current transaction in progress. + + .. seealso:: + + :meth:`_orm.Session.rollback` - main documentation for + "rollback" + """ + await greenlet_spawn(self.sync_session.rollback) + + async def commit(self) -> None: + """Commit the current transaction in progress. + + .. seealso:: + + :meth:`_orm.Session.commit` - main documentation for + "commit" + """ + await greenlet_spawn(self.sync_session.commit) + + async def close(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_asyncio.AsyncSession`. + + .. seealso:: + + :meth:`_orm.Session.close` - main documentation for + "close" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + """ + await greenlet_spawn(self.sync_session.close) + + async def reset(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`, resetting the session to its initial state. + + .. versionadded:: 2.0.22 + + .. seealso:: + + :meth:`_orm.Session.reset` - main documentation for + "reset" + + :ref:`session_closing` - detail on the semantics of + :meth:`_asyncio.AsyncSession.close` and + :meth:`_asyncio.AsyncSession.reset`. + + """ + await greenlet_spawn(self.sync_session.reset) + + async def aclose(self) -> None: + """A synonym for :meth:`_asyncio.AsyncSession.close`. + + The :meth:`_asyncio.AsyncSession.aclose` name is specifically + to support the Python standard library ``@contextlib.aclosing`` + context manager function. + + .. versionadded:: 2.0.20 + + """ + await self.close() + + async def invalidate(self) -> None: + """Close this Session, using connection invalidation. + + For a complete description, see :meth:`_orm.Session.invalidate`. + """ + await greenlet_spawn(self.sync_session.invalidate) + + @classmethod + @util.deprecated( + "2.0", + "The :meth:`.AsyncSession.close_all` method is deprecated and will be " + "removed in a future release. Please refer to " + ":func:`_asyncio.close_all_sessions`.", + ) + async def close_all(cls) -> None: + """Close all :class:`_asyncio.AsyncSession` sessions.""" + await close_all_sessions() + + async def __aenter__(self: _AS) -> _AS: + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + task = asyncio.create_task(self.close()) + await asyncio.shield(task) + + def _maker_context_manager(self: _AS) -> _AsyncSessionContextManager[_AS]: + return _AsyncSessionContextManager(self) + + # START PROXY METHODS AsyncSession + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + + """ # noqa: E501 + + return self._proxied.__iter__() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + def in_transaction(self) -> bool: + r"""Return True if this :class:`_orm.Session` has begun a transaction. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_orm.Session.is_active` + + + + """ # noqa: E501 + + return self._proxied.in_transaction() + + def in_nested_transaction(self) -> bool: + r"""Return True if this :class:`_orm.Session` has begun a nested + transaction, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionadded:: 1.4 + + + """ # noqa: E501 + + return self._proxied.in_nested_transaction() + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> IdentityMap: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: IdentityMap) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> bool: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_asyncio.AsyncSession` class. + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: bool) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_asyncio.AsyncSession` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.object_session`. + + + """ # noqa: E501 + + return Session.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_asyncio.AsyncSession` class. + + This is an alias of :func:`.util.identity_key`. + + + """ # noqa: E501 + + return Session.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS AsyncSession + + +_AS = TypeVar("_AS", bound="AsyncSession") + + +class async_sessionmaker(Generic[_AS]): + """A configurable :class:`.AsyncSession` factory. + + The :class:`.async_sessionmaker` factory works in the same way as the + :class:`.sessionmaker` factory, to generate new :class:`.AsyncSession` + objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.ext.asyncio import AsyncSession + from sqlalchemy.ext.asyncio import async_sessionmaker + + async def run_some_sql(async_session: async_sessionmaker[AsyncSession]) -> None: + async with async_session() as session: + session.add(SomeObject(data="object")) + session.add(SomeOtherObject(name="other object")) + await session.commit() + + async def main() -> None: + # an AsyncEngine, which the AsyncSession will use for connection + # resources + engine = create_async_engine('postgresql+asyncpg://scott:tiger@localhost/') + + # create a reusable factory for new AsyncSession instances + async_session = async_sessionmaker(engine) + + await run_some_sql(async_session) + + await engine.dispose() + + The :class:`.async_sessionmaker` is useful so that different parts + of a program can create new :class:`.AsyncSession` objects with a + fixed configuration established up front. Note that :class:`.AsyncSession` + objects may also be instantiated directly when not using + :class:`.async_sessionmaker`. + + .. versionadded:: 2.0 :class:`.async_sessionmaker` provides a + :class:`.sessionmaker` class that's dedicated to the + :class:`.AsyncSession` object, including pep-484 typing support. + + .. seealso:: + + :ref:`asyncio_orm` - shows example use + + :class:`.sessionmaker` - general overview of the + :class:`.sessionmaker` architecture + + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ # noqa E501 + + class_: Type[_AS] + + @overload + def __init__( + self, + bind: Optional[_AsyncSessionBind] = ..., + *, + class_: Type[_AS], + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): ... + + @overload + def __init__( + self: "async_sessionmaker[AsyncSession]", + bind: Optional[_AsyncSessionBind] = ..., + *, + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): ... + + def __init__( + self, + bind: Optional[_AsyncSessionBind] = None, + *, + class_: Type[_AS] = AsyncSession, # type: ignore + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[_InfoType] = None, + **kw: Any, + ): + r"""Construct a new :class:`.async_sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.AsyncSession.__init__` docstring for more details on + parameters. + + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + self.class_ = class_ + + def begin(self) -> _AsyncSessionContextManager[_AS]: + """Produce a context manager that both provides a new + :class:`_orm.AsyncSession` as well as a transaction that commits. + + + e.g.:: + + async def main(): + Session = async_sessionmaker(some_engine) + + async with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> _AS: + """Produce a new :class:`.AsyncSession` object using the configuration + established in this :class:`.async_sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False) + session = AsyncSession() # invokes sessionmaker.__call__() + + """ # noqa E501 + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this async_sessionmaker. + + e.g.:: + + AsyncSession = async_sessionmaker(some_engine) + + AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) + """ # noqa E501 + + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + +class _AsyncSessionContextManager(Generic[_AS]): + __slots__ = ("async_session", "trans") + + async_session: _AS + trans: AsyncSessionTransaction + + def __init__(self, async_session: _AS): + self.async_session = async_session + + async def __aenter__(self) -> _AS: + self.trans = self.async_session.begin() + await self.trans.__aenter__() + return self.async_session + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + async def go() -> None: + await self.trans.__aexit__(type_, value, traceback) + await self.async_session.__aexit__(type_, value, traceback) + + task = asyncio.create_task(go()) + await asyncio.shield(task) + + +class AsyncSessionTransaction( + ReversibleProxy[SessionTransaction], + StartableContext["AsyncSessionTransaction"], +): + """A wrapper for the ORM :class:`_orm.SessionTransaction` object. + + This object is provided so that a transaction-holding object + for the :meth:`_asyncio.AsyncSession.begin` may be returned. + + The object supports both explicit calls to + :meth:`_asyncio.AsyncSessionTransaction.commit` and + :meth:`_asyncio.AsyncSessionTransaction.rollback`, as well as use as an + async context manager. + + + .. versionadded:: 1.4 + + """ + + __slots__ = ("session", "sync_transaction", "nested") + + session: AsyncSession + sync_transaction: Optional[SessionTransaction] + + def __init__(self, session: AsyncSession, nested: bool = False): + self.session = session + self.nested = nested + self.sync_transaction = None + + @property + def is_active(self) -> bool: + return ( + self._sync_transaction() is not None + and self._sync_transaction().is_active + ) + + def _sync_transaction(self) -> SessionTransaction: + if not self.sync_transaction: + self._raise_for_not_started() + return self.sync_transaction + + async def rollback(self) -> None: + """Roll back this :class:`_asyncio.AsyncTransaction`.""" + await greenlet_spawn(self._sync_transaction().rollback) + + async def commit(self) -> None: + """Commit this :class:`_asyncio.AsyncTransaction`.""" + + await greenlet_spawn(self._sync_transaction().commit) + + async def start( + self, is_ctxmanager: bool = False + ) -> AsyncSessionTransaction: + self.sync_transaction = self._assign_proxied( + await greenlet_spawn( + self.session.sync_session.begin_nested # type: ignore + if self.nested + else self.session.sync_session.begin + ) + ) + if is_ctxmanager: + self.sync_transaction.__enter__() + return self + + async def __aexit__(self, type_: Any, value: Any, traceback: Any) -> None: + await greenlet_spawn( + self._sync_transaction().__exit__, type_, value, traceback + ) + + +def async_object_session(instance: object) -> Optional[AsyncSession]: + """Return the :class:`_asyncio.AsyncSession` to which the given instance + belongs. + + This function makes use of the sync-API function + :class:`_orm.object_session` to retrieve the :class:`_orm.Session` which + refers to the given instance, and from there links it to the original + :class:`_asyncio.AsyncSession`. + + If the :class:`_asyncio.AsyncSession` has been garbage collected, the + return value is ``None``. + + This functionality is also available from the + :attr:`_orm.InstanceState.async_session` accessor. + + :param instance: an ORM mapped instance + :return: an :class:`_asyncio.AsyncSession` object, or ``None``. + + .. versionadded:: 1.4.18 + + """ + + session = object_session(instance) + if session is not None: + return async_session(session) + else: + return None + + +def async_session(session: Session) -> Optional[AsyncSession]: + """Return the :class:`_asyncio.AsyncSession` which is proxying the given + :class:`_orm.Session` object, if any. + + :param session: a :class:`_orm.Session` instance. + :return: a :class:`_asyncio.AsyncSession` instance, or ``None``. + + .. versionadded:: 1.4.18 + + """ + return AsyncSession._retrieve_proxy_for_target(session, regenerate=False) + + +async def close_all_sessions() -> None: + """Close all :class:`_asyncio.AsyncSession` sessions. + + .. versionadded:: 2.0.23 + + .. seealso:: + + :func:`.session.close_all_sessions` + + """ + await greenlet_spawn(_sync_close_all_sessions) + + +_instance_state._async_provider = async_session # type: ignore diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/automap.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/automap.py new file mode 100644 index 0000000..bf6a5f2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/automap.py @@ -0,0 +1,1658 @@ +# ext/automap.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system +which automatically generates mapped classes and relationships from a database +schema, typically though not necessarily one which is reflected. + +It is hoped that the :class:`.AutomapBase` system provides a quick +and modernized solution to the problem that the very famous +`SQLSoup `_ +also tries to solve, that of generating a quick and rudimentary object +model from an existing database on the fly. By addressing the issue strictly +at the mapper configuration level, and integrating fully with existing +Declarative class techniques, :class:`.AutomapBase` seeks to provide +a well-integrated approach to the issue of expediently auto-generating ad-hoc +mappings. + +.. tip:: The :ref:`automap_toplevel` extension is geared towards a + "zero declaration" approach, where a complete ORM model including classes + and pre-named relationships can be generated on the fly from a database + schema. For applications that still want to use explicit class declarations + including explicit relationship definitions in conjunction with reflection + of tables, the :class:`.DeferredReflection` class, described at + :ref:`orm_declarative_reflected_deferred_reflection`, is a better choice. + +.. _automap_basic_use: + +Basic Use +========= + +The simplest usage is to reflect an existing database into a new model. +We create a new :class:`.AutomapBase` class in a similar manner as to how +we create a declarative base class, using :func:`.automap_base`. +We then call :meth:`.AutomapBase.prepare` on the resulting base class, +asking it to reflect the schema and produce mappings:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy.orm import Session + from sqlalchemy import create_engine + + Base = automap_base() + + # engine, suppose it has two tables 'user' and 'address' set up + engine = create_engine("sqlite:///mydatabase.db") + + # reflect the tables + Base.prepare(autoload_with=engine) + + # mapped classes are now created with names by default + # matching that of the table name. + User = Base.classes.user + Address = Base.classes.address + + session = Session(engine) + + # rudimentary relationships are produced + session.add(Address(email_address="foo@bar.com", user=User(name="foo"))) + session.commit() + + # collection-based relationships are by default named + # "_collection" + u1 = session.query(User).first() + print (u1.address_collection) + +Above, calling :meth:`.AutomapBase.prepare` while passing along the +:paramref:`.AutomapBase.prepare.reflect` parameter indicates that the +:meth:`_schema.MetaData.reflect` +method will be called on this declarative base +classes' :class:`_schema.MetaData` collection; then, each **viable** +:class:`_schema.Table` within the :class:`_schema.MetaData` +will get a new mapped class +generated automatically. The :class:`_schema.ForeignKeyConstraint` +objects which +link the various tables together will be used to produce new, bidirectional +:func:`_orm.relationship` objects between classes. +The classes and relationships +follow along a default naming scheme that we can customize. At this point, +our basic mapping consisting of related ``User`` and ``Address`` classes is +ready to use in the traditional way. + +.. note:: By **viable**, we mean that for a table to be mapped, it must + specify a primary key. Additionally, if the table is detected as being + a pure association table between two other tables, it will not be directly + mapped and will instead be configured as a many-to-many table between + the mappings for the two referring tables. + +Generating Mappings from an Existing MetaData +============================================= + +We can pass a pre-declared :class:`_schema.MetaData` object to +:func:`.automap_base`. +This object can be constructed in any way, including programmatically, from +a serialized file, or from itself being reflected using +:meth:`_schema.MetaData.reflect`. +Below we illustrate a combination of reflection and +explicit table declaration:: + + from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey + from sqlalchemy.ext.automap import automap_base + engine = create_engine("sqlite:///mydatabase.db") + + # produce our own MetaData object + metadata = MetaData() + + # we can reflect it ourselves from a database, using options + # such as 'only' to limit what tables we look at... + metadata.reflect(engine, only=['user', 'address']) + + # ... or just define our own Table objects with it (or combine both) + Table('user_order', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', ForeignKey('user.id')) + ) + + # we can then produce a set of mappings from this MetaData. + Base = automap_base(metadata=metadata) + + # calling prepare() just sets up mapped classes and relationships. + Base.prepare() + + # mapped classes are ready + User, Address, Order = Base.classes.user, Base.classes.address,\ + Base.classes.user_order + +.. _automap_by_module: + +Generating Mappings from Multiple Schemas +========================================= + +The :meth:`.AutomapBase.prepare` method when used with reflection may reflect +tables from one schema at a time at most, using the +:paramref:`.AutomapBase.prepare.schema` parameter to indicate the name of a +schema to be reflected from. In order to populate the :class:`.AutomapBase` +with tables from multiple schemas, :meth:`.AutomapBase.prepare` may be invoked +multiple times, each time passing a different name to the +:paramref:`.AutomapBase.prepare.schema` parameter. The +:meth:`.AutomapBase.prepare` method keeps an internal list of +:class:`_schema.Table` objects that have already been mapped, and will add new +mappings only for those :class:`_schema.Table` objects that are new since the +last time :meth:`.AutomapBase.prepare` was run:: + + e = create_engine("postgresql://scott:tiger@localhost/test") + + Base.metadata.create_all(e) + + Base = automap_base() + + Base.prepare(e) + Base.prepare(e, schema="test_schema") + Base.prepare(e, schema="test_schema_2") + +.. versionadded:: 2.0 The :meth:`.AutomapBase.prepare` method may be called + any number of times; only newly added tables will be mapped + on each run. Previously in version 1.4 and earlier, multiple calls would + cause errors as it would attempt to re-map an already mapped class. + The previous workaround approach of invoking + :meth:`_schema.MetaData.reflect` directly remains available as well. + +Automapping same-named tables across multiple schemas +----------------------------------------------------- + +For the common case where multiple schemas may have same-named tables and +therefore would generate same-named classes, conflicts can be resolved either +through use of the :paramref:`.AutomapBase.prepare.classname_for_table` hook to +apply different classnames on a per-schema basis, or by using the +:paramref:`.AutomapBase.prepare.modulename_for_table` hook, which allows +disambiguation of same-named classes by changing their effective ``__module__`` +attribute. In the example below, this hook is used to create a ``__module__`` +attribute for all classes that is of the form ``mymodule.``, where +the schema name ``default`` is used if no schema is present:: + + e = create_engine("postgresql://scott:tiger@localhost/test") + + Base.metadata.create_all(e) + + def module_name_for_table(cls, tablename, table): + if table.schema is not None: + return f"mymodule.{table.schema}" + else: + return f"mymodule.default" + + Base = automap_base() + + Base.prepare(e, modulename_for_table=module_name_for_table) + Base.prepare(e, schema="test_schema", modulename_for_table=module_name_for_table) + Base.prepare(e, schema="test_schema_2", modulename_for_table=module_name_for_table) + + +The same named-classes are organized into a hierarchical collection available +at :attr:`.AutomapBase.by_module`. This collection is traversed using the +dot-separated name of a particular package/module down into the desired +class name. + +.. note:: When using the :paramref:`.AutomapBase.prepare.modulename_for_table` + hook to return a new ``__module__`` that is not ``None``, the class is + **not** placed into the :attr:`.AutomapBase.classes` collection; only + classes that were not given an explicit modulename are placed here, as the + collection cannot represent same-named classes individually. + +In the example above, if the database contained a table named ``accounts`` in +all three of the default schema, the ``test_schema`` schema, and the +``test_schema_2`` schema, three separate classes will be available as:: + + Base.by_module.mymodule.default.accounts + Base.by_module.mymodule.test_schema.accounts + Base.by_module.mymodule.test_schema_2.accounts + +The default module namespace generated for all :class:`.AutomapBase` classes is +``sqlalchemy.ext.automap``. If no +:paramref:`.AutomapBase.prepare.modulename_for_table` hook is used, the +contents of :attr:`.AutomapBase.by_module` will be entirely within the +``sqlalchemy.ext.automap`` namespace (e.g. +``MyBase.by_module.sqlalchemy.ext.automap.``), which would contain +the same series of classes as what would be seen in +:attr:`.AutomapBase.classes`. Therefore it's generally only necessary to use +:attr:`.AutomapBase.by_module` when explicit ``__module__`` conventions are +present. + +.. versionadded: 2.0 + + Added the :attr:`.AutomapBase.by_module` collection, which stores + classes within a named hierarchy based on dot-separated module names, + as well as the :paramref:`.Automap.prepare.modulename_for_table` parameter + which allows for custom ``__module__`` schemes for automapped + classes. + + + +Specifying Classes Explicitly +============================= + +.. tip:: If explicit classes are expected to be prominent in an application, + consider using :class:`.DeferredReflection` instead. + +The :mod:`.sqlalchemy.ext.automap` extension allows classes to be defined +explicitly, in a way similar to that of the :class:`.DeferredReflection` class. +Classes that extend from :class:`.AutomapBase` act like regular declarative +classes, but are not immediately mapped after their construction, and are +instead mapped when we call :meth:`.AutomapBase.prepare`. The +:meth:`.AutomapBase.prepare` method will make use of the classes we've +established based on the table name we use. If our schema contains tables +``user`` and ``address``, we can define one or both of the classes to be used:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + # pre-declare User for the 'user' table + class User(Base): + __tablename__ = 'user' + + # override schema elements like Columns + user_name = Column('name', String) + + # override relationships too, if desired. + # we must use the same name that automap would use for the + # relationship, and also must refer to the class name that automap will + # generate for "address" + address_collection = relationship("address", collection_class=set) + + # reflect + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine) + + # we still have Address generated from the tablename "address", + # but User is the same as Base.classes.User now + + Address = Base.classes.address + + u1 = session.query(User).first() + print (u1.address_collection) + + # the backref is still there: + a1 = session.query(Address).first() + print (a1.user) + +Above, one of the more intricate details is that we illustrated overriding +one of the :func:`_orm.relationship` objects that automap would have created. +To do this, we needed to make sure the names match up with what automap +would normally generate, in that the relationship name would be +``User.address_collection`` and the name of the class referred to, from +automap's perspective, is called ``address``, even though we are referring to +it as ``Address`` within our usage of this class. + +Overriding Naming Schemes +========================= + +:mod:`.sqlalchemy.ext.automap` is tasked with producing mapped classes and +relationship names based on a schema, which means it has decision points in how +these names are determined. These three decision points are provided using +functions which can be passed to the :meth:`.AutomapBase.prepare` method, and +are known as :func:`.classname_for_table`, +:func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship`. Any or all of these +functions are provided as in the example below, where we use a "camel case" +scheme for class names and a "pluralizer" for collection names using the +`Inflect `_ package:: + + import re + import inflect + + def camelize_classname(base, tablename, table): + "Produce a 'camelized' class name, e.g. " + "'words_and_underscores' -> 'WordsAndUnderscores'" + + return str(tablename[0].upper() + \ + re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:])) + + _pluralizer = inflect.engine() + def pluralize_collection(base, local_cls, referred_cls, constraint): + "Produce an 'uncamelized', 'pluralized' class name, e.g. " + "'SomeTerm' -> 'some_terms'" + + referred_name = referred_cls.__name__ + uncamelized = re.sub(r'[A-Z]', + lambda m: "_%s" % m.group(0).lower(), + referred_name)[1:] + pluralized = _pluralizer.plural(uncamelized) + return pluralized + + from sqlalchemy.ext.automap import automap_base + + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + + Base.prepare(autoload_with=engine, + classname_for_table=camelize_classname, + name_for_collection_relationship=pluralize_collection + ) + +From the above mapping, we would now have classes ``User`` and ``Address``, +where the collection from ``User`` to ``Address`` is called +``User.addresses``:: + + User, Address = Base.classes.User, Base.classes.Address + + u1 = User(addresses=[Address(email="foo@bar.com")]) + +Relationship Detection +====================== + +The vast majority of what automap accomplishes is the generation of +:func:`_orm.relationship` structures based on foreign keys. The mechanism +by which this works for many-to-one and one-to-many relationships is as +follows: + +1. A given :class:`_schema.Table`, known to be mapped to a particular class, + is examined for :class:`_schema.ForeignKeyConstraint` objects. + +2. From each :class:`_schema.ForeignKeyConstraint`, the remote + :class:`_schema.Table` + object present is matched up to the class to which it is to be mapped, + if any, else it is skipped. + +3. As the :class:`_schema.ForeignKeyConstraint` + we are examining corresponds to a + reference from the immediate mapped class, the relationship will be set up + as a many-to-one referring to the referred class; a corresponding + one-to-many backref will be created on the referred class referring + to this class. + +4. If any of the columns that are part of the + :class:`_schema.ForeignKeyConstraint` + are not nullable (e.g. ``nullable=False``), a + :paramref:`_orm.relationship.cascade` keyword argument + of ``all, delete-orphan`` will be added to the keyword arguments to + be passed to the relationship or backref. If the + :class:`_schema.ForeignKeyConstraint` reports that + :paramref:`_schema.ForeignKeyConstraint.ondelete` + is set to ``CASCADE`` for a not null or ``SET NULL`` for a nullable + set of columns, the option :paramref:`_orm.relationship.passive_deletes` + flag is set to ``True`` in the set of relationship keyword arguments. + Note that not all backends support reflection of ON DELETE. + +5. The names of the relationships are determined using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` and + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + callable functions. It is important to note that the default relationship + naming derives the name from the **the actual class name**. If you've + given a particular class an explicit name by declaring it, or specified an + alternate class naming scheme, that's the name from which the relationship + name will be derived. + +6. The classes are inspected for an existing mapped property matching these + names. If one is detected on one side, but none on the other side, + :class:`.AutomapBase` attempts to create a relationship on the missing side, + then uses the :paramref:`_orm.relationship.back_populates` + parameter in order to + point the new relationship to the other side. + +7. In the usual case where no relationship is on either side, + :meth:`.AutomapBase.prepare` produces a :func:`_orm.relationship` on the + "many-to-one" side and matches it to the other using the + :paramref:`_orm.relationship.backref` parameter. + +8. Production of the :func:`_orm.relationship` and optionally the + :func:`.backref` + is handed off to the :paramref:`.AutomapBase.prepare.generate_relationship` + function, which can be supplied by the end-user in order to augment + the arguments passed to :func:`_orm.relationship` or :func:`.backref` or to + make use of custom implementations of these functions. + +Custom Relationship Arguments +----------------------------- + +The :paramref:`.AutomapBase.prepare.generate_relationship` hook can be used +to add parameters to relationships. For most cases, we can make use of the +existing :func:`.automap.generate_relationship` function to return +the object, after augmenting the given keyword dictionary with our own +arguments. + +Below is an illustration of how to send +:paramref:`_orm.relationship.cascade` and +:paramref:`_orm.relationship.passive_deletes` +options along to all one-to-many relationships:: + + from sqlalchemy.ext.automap import generate_relationship + + def _gen_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw): + if direction is interfaces.ONETOMANY: + kw['cascade'] = 'all, delete-orphan' + kw['passive_deletes'] = True + # make use of the built-in function to actually return + # the result. + return generate_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw) + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import create_engine + + # automap base + Base = automap_base() + + engine = create_engine("sqlite:///mydatabase.db") + Base.prepare(autoload_with=engine, + generate_relationship=_gen_relationship) + +Many-to-Many relationships +-------------------------- + +:mod:`.sqlalchemy.ext.automap` will generate many-to-many relationships, e.g. +those which contain a ``secondary`` argument. The process for producing these +is as follows: + +1. A given :class:`_schema.Table` is examined for + :class:`_schema.ForeignKeyConstraint` + objects, before any mapped class has been assigned to it. + +2. If the table contains two and exactly two + :class:`_schema.ForeignKeyConstraint` + objects, and all columns within this table are members of these two + :class:`_schema.ForeignKeyConstraint` objects, the table is assumed to be a + "secondary" table, and will **not be mapped directly**. + +3. The two (or one, for self-referential) external tables to which the + :class:`_schema.Table` + refers to are matched to the classes to which they will be + mapped, if any. + +4. If mapped classes for both sides are located, a many-to-many bi-directional + :func:`_orm.relationship` / :func:`.backref` + pair is created between the two + classes. + +5. The override logic for many-to-many works the same as that of one-to-many/ + many-to-one; the :func:`.generate_relationship` function is called upon + to generate the structures and existing attributes will be maintained. + +Relationships with Inheritance +------------------------------ + +:mod:`.sqlalchemy.ext.automap` will not generate any relationships between +two classes that are in an inheritance relationship. That is, with two +classes given as follows:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on': type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + __mapper_args__ = { + 'polymorphic_identity':'engineer', + } + +The foreign key from ``Engineer`` to ``Employee`` is used not for a +relationship, but to establish joined inheritance between the two classes. + +Note that this means automap will not generate *any* relationships +for foreign keys that link from a subclass to a superclass. If a mapping +has actual relationships from subclass to superclass as well, those +need to be explicit. Below, as we have two separate foreign keys +from ``Engineer`` to ``Employee``, we need to set up both the relationship +we want as well as the ``inherit_condition``, as these are not things +SQLAlchemy can guess:: + + class Employee(Base): + __tablename__ = 'employee' + id = Column(Integer, primary_key=True) + type = Column(String(50)) + + __mapper_args__ = { + 'polymorphic_identity':'employee', 'polymorphic_on':type + } + + class Engineer(Employee): + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + favorite_employee_id = Column(Integer, ForeignKey('employee.id')) + + favorite_employee = relationship(Employee, + foreign_keys=favorite_employee_id) + + __mapper_args__ = { + 'polymorphic_identity':'engineer', + 'inherit_condition': id == Employee.id + } + +Handling Simple Naming Conflicts +-------------------------------- + +In the case of naming conflicts during mapping, override any of +:func:`.classname_for_table`, :func:`.name_for_scalar_relationship`, +and :func:`.name_for_collection_relationship` as needed. For example, if +automap is attempting to name a many-to-one relationship the same as an +existing column, an alternate convention can be conditionally selected. Given +a schema: + +.. sourcecode:: sql + + CREATE TABLE table_a ( + id INTEGER PRIMARY KEY + ); + + CREATE TABLE table_b ( + id INTEGER PRIMARY KEY, + table_a INTEGER, + FOREIGN KEY(table_a) REFERENCES table_a(id) + ); + +The above schema will first automap the ``table_a`` table as a class named +``table_a``; it will then automap a relationship onto the class for ``table_b`` +with the same name as this related class, e.g. ``table_a``. This +relationship name conflicts with the mapping column ``table_b.table_a``, +and will emit an error on mapping. + +We can resolve this conflict by using an underscore as follows:: + + def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): + name = referred_cls.__name__.lower() + local_table = local_cls.__table__ + if name in local_table.columns: + newname = name + "_" + warnings.warn( + "Already detected name %s present. using %s" % + (name, newname)) + return newname + return name + + + Base.prepare(autoload_with=engine, + name_for_scalar_relationship=name_for_scalar_relationship) + +Alternatively, we can change the name on the column side. The columns +that are mapped can be modified using the technique described at +:ref:`mapper_column_distinct_names`, by assigning the column explicitly +to a new name:: + + Base = automap_base() + + class TableB(Base): + __tablename__ = 'table_b' + _table_a = Column('table_a', ForeignKey('table_a.id')) + + Base.prepare(autoload_with=engine) + + +Using Automap with Explicit Declarations +======================================== + +As noted previously, automap has no dependency on reflection, and can make +use of any collection of :class:`_schema.Table` objects within a +:class:`_schema.MetaData` +collection. From this, it follows that automap can also be used +generate missing relationships given an otherwise complete model that fully +defines table metadata:: + + from sqlalchemy.ext.automap import automap_base + from sqlalchemy import Column, Integer, String, ForeignKey + + Base = automap_base() + + class User(Base): + __tablename__ = 'user' + + id = Column(Integer, primary_key=True) + name = Column(String) + + class Address(Base): + __tablename__ = 'address' + + id = Column(Integer, primary_key=True) + email = Column(String) + user_id = Column(ForeignKey('user.id')) + + # produce relationships + Base.prepare() + + # mapping is complete, with "address_collection" and + # "user" relationships + a1 = Address(email='u1') + a2 = Address(email='u2') + u1 = User(address_collection=[a1, a2]) + assert a1.user is u1 + +Above, given mostly complete ``User`` and ``Address`` mappings, the +:class:`_schema.ForeignKey` which we defined on ``Address.user_id`` allowed a +bidirectional relationship pair ``Address.user`` and +``User.address_collection`` to be generated on the mapped classes. + +Note that when subclassing :class:`.AutomapBase`, +the :meth:`.AutomapBase.prepare` method is required; if not called, the classes +we've declared are in an un-mapped state. + + +.. _automap_intercepting_columns: + +Intercepting Column Definitions +=============================== + +The :class:`_schema.MetaData` and :class:`_schema.Table` objects support an +event hook :meth:`_events.DDLEvents.column_reflect` that may be used to intercept +the information reflected about a database column before the :class:`_schema.Column` +object is constructed. For example if we wanted to map columns using a +naming convention such as ``"attr_"``, the event could +be applied as:: + + @event.listens_for(Base.metadata, "column_reflect") + def column_reflect(inspector, table, column_info): + # set column.key = "attr_" + column_info['key'] = "attr_%s" % column_info['name'].lower() + + # run reflection + Base.prepare(autoload_with=engine) + +.. versionadded:: 1.4.0b2 the :meth:`_events.DDLEvents.column_reflect` event + may be applied to a :class:`_schema.MetaData` object. + +.. seealso:: + + :meth:`_events.DDLEvents.column_reflect` + + :ref:`mapper_automated_reflection_schemes` - in the ORM mapping documentation + + +""" # noqa +from __future__ import annotations + +import dataclasses +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..orm import backref +from ..orm import declarative_base as _declarative_base +from ..orm import exc as orm_exc +from ..orm import interfaces +from ..orm import relationship +from ..orm.decl_base import _DeferredMapperConfig +from ..orm.mapper import _CONFIGURE_MUTEX +from ..schema import ForeignKeyConstraint +from ..sql import and_ +from ..util import Properties +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.base import Engine + from ..orm.base import RelationshipDirection + from ..orm.relationships import ORMBackrefArgument + from ..orm.relationships import Relationship + from ..sql.schema import Column + from ..sql.schema import MetaData + from ..sql.schema import Table + from ..util import immutabledict + + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + + +class PythonNameForTableType(Protocol): + def __call__( + self, base: Type[Any], tablename: str, table: Table + ) -> str: ... + + +def classname_for_table( + base: Type[Any], + tablename: str, + table: Table, +) -> str: + """Return the class name that should be used, given the name + of a table. + + The default implementation is:: + + return str(tablename) + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.classname_for_table` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param tablename: string name of the :class:`_schema.Table`. + + :param table: the :class:`_schema.Table` object itself. + + :return: a string class name. + + .. note:: + + In Python 2, the string used for the class name **must** be a + non-Unicode object, e.g. a ``str()`` object. The ``.name`` attribute + of :class:`_schema.Table` is typically a Python unicode subclass, + so the + ``str()`` function should be applied to this name, after accounting for + any non-ASCII characters. + + """ + return str(tablename) + + +class NameForScalarRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: ... + + +def name_for_scalar_relationship( + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: + """Return the attribute name that should be used to refer from one + class to another, for a scalar object reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + + Alternate implementations can be specified using the + :paramref:`.AutomapBase.prepare.name_for_scalar_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + + +class NameForCollectionRelationshipType(Protocol): + def __call__( + self, + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, + ) -> str: ... + + +def name_for_collection_relationship( + base: Type[Any], + local_cls: Type[Any], + referred_cls: Type[Any], + constraint: ForeignKeyConstraint, +) -> str: + """Return the attribute name that should be used to refer from one + class to another, for a collection reference. + + The default implementation is:: + + return referred_cls.__name__.lower() + "_collection" + + Alternate implementations + can be specified using the + :paramref:`.AutomapBase.prepare.name_for_collection_relationship` + parameter. + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param local_cls: the class to be mapped on the local side. + + :param referred_cls: the class to be mapped on the referring side. + + :param constraint: the :class:`_schema.ForeignKeyConstraint` that is being + inspected to produce this relationship. + + """ + return referred_cls.__name__.lower() + "_collection" + + +class GenerateRelationshipType(Protocol): + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Relationship[Any]: ... + + @overload + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> ORMBackrefArgument: ... + + def __call__( + self, + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, + ) -> Union[ORMBackrefArgument, Relationship[Any]]: ... + + +@overload +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., Relationship[Any]], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Relationship[Any]: ... + + +@overload +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Callable[..., ORMBackrefArgument], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> ORMBackrefArgument: ... + + +def generate_relationship( + base: Type[Any], + direction: RelationshipDirection, + return_fn: Union[ + Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument] + ], + attrname: str, + local_cls: Type[Any], + referred_cls: Type[Any], + **kw: Any, +) -> Union[Relationship[Any], ORMBackrefArgument]: + r"""Generate a :func:`_orm.relationship` or :func:`.backref` + on behalf of two + mapped classes. + + An alternate implementation of this function can be specified using the + :paramref:`.AutomapBase.prepare.generate_relationship` parameter. + + The default implementation of this function is as follows:: + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + :param base: the :class:`.AutomapBase` class doing the prepare. + + :param direction: indicate the "direction" of the relationship; this will + be one of :data:`.ONETOMANY`, :data:`.MANYTOONE`, :data:`.MANYTOMANY`. + + :param return_fn: the function that is used by default to create the + relationship. This will be either :func:`_orm.relationship` or + :func:`.backref`. The :func:`.backref` function's result will be used to + produce a new :func:`_orm.relationship` in a second step, + so it is critical + that user-defined implementations correctly differentiate between the two + functions, if a custom relationship function is being used. + + :param attrname: the attribute name to which this relationship is being + assigned. If the value of :paramref:`.generate_relationship.return_fn` is + the :func:`.backref` function, then this name is the name that is being + assigned to the backref. + + :param local_cls: the "local" class to which this relationship or backref + will be locally present. + + :param referred_cls: the "referred" class to which the relationship or + backref refers to. + + :param \**kw: all additional keyword arguments are passed along to the + function. + + :return: a :func:`_orm.relationship` or :func:`.backref` construct, + as dictated + by the :paramref:`.generate_relationship.return_fn` parameter. + + """ + + if return_fn is backref: + return return_fn(attrname, **kw) + elif return_fn is relationship: + return return_fn(referred_cls, **kw) + else: + raise TypeError("Unknown relationship function: %s" % return_fn) + + +ByModuleProperties = Properties[Union["ByModuleProperties", Type[Any]]] + + +class AutomapBase: + """Base class for an "automap" schema. + + The :class:`.AutomapBase` class can be compared to the "declarative base" + class that is produced by the :func:`.declarative.declarative_base` + function. In practice, the :class:`.AutomapBase` class is always used + as a mixin along with an actual declarative base. + + A new subclassable :class:`.AutomapBase` is typically instantiated + using the :func:`.automap_base` function. + + .. seealso:: + + :ref:`automap_toplevel` + + """ + + __abstract__ = True + + classes: ClassVar[Properties[Type[Any]]] + """An instance of :class:`.util.Properties` containing classes. + + This object behaves much like the ``.c`` collection on a table. Classes + are present under the name they were given, e.g.:: + + Base = automap_base() + Base.prepare(autoload_with=some_engine) + + User, Address = Base.classes.User, Base.classes.Address + + For class names that overlap with a method name of + :class:`.util.Properties`, such as ``items()``, the getitem form + is also supported:: + + Item = Base.classes["items"] + + """ + + by_module: ClassVar[ByModuleProperties] + """An instance of :class:`.util.Properties` containing a hierarchal + structure of dot-separated module names linked to classes. + + This collection is an alternative to the :attr:`.AutomapBase.classes` + collection that is useful when making use of the + :paramref:`.AutomapBase.prepare.modulename_for_table` parameter, which will + apply distinct ``__module__`` attributes to generated classes. + + The default ``__module__`` an automap-generated class is + ``sqlalchemy.ext.automap``; to access this namespace using + :attr:`.AutomapBase.by_module` looks like:: + + User = Base.by_module.sqlalchemy.ext.automap.User + + If a class had a ``__module__`` of ``mymodule.account``, accessing + this namespace looks like:: + + MyClass = Base.by_module.mymodule.account.MyClass + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`automap_by_module` + + """ + + metadata: ClassVar[MetaData] + """Refers to the :class:`_schema.MetaData` collection that will be used + for new :class:`_schema.Table` objects. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + _sa_automapbase_bookkeeping: ClassVar[_Bookkeeping] + + @classmethod + @util.deprecated_params( + engine=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.engine` parameter " + "is deprecated and will be removed in a future release. " + "Please use the " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "parameter.", + ), + reflect=( + "2.0", + "The :paramref:`_automap.AutomapBase.prepare.reflect` " + "parameter is deprecated and will be removed in a future " + "release. Reflection is enabled when " + ":paramref:`_automap.AutomapBase.prepare.autoload_with` " + "is passed.", + ), + ) + def prepare( + cls: Type[AutomapBase], + autoload_with: Optional[Engine] = None, + engine: Optional[Any] = None, + reflect: bool = False, + schema: Optional[str] = None, + classname_for_table: Optional[PythonNameForTableType] = None, + modulename_for_table: Optional[PythonNameForTableType] = None, + collection_class: Optional[Any] = None, + name_for_scalar_relationship: Optional[ + NameForScalarRelationshipType + ] = None, + name_for_collection_relationship: Optional[ + NameForCollectionRelationshipType + ] = None, + generate_relationship: Optional[GenerateRelationshipType] = None, + reflection_options: Union[ + Dict[_KT, _VT], immutabledict[_KT, _VT] + ] = util.EMPTY_DICT, + ) -> None: + """Extract mapped classes and relationships from the + :class:`_schema.MetaData` and perform mappings. + + For full documentation and examples see + :ref:`automap_basic_use`. + + :param autoload_with: an :class:`_engine.Engine` or + :class:`_engine.Connection` with which + to perform schema reflection; when specified, the + :meth:`_schema.MetaData.reflect` method will be invoked within + the scope of this method. + + :param engine: legacy; use :paramref:`.AutomapBase.autoload_with`. + Used to indicate the :class:`_engine.Engine` or + :class:`_engine.Connection` with which to reflect tables with, + if :paramref:`.AutomapBase.reflect` is True. + + :param reflect: legacy; use :paramref:`.AutomapBase.autoload_with`. + Indicates that :meth:`_schema.MetaData.reflect` should be invoked. + + :param classname_for_table: callable function which will be used to + produce new class names, given a table name. Defaults to + :func:`.classname_for_table`. + + :param modulename_for_table: callable function which will be used to + produce the effective ``__module__`` for an internally generated + class, to allow for multiple classes of the same name in a single + automap base which would be in different "modules". + + Defaults to ``None``, which will indicate that ``__module__`` will not + be set explicitly; the Python runtime will use the value + ``sqlalchemy.ext.automap`` for these classes. + + When assigning ``__module__`` to generated classes, they can be + accessed based on dot-separated module names using the + :attr:`.AutomapBase.by_module` collection. Classes that have + an explicit ``__module_`` assigned using this hook do **not** get + placed into the :attr:`.AutomapBase.classes` collection, only + into :attr:`.AutomapBase.by_module`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`automap_by_module` + + :param name_for_scalar_relationship: callable function which will be + used to produce relationship names for scalar relationships. Defaults + to :func:`.name_for_scalar_relationship`. + + :param name_for_collection_relationship: callable function which will + be used to produce relationship names for collection-oriented + relationships. Defaults to :func:`.name_for_collection_relationship`. + + :param generate_relationship: callable function which will be used to + actually generate :func:`_orm.relationship` and :func:`.backref` + constructs. Defaults to :func:`.generate_relationship`. + + :param collection_class: the Python collection class that will be used + when a new :func:`_orm.relationship` + object is created that represents a + collection. Defaults to ``list``. + + :param schema: Schema name to reflect when reflecting tables using + the :paramref:`.AutomapBase.prepare.autoload_with` parameter. The name + is passed to the :paramref:`_schema.MetaData.reflect.schema` parameter + of :meth:`_schema.MetaData.reflect`. When omitted, the default schema + in use by the database connection is used. + + .. note:: The :paramref:`.AutomapBase.prepare.schema` + parameter supports reflection of a single schema at a time. + In order to include tables from many schemas, use + multiple calls to :meth:`.AutomapBase.prepare`. + + For an overview of multiple-schema automap including the use + of additional naming conventions to resolve table name + conflicts, see the section :ref:`automap_by_module`. + + .. versionadded:: 2.0 :meth:`.AutomapBase.prepare` supports being + directly invoked any number of times, keeping track of tables + that have already been processed to avoid processing them + a second time. + + :param reflection_options: When present, this dictionary of options + will be passed to :meth:`_schema.MetaData.reflect` + to supply general reflection-specific options like ``only`` and/or + dialect-specific options like ``oracle_resolve_synonyms``. + + .. versionadded:: 1.4 + + """ + + for mr in cls.__mro__: + if "_sa_automapbase_bookkeeping" in mr.__dict__: + automap_base = cast("Type[AutomapBase]", mr) + break + else: + assert False, "Can't locate automap base in class hierarchy" + + glbls = globals() + if classname_for_table is None: + classname_for_table = glbls["classname_for_table"] + if name_for_scalar_relationship is None: + name_for_scalar_relationship = glbls[ + "name_for_scalar_relationship" + ] + if name_for_collection_relationship is None: + name_for_collection_relationship = glbls[ + "name_for_collection_relationship" + ] + if generate_relationship is None: + generate_relationship = glbls["generate_relationship"] + if collection_class is None: + collection_class = list + + if autoload_with: + reflect = True + + if engine: + autoload_with = engine + + if reflect: + assert autoload_with + opts = dict( + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + if reflection_options: + opts.update(reflection_options) + cls.metadata.reflect(autoload_with, **opts) # type: ignore[arg-type] # noqa: E501 + + with _CONFIGURE_MUTEX: + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ] = { + cast("Table", m.local_table): m + for m in _DeferredMapperConfig.classes_for_base( + cls, sort=False + ) + } + + many_to_many: List[ + Tuple[Table, Table, List[ForeignKeyConstraint], Table] + ] + many_to_many = [] + + bookkeeping = automap_base._sa_automapbase_bookkeeping + metadata_tables = cls.metadata.tables + + for table_key in set(metadata_tables).difference( + bookkeeping.table_keys + ): + table = metadata_tables[table_key] + bookkeeping.table_keys.add(table_key) + + lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table) + if lcl_m2m is not None: + assert rem_m2m is not None + assert m2m_const is not None + many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table)) + elif not table.primary_key: + continue + elif table not in table_to_map_config: + clsdict: Dict[str, Any] = {"__table__": table} + if modulename_for_table is not None: + new_module = modulename_for_table( + cls, table.name, table + ) + if new_module is not None: + clsdict["__module__"] = new_module + else: + new_module = None + + newname = classname_for_table(cls, table.name, table) + if new_module is None and newname in cls.classes: + util.warn( + "Ignoring duplicate class name " + f"'{newname}' " + "received in automap base for table " + f"{table.key} without " + "``__module__`` being set; consider using the " + "``modulename_for_table`` hook" + ) + continue + + mapped_cls = type( + newname, + (automap_base,), + clsdict, + ) + map_config = _DeferredMapperConfig.config_for_cls( + mapped_cls + ) + assert map_config.cls.__name__ == newname + if new_module is None: + cls.classes[newname] = mapped_cls + + by_module_properties: ByModuleProperties = cls.by_module + for token in map_config.cls.__module__.split("."): + if token not in by_module_properties: + by_module_properties[token] = util.Properties({}) + + props = by_module_properties[token] + + # we can assert this because the clsregistry + # module would have raised if there was a mismatch + # between modules/classes already. + # see test_cls_schema_name_conflict + assert isinstance(props, Properties) + by_module_properties = props + + by_module_properties[map_config.cls.__name__] = mapped_cls + + table_to_map_config[table] = map_config + + for map_config in table_to_map_config.values(): + _relationships_for_fks( + automap_base, + map_config, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for lcl_m2m, rem_m2m, m2m_const, table in many_to_many: + _m2m_relationship( + automap_base, + lcl_m2m, + rem_m2m, + m2m_const, + table, + table_to_map_config, + collection_class, + name_for_scalar_relationship, + name_for_collection_relationship, + generate_relationship, + ) + + for map_config in _DeferredMapperConfig.classes_for_base( + automap_base + ): + map_config.map() + + _sa_decl_prepare = True + """Indicate that the mapping of classes should be deferred. + + The presence of this attribute name indicates to declarative + that the call to mapper() should not occur immediately; instead, + information about the table and attributes to be mapped are gathered + into an internal structure called _DeferredMapperConfig. These + objects can be collected later using classes_for_base(), additional + mapping decisions can be made, and then the map() method will actually + apply the mapping. + + The only real reason this deferral of the whole + thing is needed is to support primary key columns that aren't reflected + yet when the class is declared; everything else can theoretically be + added to the mapper later. However, the _DeferredMapperConfig is a + nice interface in any case which exists at that not usually exposed point + at which declarative has the class and the Table but hasn't called + mapper() yet. + + """ + + @classmethod + def _sa_raise_deferred_config(cls) -> NoReturn: + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AutomapBase. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) + + +@dataclasses.dataclass +class _Bookkeeping: + __slots__ = ("table_keys",) + + table_keys: Set[str] + + +def automap_base( + declarative_base: Optional[Type[Any]] = None, **kw: Any +) -> Any: + r"""Produce a declarative automap base. + + This function produces a new base class that is a product of the + :class:`.AutomapBase` class as well a declarative base produced by + :func:`.declarative.declarative_base`. + + All parameters other than ``declarative_base`` are keyword arguments + that are passed directly to the :func:`.declarative.declarative_base` + function. + + :param declarative_base: an existing class produced by + :func:`.declarative.declarative_base`. When this is passed, the function + no longer invokes :func:`.declarative.declarative_base` itself, and all + other keyword arguments are ignored. + + :param \**kw: keyword arguments are passed along to + :func:`.declarative.declarative_base`. + + """ + if declarative_base is None: + Base = _declarative_base(**kw) + else: + Base = declarative_base + + return type( + Base.__name__, + (AutomapBase, Base), + { + "__abstract__": True, + "classes": util.Properties({}), + "by_module": util.Properties({}), + "_sa_automapbase_bookkeeping": _Bookkeeping(set()), + }, + ) + + +def _is_many_to_many( + automap_base: Type[Any], table: Table +) -> Tuple[ + Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]] +]: + fk_constraints = [ + const + for const in table.constraints + if isinstance(const, ForeignKeyConstraint) + ] + if len(fk_constraints) != 2: + return None, None, None + + cols: List[Column[Any]] = sum( + [ + [fk.parent for fk in fk_constraint.elements] + for fk_constraint in fk_constraints + ], + [], + ) + + if set(cols) != set(table.c): + return None, None, None + + return ( + fk_constraints[0].elements[0].column.table, + fk_constraints[1].elements[0].column.table, + fk_constraints, + ) + + +def _relationships_for_fks( + automap_base: Type[Any], + map_config: _DeferredMapperConfig, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForScalarRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: + local_table = cast("Optional[Table]", map_config.local_table) + local_cls = cast( + "Optional[Type[Any]]", map_config.cls + ) # derived from a weakref, may be None + + if local_table is None or local_cls is None: + return + for constraint in local_table.constraints: + if isinstance(constraint, ForeignKeyConstraint): + fks = constraint.elements + referred_table = fks[0].column.table + referred_cfg = table_to_map_config.get(referred_table, None) + if referred_cfg is None: + continue + referred_cls = referred_cfg.cls + + if local_cls is not referred_cls and issubclass( + local_cls, referred_cls + ): + continue + + relationship_name = name_for_scalar_relationship( + automap_base, local_cls, referred_cls, constraint + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, constraint + ) + + o2m_kws: Dict[str, Union[str, bool]] = {} + nullable = False not in {fk.parent.nullable for fk in fks} + if not nullable: + o2m_kws["cascade"] = "all, delete-orphan" + + if ( + constraint.ondelete + and constraint.ondelete.lower() == "cascade" + ): + o2m_kws["passive_deletes"] = True + else: + if ( + constraint.ondelete + and constraint.ondelete.lower() == "set null" + ): + o2m_kws["passive_deletes"] = True + + create_backref = backref_name not in referred_cfg.properties + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.ONETOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + **o2m_kws, + ) + else: + backref_obj = None + rel = generate_relationship( + automap_base, + interfaces.MANYTOONE, + relationship, + relationship_name, + local_cls, + referred_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + backref=backref_obj, + remote_side=[fk.column for fk in constraint.elements], + ) + if rel is not None: + map_config.properties[relationship_name] = rel + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.ONETOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + foreign_keys=[fk.parent for fk in constraint.elements], + back_populates=relationship_name, + collection_class=collection_class, + **o2m_kws, + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name # type: ignore[union-attr] + + +def _m2m_relationship( + automap_base: Type[Any], + lcl_m2m: Table, + rem_m2m: Table, + m2m_const: List[ForeignKeyConstraint], + table: Table, + table_to_map_config: Union[ + Dict[Optional[Table], _DeferredMapperConfig], + Dict[Table, _DeferredMapperConfig], + ], + collection_class: type, + name_for_scalar_relationship: NameForCollectionRelationshipType, + name_for_collection_relationship: NameForCollectionRelationshipType, + generate_relationship: GenerateRelationshipType, +) -> None: + map_config = table_to_map_config.get(lcl_m2m, None) + referred_cfg = table_to_map_config.get(rem_m2m, None) + if map_config is None or referred_cfg is None: + return + + local_cls = map_config.cls + referred_cls = referred_cfg.cls + + relationship_name = name_for_collection_relationship( + automap_base, local_cls, referred_cls, m2m_const[0] + ) + backref_name = name_for_collection_relationship( + automap_base, referred_cls, local_cls, m2m_const[1] + ) + + create_backref = backref_name not in referred_cfg.properties + + if table in table_to_map_config: + overlaps = "__*" + else: + overlaps = None + + if relationship_name not in map_config.properties: + if create_backref: + backref_obj = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + backref, + backref_name, + referred_cls, + local_cls, + collection_class=collection_class, + overlaps=overlaps, + ) + else: + backref_obj = None + + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + relationship_name, + local_cls, + referred_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), # type: ignore [arg-type] + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), # type: ignore [arg-type] + backref=backref_obj, + collection_class=collection_class, + ) + if rel is not None: + map_config.properties[relationship_name] = rel + + if not create_backref: + referred_cfg.properties[ + backref_name + ].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501 + elif create_backref: + rel = generate_relationship( + automap_base, + interfaces.MANYTOMANY, + relationship, + backref_name, + referred_cls, + local_cls, + overlaps=overlaps, + secondary=table, + primaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[1].elements + ), # type: ignore [arg-type] + secondaryjoin=and_( + fk.column == fk.parent for fk in m2m_const[0].elements + ), # type: ignore [arg-type] + back_populates=relationship_name, + collection_class=collection_class, + ) + if rel is not None: + referred_cfg.properties[backref_name] = rel + map_config.properties[ + relationship_name + ].back_populates = backref_name # type: ignore[union-attr] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/baked.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/baked.py new file mode 100644 index 0000000..60f7ae6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/baked.py @@ -0,0 +1,574 @@ +# ext/baked.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Baked query extension. + +Provides a creational pattern for the :class:`.query.Query` object which +allows the fully constructed object, Core select statement, and string +compiled result to be fully cached. + + +""" + +import collections.abc as collections_abc +import logging + +from .. import exc as sa_exc +from .. import util +from ..orm import exc as orm_exc +from ..orm.query import Query +from ..orm.session import Session +from ..sql import func +from ..sql import literal_column +from ..sql import util as sql_util + + +log = logging.getLogger(__name__) + + +class Bakery: + """Callable which returns a :class:`.BakedQuery`. + + This object is returned by the class method + :meth:`.BakedQuery.bakery`. It exists as an object + so that the "cache" can be easily inspected. + + .. versionadded:: 1.2 + + + """ + + __slots__ = "cls", "cache" + + def __init__(self, cls_, cache): + self.cls = cls_ + self.cache = cache + + def __call__(self, initial_fn, *args): + return self.cls(self.cache, initial_fn, args) + + +class BakedQuery: + """A builder object for :class:`.query.Query` objects.""" + + __slots__ = "steps", "_bakery", "_cache_key", "_spoiled" + + def __init__(self, bakery, initial_fn, args=()): + self._cache_key = () + self._update_cache_key(initial_fn, args) + self.steps = [initial_fn] + self._spoiled = False + self._bakery = bakery + + @classmethod + def bakery(cls, size=200, _size_alert=None): + """Construct a new bakery. + + :return: an instance of :class:`.Bakery` + + """ + + return Bakery(cls, util.LRUCache(size, size_alert=_size_alert)) + + def _clone(self): + b1 = BakedQuery.__new__(BakedQuery) + b1._cache_key = self._cache_key + b1.steps = list(self.steps) + b1._bakery = self._bakery + b1._spoiled = self._spoiled + return b1 + + def _update_cache_key(self, fn, args=()): + self._cache_key += (fn.__code__,) + args + + def __iadd__(self, other): + if isinstance(other, tuple): + self.add_criteria(*other) + else: + self.add_criteria(other) + return self + + def __add__(self, other): + if isinstance(other, tuple): + return self.with_criteria(*other) + else: + return self.with_criteria(other) + + def add_criteria(self, fn, *args): + """Add a criteria function to this :class:`.BakedQuery`. + + This is equivalent to using the ``+=`` operator to + modify a :class:`.BakedQuery` in-place. + + """ + self._update_cache_key(fn, args) + self.steps.append(fn) + return self + + def with_criteria(self, fn, *args): + """Add a criteria function to a :class:`.BakedQuery` cloned from this + one. + + This is equivalent to using the ``+`` operator to + produce a new :class:`.BakedQuery` with modifications. + + """ + return self._clone().add_criteria(fn, *args) + + def for_session(self, session): + """Return a :class:`_baked.Result` object for this + :class:`.BakedQuery`. + + This is equivalent to calling the :class:`.BakedQuery` as a + Python callable, e.g. ``result = my_baked_query(session)``. + + """ + return Result(self, session) + + def __call__(self, session): + return self.for_session(session) + + def spoil(self, full=False): + """Cancel any query caching that will occur on this BakedQuery object. + + The BakedQuery can continue to be used normally, however additional + creational functions will not be cached; they will be called + on every invocation. + + This is to support the case where a particular step in constructing + a baked query disqualifies the query from being cacheable, such + as a variant that relies upon some uncacheable value. + + :param full: if False, only functions added to this + :class:`.BakedQuery` object subsequent to the spoil step will be + non-cached; the state of the :class:`.BakedQuery` up until + this point will be pulled from the cache. If True, then the + entire :class:`_query.Query` object is built from scratch each + time, with all creational functions being called on each + invocation. + + """ + if not full and not self._spoiled: + _spoil_point = self._clone() + _spoil_point._cache_key += ("_query_only",) + self.steps = [_spoil_point._retrieve_baked_query] + self._spoiled = True + return self + + def _effective_key(self, session): + """Return the key that actually goes into the cache dictionary for + this :class:`.BakedQuery`, taking into account the given + :class:`.Session`. + + This basically means we also will include the session's query_class, + as the actual :class:`_query.Query` object is part of what's cached + and needs to match the type of :class:`_query.Query` that a later + session will want to use. + + """ + return self._cache_key + (session._query_cls,) + + def _with_lazyload_options(self, options, effective_path, cache_path=None): + """Cloning version of _add_lazyload_options.""" + q = self._clone() + q._add_lazyload_options(options, effective_path, cache_path=cache_path) + return q + + def _add_lazyload_options(self, options, effective_path, cache_path=None): + """Used by per-state lazy loaders to add options to the + "lazy load" query from a parent query. + + Creates a cache key based on given load path and query options; + if a repeatable cache key cannot be generated, the query is + "spoiled" so that it won't use caching. + + """ + + key = () + + if not cache_path: + cache_path = effective_path + + for opt in options: + if opt._is_legacy_option or opt._is_compile_state: + ck = opt._generate_cache_key() + if ck is None: + self.spoil(full=True) + else: + assert not ck[1], ( + "loader options with variable bound parameters " + "not supported with baked queries. Please " + "use new-style select() statements for cached " + "ORM queries." + ) + key += ck[0] + + self.add_criteria( + lambda q: q._with_current_path(effective_path).options(*options), + cache_path.path, + key, + ) + + def _retrieve_baked_query(self, session): + query = self._bakery.get(self._effective_key(session), None) + if query is None: + query = self._as_query(session) + self._bakery[self._effective_key(session)] = query.with_session( + None + ) + return query.with_session(session) + + def _bake(self, session): + query = self._as_query(session) + query.session = None + + # in 1.4, this is where before_compile() event is + # invoked + statement = query._statement_20() + + # if the query is not safe to cache, we still do everything as though + # we did cache it, since the receiver of _bake() assumes subqueryload + # context was set up, etc. + # + # note also we want to cache the statement itself because this + # allows the statement itself to hold onto its cache key that is + # used by the Connection, which in itself is more expensive to + # generate than what BakedQuery was able to provide in 1.3 and prior + + if statement._compile_options._bake_ok: + self._bakery[self._effective_key(session)] = ( + query, + statement, + ) + + return query, statement + + def to_query(self, query_or_session): + """Return the :class:`_query.Query` object for use as a subquery. + + This method should be used within the lambda callable being used + to generate a step of an enclosing :class:`.BakedQuery`. The + parameter should normally be the :class:`_query.Query` object that + is passed to the lambda:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery(lambda s: s.query(Address)) + main_bq += lambda q: q.filter( + sub_bq.to_query(q).exists()) + + In the case where the subquery is used in the first callable against + a :class:`.Session`, the :class:`.Session` is also accepted:: + + sub_bq = self.bakery(lambda s: s.query(User.name)) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) + + main_bq = self.bakery( + lambda s: s.query( + Address.id, sub_bq.to_query(q).scalar_subquery()) + ) + + :param query_or_session: a :class:`_query.Query` object or a class + :class:`.Session` object, that is assumed to be within the context + of an enclosing :class:`.BakedQuery` callable. + + + .. versionadded:: 1.3 + + + """ + + if isinstance(query_or_session, Session): + session = query_or_session + elif isinstance(query_or_session, Query): + session = query_or_session.session + if session is None: + raise sa_exc.ArgumentError( + "Given Query needs to be associated with a Session" + ) + else: + raise TypeError( + "Query or Session object expected, got %r." + % type(query_or_session) + ) + return self._as_query(session) + + def _as_query(self, session): + query = self.steps[0](session) + + for step in self.steps[1:]: + query = step(query) + + return query + + +class Result: + """Invokes a :class:`.BakedQuery` against a :class:`.Session`. + + The :class:`_baked.Result` object is where the actual :class:`.query.Query` + object gets created, or retrieved from the cache, + against a target :class:`.Session`, and is then invoked for results. + + """ + + __slots__ = "bq", "session", "_params", "_post_criteria" + + def __init__(self, bq, session): + self.bq = bq + self.session = session + self._params = {} + self._post_criteria = [] + + def params(self, *args, **kw): + """Specify parameters to be replaced into the string SQL statement.""" + + if len(args) == 1: + kw.update(args[0]) + elif len(args) > 0: + raise sa_exc.ArgumentError( + "params() takes zero or one positional argument, " + "which is a dictionary." + ) + self._params.update(kw) + return self + + def _using_post_criteria(self, fns): + if fns: + self._post_criteria.extend(fns) + return self + + def with_post_criteria(self, fn): + """Add a criteria function that will be applied post-cache. + + This adds a function that will be run against the + :class:`_query.Query` object after it is retrieved from the + cache. This currently includes **only** the + :meth:`_query.Query.params` and :meth:`_query.Query.execution_options` + methods. + + .. warning:: :meth:`_baked.Result.with_post_criteria` + functions are applied + to the :class:`_query.Query` + object **after** the query's SQL statement + object has been retrieved from the cache. Only + :meth:`_query.Query.params` and + :meth:`_query.Query.execution_options` + methods should be used. + + + .. versionadded:: 1.2 + + + """ + return self._using_post_criteria([fn]) + + def _as_query(self): + q = self.bq._as_query(self.session).params(self._params) + for fn in self._post_criteria: + q = fn(q) + return q + + def __str__(self): + return str(self._as_query()) + + def __iter__(self): + return self._iter().__iter__() + + def _iter(self): + bq = self.bq + + if not self.session.enable_baked_queries or bq._spoiled: + return self._as_query()._iter() + + query, statement = bq._bakery.get( + bq._effective_key(self.session), (None, None) + ) + if query is None: + query, statement = bq._bake(self.session) + + if self._params: + q = query.params(self._params) + else: + q = query + for fn in self._post_criteria: + q = fn(q) + + params = q._params + execution_options = dict(q._execution_options) + execution_options.update( + { + "_sa_orm_load_options": q.load_options, + "compiled_cache": bq._bakery, + } + ) + + result = self.session.execute( + statement, params, execution_options=execution_options + ) + if result._attributes.get("is_single_entity", False): + result = result.scalars() + + if result._attributes.get("filtered", False): + result = result.unique() + + return result + + def count(self): + """return the 'count'. + + Equivalent to :meth:`_query.Query.count`. + + Note this uses a subquery to ensure an accurate count regardless + of the structure of the original statement. + + """ + + col = func.count(literal_column("*")) + bq = self.bq.with_criteria(lambda q: q._legacy_from_self(col)) + return bq.for_session(self.session).params(self._params).scalar() + + def scalar(self): + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + Equivalent to :meth:`_query.Query.scalar`. + + """ + try: + ret = self.one() + if not isinstance(ret, collections_abc.Sequence): + return ret + return ret[0] + except orm_exc.NoResultFound: + return None + + def first(self): + """Return the first row. + + Equivalent to :meth:`_query.Query.first`. + + """ + + bq = self.bq.with_criteria(lambda q: q.slice(0, 1)) + return ( + bq.for_session(self.session) + .params(self._params) + ._using_post_criteria(self._post_criteria) + ._iter() + .first() + ) + + def one(self): + """Return exactly one result or raise an exception. + + Equivalent to :meth:`_query.Query.one`. + + """ + return self._iter().one() + + def one_or_none(self): + """Return one or zero results, or raise an exception for multiple + rows. + + Equivalent to :meth:`_query.Query.one_or_none`. + + """ + return self._iter().one_or_none() + + def all(self): + """Return all rows. + + Equivalent to :meth:`_query.Query.all`. + + """ + return self._iter().all() + + def get(self, ident): + """Retrieve an object based on identity. + + Equivalent to :meth:`_query.Query.get`. + + """ + + query = self.bq.steps[0](self.session) + return query._get_impl(ident, self._load_on_pk_identity) + + def _load_on_pk_identity(self, session, query, primary_key_identity, **kw): + """Load the given primary key identity from the database.""" + + mapper = query._raw_columns[0]._annotations["parententity"] + + _get_clause, _get_params = mapper._get_clause + + def setup(query): + _lcl_get_clause = _get_clause + q = query._clone() + q._get_condition() + q._order_by = None + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in primary_key_identity: + nones = { + _get_params[col].key + for col, value in zip( + mapper.primary_key, primary_key_identity + ) + if value is None + } + _lcl_get_clause = sql_util.adapt_criterion_to_null( + _lcl_get_clause, nones + ) + + # TODO: can mapper._get_clause be pre-adapted? + q._where_criteria = ( + sql_util._deep_annotate(_lcl_get_clause, {"_orm_adapt": True}), + ) + + for fn in self._post_criteria: + q = fn(q) + return q + + # cache the query against a key that includes + # which positions in the primary key are NULL + # (remember, we can map to an OUTER JOIN) + bq = self.bq + + # add the clause we got from mapper._get_clause to the cache + # key so that if a race causes multiple calls to _get_clause, + # we've cached on ours + bq = bq._clone() + bq._cache_key += (_get_clause,) + + bq = bq.with_criteria( + setup, tuple(elem is None for elem in primary_key_identity) + ) + + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } + + result = list(bq.for_session(self.session).params(**params)) + l = len(result) + if l > 1: + raise orm_exc.MultipleResultsFound() + elif l: + return result[0] + else: + return None + + +bakery = BakedQuery.bakery diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py new file mode 100644 index 0000000..01462ad --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/compiler.py @@ -0,0 +1,555 @@ +# ext/compiler.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +r"""Provides an API for creation of custom ClauseElements and compilers. + +Synopsis +======== + +Usage involves the creation of one or more +:class:`~sqlalchemy.sql.expression.ClauseElement` subclasses and one or +more callables defining its compilation:: + + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.sql.expression import ColumnClause + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, ``MyColumn`` extends :class:`~sqlalchemy.sql.expression.ColumnClause`, +the base expression element for named column objects. The ``compiles`` +decorator registers itself with the ``MyColumn`` class so that it is invoked +when the object is compiled to a string:: + + from sqlalchemy import select + + s = select(MyColumn('x'), MyColumn('y')) + print(str(s)) + +Produces:: + + SELECT [x], [y] + +Dialect-specific compilation rules +================================== + +Compilers can also be made dialect-specific. The appropriate compiler will be +invoked for the dialect in use:: + + from sqlalchemy.schema import DDLElement + + class AlterColumn(DDLElement): + inherit_cache = False + + def __init__(self, column, cmd): + self.column = column + self.cmd = cmd + + @compiles(AlterColumn) + def visit_alter_column(element, compiler, **kw): + return "ALTER COLUMN %s ..." % element.column.name + + @compiles(AlterColumn, 'postgresql') + def visit_alter_column(element, compiler, **kw): + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, + element.column.name) + +The second ``visit_alter_table`` will be invoked when any ``postgresql`` +dialect is used. + +.. _compilerext_compiling_subelements: + +Compiling sub-elements of a custom expression construct +======================================================= + +The ``compiler`` argument is the +:class:`~sqlalchemy.engine.interfaces.Compiled` object in use. This object +can be inspected for any information about the in-progress compilation, +including ``compiler.dialect``, ``compiler.statement`` etc. The +:class:`~sqlalchemy.sql.compiler.SQLCompiler` and +:class:`~sqlalchemy.sql.compiler.DDLCompiler` both include a ``process()`` +method which can be used for compilation of embedded attributes:: + + from sqlalchemy.sql.expression import Executable, ClauseElement + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + + insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5)) + print(insert) + +Produces:: + + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z + FROM mytable WHERE mytable.x > :x_1)" + +.. note:: + + The above ``InsertFromSelect`` construct is only an example, this actual + functionality is already available using the + :meth:`_expression.Insert.from_select` method. + + +Cross Compiling between SQL and DDL compilers +--------------------------------------------- + +SQL and DDL constructs are each compiled using different base compilers - +``SQLCompiler`` and ``DDLCompiler``. A common need is to access the +compilation rules of SQL expressions from within a DDL expression. The +``DDLCompiler`` includes an accessor ``sql_compiler`` for this reason, such as +below where we generate a CHECK constraint that embeds a SQL expression:: + + @compiles(MyConstraint) + def compile_my_constraint(constraint, ddlcompiler, **kw): + kw['literal_binds'] = True + return "CONSTRAINT %s CHECK (%s)" % ( + constraint.name, + ddlcompiler.sql_compiler.process( + constraint.expression, **kw) + ) + +Above, we add an additional flag to the process step as called by +:meth:`.SQLCompiler.process`, which is the ``literal_binds`` flag. This +indicates that any SQL expression which refers to a :class:`.BindParameter` +object or other "literal" object such as those which refer to strings or +integers should be rendered **in-place**, rather than being referred to as +a bound parameter; when emitting DDL, bound parameters are typically not +supported. + + +Changing the default compilation of existing constructs +======================================================= + +The compiler extension applies just as well to the existing constructs. When +overriding the compilation of a built in SQL construct, the @compiles +decorator is invoked upon the appropriate class (be sure to use the class, +i.e. ``Insert`` or ``Select``, instead of the creation function such +as ``insert()`` or ``select()``). + +Within the new compilation function, to get at the "original" compilation +routine, use the appropriate visit_XXX method - this +because compiler.process() will call upon the overriding routine and cause +an endless loop. Such as, to add "prefix" to all insert statements:: + + from sqlalchemy.sql.expression import Insert + + @compiles(Insert) + def prefix_inserts(insert, compiler, **kw): + return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) + +The above compiler will prefix all INSERT statements with "some prefix" when +compiled. + +.. _type_compilation_extension: + +Changing Compilation of Types +============================= + +``compiler`` works for types, too, such as below where we implement the +MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: + + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') + def compile_varchar(element, compiler, **kw): + if element.length == 'max': + return "VARCHAR('max')" + else: + return compiler.visit_VARCHAR(element, **kw) + + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) + +Subclassing Guidelines +====================== + +A big part of using the compiler extension is subclassing SQLAlchemy +expression constructs. To make this easier, the expression and +schema packages feature a set of "bases" intended for common tasks. +A synopsis is as follows: + +* :class:`~sqlalchemy.sql.expression.ClauseElement` - This is the root + expression class. Any SQL expression can be derived from this base, and is + probably the best choice for longer constructs such as specialized INSERT + statements. + +* :class:`~sqlalchemy.sql.expression.ColumnElement` - The root of all + "column-like" elements. Anything that you'd place in the "columns" clause of + a SELECT statement (as well as order by and group by) can derive from this - + the object will automatically have Python "comparison" behavior. + + :class:`~sqlalchemy.sql.expression.ColumnElement` classes want to have a + ``type`` member which is expression's return type. This can be established + at the instance level in the constructor, or at the class level if its + generally constant:: + + class timestamp(ColumnElement): + type = TIMESTAMP() + inherit_cache = True + +* :class:`~sqlalchemy.sql.functions.FunctionElement` - This is a hybrid of a + ``ColumnElement`` and a "from clause" like object, and represents a SQL + function or stored procedure type of call. Since most databases support + statements along the line of "SELECT FROM " + ``FunctionElement`` adds in the ability to be used in the FROM clause of a + ``select()`` construct:: + + from sqlalchemy.sql.expression import FunctionElement + + class coalesce(FunctionElement): + name = 'coalesce' + inherit_cache = True + + @compiles(coalesce) + def compile(element, compiler, **kw): + return "coalesce(%s)" % compiler.process(element.clauses, **kw) + + @compiles(coalesce, 'oracle') + def compile(element, compiler, **kw): + if len(element.clauses) > 2: + raise TypeError("coalesce only supports two arguments on Oracle") + return "nvl(%s)" % compiler.process(element.clauses, **kw) + +* :class:`.ExecutableDDLElement` - The root of all DDL expressions, + like CREATE TABLE, ALTER TABLE, etc. Compilation of + :class:`.ExecutableDDLElement` subclasses is issued by a + :class:`.DDLCompiler` instead of a :class:`.SQLCompiler`. + :class:`.ExecutableDDLElement` can also be used as an event hook in + conjunction with event hooks like :meth:`.DDLEvents.before_create` and + :meth:`.DDLEvents.after_create`, allowing the construct to be invoked + automatically during CREATE TABLE and DROP TABLE sequences. + + .. seealso:: + + :ref:`metadata_ddl_toplevel` - contains examples of associating + :class:`.DDL` objects (which are themselves :class:`.ExecutableDDLElement` + instances) with :class:`.DDLEvents` event hooks. + +* :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which + should be used with any expression class that represents a "standalone" + SQL statement that can be passed directly to an ``execute()`` method. It + is already implicit within ``DDLElement`` and ``FunctionElement``. + +Most of the above constructs also respond to SQL statement caching. A +subclassed construct will want to define the caching behavior for the object, +which usually means setting the flag ``inherit_cache`` to the value of +``False`` or ``True``. See the next section :ref:`compilerext_caching` +for background. + + +.. _compilerext_caching: + +Enabling Caching Support for Custom Constructs +============================================== + +SQLAlchemy as of version 1.4 includes a +:ref:`SQL compilation caching facility ` which will allow +equivalent SQL constructs to cache their stringified form, along with other +structural information used to fetch results from the statement. + +For reasons discussed at :ref:`caching_caveats`, the implementation of this +caching system takes a conservative approach towards including custom SQL +constructs and/or subclasses within the caching system. This includes that +any user-defined SQL constructs, including all the examples for this +extension, will not participate in caching by default unless they positively +assert that they are able to do so. The :attr:`.HasCacheKey.inherit_cache` +attribute when set to ``True`` at the class level of a specific subclass +will indicate that instances of this class may be safely cached, using the +cache key generation scheme of the immediate superclass. This applies +for example to the "synopsis" example indicated previously:: + + class MyColumn(ColumnClause): + inherit_cache = True + + @compiles(MyColumn) + def compile_mycolumn(element, compiler, **kw): + return "[%s]" % element.name + +Above, the ``MyColumn`` class does not include any new state that +affects its SQL compilation; the cache key of ``MyColumn`` instances will +make use of that of the ``ColumnClause`` superclass, meaning it will take +into account the class of the object (``MyColumn``), the string name and +datatype of the object:: + + >>> MyColumn("some_name", String())._generate_cache_key() + CacheKey( + key=('0', , + 'name', 'some_name', + 'type', (, + ('length', None), ('collation', None)) + ), bindparams=[]) + +For objects that are likely to be **used liberally as components within many +larger statements**, such as :class:`_schema.Column` subclasses and custom SQL +datatypes, it's important that **caching be enabled as much as possible**, as +this may otherwise negatively affect performance. + +An example of an object that **does** contain state which affects its SQL +compilation is the one illustrated at :ref:`compilerext_compiling_subelements`; +this is an "INSERT FROM SELECT" construct that combines together a +:class:`_schema.Table` as well as a :class:`_sql.Select` construct, each of +which independently affect the SQL string generation of the construct. For +this class, the example illustrates that it simply does not participate in +caching:: + + class InsertFromSelect(Executable, ClauseElement): + inherit_cache = False + + def __init__(self, table, select): + self.table = table + self.select = select + + @compiles(InsertFromSelect) + def visit_insert_from_select(element, compiler, **kw): + return "INSERT INTO %s (%s)" % ( + compiler.process(element.table, asfrom=True, **kw), + compiler.process(element.select, **kw) + ) + +While it is also possible that the above ``InsertFromSelect`` could be made to +produce a cache key that is composed of that of the :class:`_schema.Table` and +:class:`_sql.Select` components together, the API for this is not at the moment +fully public. However, for an "INSERT FROM SELECT" construct, which is only +used by itself for specific operations, caching is not as critical as in the +previous example. + +For objects that are **used in relative isolation and are generally +standalone**, such as custom :term:`DML` constructs like an "INSERT FROM +SELECT", **caching is generally less critical** as the lack of caching for such +a construct will have only localized implications for that specific operation. + + +Further Examples +================ + +"UTC timestamp" function +------------------------- + +A function that works like "CURRENT_TIMESTAMP" except applies the +appropriate conversions so that the time is in UTC time. Timestamps are best +stored in relational databases as UTC, without time zones. UTC so that your +database doesn't think time has gone backwards in the hour when daylight +savings ends, without timezones because timezones are like character +encodings - they're best applied only at the endpoints of an application +(i.e. convert to UTC upon user input, re-apply desired timezone upon display). + +For PostgreSQL and Microsoft SQL Server:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import DateTime + + class utcnow(expression.FunctionElement): + type = DateTime() + inherit_cache = True + + @compiles(utcnow, 'postgresql') + def pg_utcnow(element, compiler, **kw): + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + @compiles(utcnow, 'mssql') + def ms_utcnow(element, compiler, **kw): + return "GETUTCDATE()" + +Example usage:: + + from sqlalchemy import ( + Table, Column, Integer, String, DateTime, MetaData + ) + metadata = MetaData() + event = Table("event", metadata, + Column("id", Integer, primary_key=True), + Column("description", String(50), nullable=False), + Column("timestamp", DateTime, server_default=utcnow()) + ) + +"GREATEST" function +------------------- + +The "GREATEST" function is given any number of arguments and returns the one +that is of the highest value - its equivalent to Python's ``max`` +function. A SQL standard version versus a CASE based version which only +accommodates two arguments:: + + from sqlalchemy.sql import expression, case + from sqlalchemy.ext.compiler import compiles + from sqlalchemy.types import Numeric + + class greatest(expression.FunctionElement): + type = Numeric() + name = 'greatest' + inherit_cache = True + + @compiles(greatest) + def default_greatest(element, compiler, **kw): + return compiler.visit_function(element) + + @compiles(greatest, 'sqlite') + @compiles(greatest, 'mssql') + @compiles(greatest, 'oracle') + def case_greatest(element, compiler, **kw): + arg1, arg2 = list(element.clauses) + return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw) + +Example usage:: + + Session.query(Account).\ + filter( + greatest( + Account.checking_balance, + Account.savings_balance) > 10000 + ) + +"false" expression +------------------ + +Render a "false" constant expression, rendering as "0" on platforms that +don't have a "false" constant:: + + from sqlalchemy.sql import expression + from sqlalchemy.ext.compiler import compiles + + class sql_false(expression.ColumnElement): + inherit_cache = True + + @compiles(sql_false) + def default_false(element, compiler, **kw): + return "false" + + @compiles(sql_false, 'mssql') + @compiles(sql_false, 'mysql') + @compiles(sql_false, 'oracle') + def int_false(element, compiler, **kw): + return "0" + +Example usage:: + + from sqlalchemy import select, union_all + + exp = union_all( + select(users.c.name, sql_false().label("enrolled")), + select(customers.c.name, customers.c.enrolled) + ) + +""" +from .. import exc +from ..sql import sqltypes + + +def compiles(class_, *specs): + """Register a function as a compiler for a + given :class:`_expression.ClauseElement` type.""" + + def decorate(fn): + # get an existing @compiles handler + existing = class_.__dict__.get("_compiler_dispatcher", None) + + # get the original handler. All ClauseElement classes have one + # of these, but some TypeEngine classes will not. + existing_dispatch = getattr(class_, "_compiler_dispatch", None) + + if not existing: + existing = _dispatcher() + + if existing_dispatch: + + def _wrap_existing_dispatch(element, compiler, **kw): + try: + return existing_dispatch(element, compiler, **kw) + except exc.UnsupportedCompilationError as uce: + raise exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ) from uce + + existing.specs["default"] = _wrap_existing_dispatch + + # TODO: why is the lambda needed ? + setattr( + class_, + "_compiler_dispatch", + lambda *arg, **kw: existing(*arg, **kw), + ) + setattr(class_, "_compiler_dispatcher", existing) + + if specs: + for s in specs: + existing.specs[s] = fn + + else: + existing.specs["default"] = fn + return fn + + return decorate + + +def deregister(class_): + """Remove all custom compilers associated with a given + :class:`_expression.ClauseElement` type. + + """ + + if hasattr(class_, "_compiler_dispatcher"): + class_._compiler_dispatch = class_._original_compiler_dispatch + del class_._compiler_dispatcher + + +class _dispatcher: + def __init__(self): + self.specs = {} + + def __call__(self, element, compiler, **kw): + # TODO: yes, this could also switch off of DBAPI in use. + fn = self.specs.get(compiler.dialect.name, None) + if not fn: + try: + fn = self.specs["default"] + except KeyError as ke: + raise exc.UnsupportedCompilationError( + compiler, + type(element), + message="%s construct has no default " + "compilation handler." % type(element), + ) from ke + + # if compilation includes add_to_result_map, collect add_to_result_map + # arguments from the user-defined callable, which are probably none + # because this is not public API. if it wasn't called, then call it + # ourselves. + arm = kw.get("add_to_result_map", None) + if arm: + arm_collection = [] + kw["add_to_result_map"] = lambda *args: arm_collection.append(args) + + expr = fn(element, compiler, **kw) + + if arm: + if not arm_collection: + arm_collection.append( + (None, None, (element,), sqltypes.NULLTYPE) + ) + for tup in arm_collection: + arm(*tup) + return expr diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__init__.py new file mode 100644 index 0000000..37da403 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__init__.py @@ -0,0 +1,65 @@ +# ext/declarative/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from .extensions import AbstractConcreteBase +from .extensions import ConcreteBase +from .extensions import DeferredReflection +from ... import util +from ...orm.decl_api import as_declarative as _as_declarative +from ...orm.decl_api import declarative_base as _declarative_base +from ...orm.decl_api import DeclarativeMeta +from ...orm.decl_api import declared_attr +from ...orm.decl_api import has_inherited_table as _has_inherited_table +from ...orm.decl_api import synonym_for as _synonym_for + + +@util.moved_20( + "The ``declarative_base()`` function is now available as " + ":func:`sqlalchemy.orm.declarative_base`." +) +def declarative_base(*arg, **kw): + return _declarative_base(*arg, **kw) + + +@util.moved_20( + "The ``as_declarative()`` function is now available as " + ":func:`sqlalchemy.orm.as_declarative`" +) +def as_declarative(*arg, **kw): + return _as_declarative(*arg, **kw) + + +@util.moved_20( + "The ``has_inherited_table()`` function is now available as " + ":func:`sqlalchemy.orm.has_inherited_table`." +) +def has_inherited_table(*arg, **kw): + return _has_inherited_table(*arg, **kw) + + +@util.moved_20( + "The ``synonym_for()`` function is now available as " + ":func:`sqlalchemy.orm.synonym_for`" +) +def synonym_for(*arg, **kw): + return _synonym_for(*arg, **kw) + + +__all__ = [ + "declarative_base", + "synonym_for", + "has_inherited_table", + "instrument_declarative", + "declared_attr", + "as_declarative", + "ConcreteBase", + "AbstractConcreteBase", + "DeclarativeMeta", + "DeferredReflection", +] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..3b81c5f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/extensions.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/extensions.cpython-311.pyc new file mode 100644 index 0000000..198346b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/__pycache__/extensions.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/extensions.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/extensions.py new file mode 100644 index 0000000..c0f7e34 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/declarative/extensions.py @@ -0,0 +1,548 @@ +# ext/declarative/extensions.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Public API functions and helpers for declarative.""" +from __future__ import annotations + +import collections +import contextlib +from typing import Any +from typing import Callable +from typing import TYPE_CHECKING +from typing import Union + +from ... import exc as sa_exc +from ...engine import Connection +from ...engine import Engine +from ...orm import exc as orm_exc +from ...orm import relationships +from ...orm.base import _mapper_or_none +from ...orm.clsregistry import _resolver +from ...orm.decl_base import _DeferredMapperConfig +from ...orm.util import polymorphic_union +from ...schema import Table +from ...util import OrderedDict + +if TYPE_CHECKING: + from ...sql.schema import MetaData + + +class ConcreteBase: + """A helper class for 'concrete' declarative mappings. + + :class:`.ConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_last__()`` function, which is essentially + a hook for the :meth:`.after_configured` event. + + :class:`.ConcreteBase` produces a mapped + table for the class itself. Compare to :class:`.AbstractConcreteBase`, + which does not. + + Example:: + + from sqlalchemy.ext.declarative import ConcreteBase + + class Employee(ConcreteBase, Base): + __tablename__ = 'employee' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + __mapper_args__ = { + 'polymorphic_identity':'employee', + 'concrete':True} + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True} + + + The name of the discriminator column used by :func:`.polymorphic_union` + defaults to the name ``type``. To suit the use case of a mapping where an + actual column in a mapped table is already named ``type``, the + discriminator name can be configured by setting the + ``_concrete_discriminator_name`` attribute:: + + class Employee(ConcreteBase, Base): + _concrete_discriminator_name = '_concrete_discriminator' + + .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name`` + attribute to :class:`_declarative.ConcreteBase` so that the + virtual discriminator column name can be customized. + + .. versionchanged:: 1.4.2 The ``_concrete_discriminator_name`` attribute + need only be placed on the basemost class to take correct effect for + all subclasses. An explicit error message is now raised if the + mapped column names conflict with the discriminator name, whereas + in the 1.3.x series there would be some warnings and then a non-useful + query would be generated. + + .. seealso:: + + :class:`.AbstractConcreteBase` + + :ref:`concrete_inheritance` + + + """ + + @classmethod + def _create_polymorphic_union(cls, mappers, discriminator_name): + return polymorphic_union( + OrderedDict( + (mp.polymorphic_identity, mp.local_table) for mp in mappers + ), + discriminator_name, + "pjoin", + ) + + @classmethod + def __declare_first__(cls): + m = cls.__mapper__ + if m.with_polymorphic: + return + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + + mappers = list(m.self_and_descendants) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + m._set_with_polymorphic(("*", pjoin)) + m._set_polymorphic_on(pjoin.c[discriminator_name]) + + +class AbstractConcreteBase(ConcreteBase): + """A helper class for 'concrete' declarative mappings. + + :class:`.AbstractConcreteBase` will use the :func:`.polymorphic_union` + function automatically, against all tables mapped as a subclass + to this class. The function is called via the + ``__declare_first__()`` function, which is essentially + a hook for the :meth:`.before_configured` event. + + :class:`.AbstractConcreteBase` applies :class:`_orm.Mapper` for its + immediately inheriting class, as would occur for any other + declarative mapped class. However, the :class:`_orm.Mapper` is not + mapped to any particular :class:`.Table` object. Instead, it's + mapped directly to the "polymorphic" selectable produced by + :func:`.polymorphic_union`, and performs no persistence operations on its + own. Compare to :class:`.ConcreteBase`, which maps its + immediately inheriting class to an actual + :class:`.Table` that stores rows directly. + + .. note:: + + The :class:`.AbstractConcreteBase` delays the mapper creation of the + base class until all the subclasses have been defined, + as it needs to create a mapping against a selectable that will include + all subclass tables. In order to achieve this, it waits for the + **mapper configuration event** to occur, at which point it scans + through all the configured subclasses and sets up a mapping that will + query against all subclasses at once. + + While this event is normally invoked automatically, in the case of + :class:`.AbstractConcreteBase`, it may be necessary to invoke it + explicitly after **all** subclass mappings are defined, if the first + operation is to be a query against this base class. To do so, once all + the desired classes have been configured, the + :meth:`_orm.registry.configure` method on the :class:`_orm.registry` + in use can be invoked, which is available in relation to a particular + declarative base class:: + + Base.registry.configure() + + Example:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Base(DeclarativeBase): + pass + + class Employee(AbstractConcreteBase, Base): + pass + + class Manager(Employee): + __tablename__ = 'manager' + employee_id = Column(Integer, primary_key=True) + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True + } + + Base.registry.configure() + + The abstract base class is handled by declarative in a special way; + at class configuration time, it behaves like a declarative mixin + or an ``__abstract__`` base class. Once classes are configured + and mappings are produced, it then gets mapped itself, but + after all of its descendants. This is a very unique system of mapping + not found in any other SQLAlchemy API feature. + + Using this approach, we can specify columns and properties + that will take place on mapped subclasses, in the way that + we normally do as in :ref:`declarative_mixins`:: + + from sqlalchemy.ext.declarative import AbstractConcreteBase + + class Company(Base): + __tablename__ = 'company' + id = Column(Integer, primary_key=True) + + class Employee(AbstractConcreteBase, Base): + strict_attrs = True + + employee_id = Column(Integer, primary_key=True) + + @declared_attr + def company_id(cls): + return Column(ForeignKey('company.id')) + + @declared_attr + def company(cls): + return relationship("Company") + + class Manager(Employee): + __tablename__ = 'manager' + + name = Column(String(50)) + manager_data = Column(String(40)) + + __mapper_args__ = { + 'polymorphic_identity':'manager', + 'concrete':True + } + + Base.registry.configure() + + When we make use of our mappings however, both ``Manager`` and + ``Employee`` will have an independently usable ``.company`` attribute:: + + session.execute( + select(Employee).filter(Employee.company.has(id=5)) + ) + + :param strict_attrs: when specified on the base class, "strict" attribute + mode is enabled which attempts to limit ORM mapped attributes on the + base class to only those that are immediately present, while still + preserving "polymorphic" loading behavior. + + .. versionadded:: 2.0 + + .. seealso:: + + :class:`.ConcreteBase` + + :ref:`concrete_inheritance` + + :ref:`abstract_concrete_base` + + """ + + __no_table__ = True + + @classmethod + def __declare_first__(cls): + cls._sa_decl_prepare_nocascade() + + @classmethod + def _sa_decl_prepare_nocascade(cls): + if getattr(cls, "__mapper__", None): + return + + to_map = _DeferredMapperConfig.config_for_cls(cls) + + # can't rely on 'self_and_descendants' here + # since technically an immediate subclass + # might not be mapped, but a subclass + # may be. + mappers = [] + stack = list(cls.__subclasses__()) + while stack: + klass = stack.pop() + stack.extend(klass.__subclasses__()) + mn = _mapper_or_none(klass) + if mn is not None: + mappers.append(mn) + + discriminator_name = ( + getattr(cls, "_concrete_discriminator_name", None) or "type" + ) + pjoin = cls._create_polymorphic_union(mappers, discriminator_name) + + # For columns that were declared on the class, these + # are normally ignored with the "__no_table__" mapping, + # unless they have a different attribute key vs. col name + # and are in the properties argument. + # In that case, ensure we update the properties entry + # to the correct column from the pjoin target table. + declared_cols = set(to_map.declared_columns) + declared_col_keys = {c.key for c in declared_cols} + for k, v in list(to_map.properties.items()): + if v in declared_cols: + to_map.properties[k] = pjoin.c[v.key] + declared_col_keys.remove(v.key) + + to_map.local_table = pjoin + + strict_attrs = cls.__dict__.get("strict_attrs", False) + + m_args = to_map.mapper_args_fn or dict + + def mapper_args(): + args = m_args() + args["polymorphic_on"] = pjoin.c[discriminator_name] + args["polymorphic_abstract"] = True + if strict_attrs: + args["include_properties"] = ( + set(pjoin.primary_key) + | declared_col_keys + | {discriminator_name} + ) + args["with_polymorphic"] = ("*", pjoin) + return args + + to_map.mapper_args_fn = mapper_args + + to_map.map() + + stack = [cls] + while stack: + scls = stack.pop(0) + stack.extend(scls.__subclasses__()) + sm = _mapper_or_none(scls) + if sm and sm.concrete and sm.inherits is None: + for sup_ in scls.__mro__[1:]: + sup_sm = _mapper_or_none(sup_) + if sup_sm: + sm._set_concrete_base(sup_sm) + break + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of AbstractConcreteBase and " + "has a mapping pending until all subclasses are defined. " + "Call the sqlalchemy.orm.configure_mappers() function after " + "all subclasses have been defined to " + "complete the mapping of this class." + % orm_exc._safe_cls_name(cls), + ) + + +class DeferredReflection: + """A helper class for construction of mappings based on + a deferred reflection step. + + Normally, declarative can be used with reflection by + setting a :class:`_schema.Table` object using autoload_with=engine + as the ``__table__`` attribute on a declarative class. + The caveat is that the :class:`_schema.Table` must be fully + reflected, or at the very least have a primary key column, + at the point at which a normal declarative mapping is + constructed, meaning the :class:`_engine.Engine` must be available + at class declaration time. + + The :class:`.DeferredReflection` mixin moves the construction + of mappers to be at a later point, after a specific + method is called which first reflects all :class:`_schema.Table` + objects created so far. Classes can define it as such:: + + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.declarative import DeferredReflection + Base = declarative_base() + + class MyClass(DeferredReflection, Base): + __tablename__ = 'mytable' + + Above, ``MyClass`` is not yet mapped. After a series of + classes have been defined in the above fashion, all tables + can be reflected and mappings created using + :meth:`.prepare`:: + + engine = create_engine("someengine://...") + DeferredReflection.prepare(engine) + + The :class:`.DeferredReflection` mixin can be applied to individual + classes, used as the base for the declarative base itself, + or used in a custom abstract class. Using an abstract base + allows that only a subset of classes to be prepared for a + particular prepare step, which is necessary for applications + that use more than one engine. For example, if an application + has two engines, you might use two bases, and prepare each + separately, e.g.:: + + class ReflectedOne(DeferredReflection, Base): + __abstract__ = True + + class ReflectedTwo(DeferredReflection, Base): + __abstract__ = True + + class MyClass(ReflectedOne): + __tablename__ = 'mytable' + + class MyOtherClass(ReflectedOne): + __tablename__ = 'myothertable' + + class YetAnotherClass(ReflectedTwo): + __tablename__ = 'yetanothertable' + + # ... etc. + + Above, the class hierarchies for ``ReflectedOne`` and + ``ReflectedTwo`` can be configured separately:: + + ReflectedOne.prepare(engine_one) + ReflectedTwo.prepare(engine_two) + + .. seealso:: + + :ref:`orm_declarative_reflected_deferred_reflection` - in the + :ref:`orm_declarative_table_config_toplevel` section. + + """ + + @classmethod + def prepare( + cls, bind: Union[Engine, Connection], **reflect_kw: Any + ) -> None: + r"""Reflect all :class:`_schema.Table` objects for all current + :class:`.DeferredReflection` subclasses + + :param bind: :class:`_engine.Engine` or :class:`_engine.Connection` + instance + + ..versionchanged:: 2.0.16 a :class:`_engine.Connection` is also + accepted. + + :param \**reflect_kw: additional keyword arguments passed to + :meth:`_schema.MetaData.reflect`, such as + :paramref:`_schema.MetaData.reflect.views`. + + .. versionadded:: 2.0.16 + + """ + + to_map = _DeferredMapperConfig.classes_for_base(cls) + + metadata_to_table = collections.defaultdict(set) + + # first collect the primary __table__ for each class into a + # collection of metadata/schemaname -> table names + for thingy in to_map: + if thingy.local_table is not None: + metadata_to_table[ + (thingy.local_table.metadata, thingy.local_table.schema) + ].add(thingy.local_table.name) + + # then reflect all those tables into their metadatas + + if isinstance(bind, Connection): + conn = bind + ctx = contextlib.nullcontext(enter_result=conn) + elif isinstance(bind, Engine): + ctx = bind.connect() + else: + raise sa_exc.ArgumentError( + f"Expected Engine or Connection, got {bind!r}" + ) + + with ctx as conn: + for (metadata, schema), table_names in metadata_to_table.items(): + metadata.reflect( + conn, + only=table_names, + schema=schema, + extend_existing=True, + autoload_replace=False, + **reflect_kw, + ) + + metadata_to_table.clear() + + # .map() each class, then go through relationships and look + # for secondary + for thingy in to_map: + thingy.map() + + mapper = thingy.cls.__mapper__ + metadata = mapper.class_.metadata + + for rel in mapper._props.values(): + if ( + isinstance(rel, relationships.RelationshipProperty) + and rel._init_args.secondary._is_populated() + ): + secondary_arg = rel._init_args.secondary + + if isinstance(secondary_arg.argument, Table): + secondary_table = secondary_arg.argument + metadata_to_table[ + ( + secondary_table.metadata, + secondary_table.schema, + ) + ].add(secondary_table.name) + elif isinstance(secondary_arg.argument, str): + _, resolve_arg = _resolver(rel.parent.class_, rel) + + resolver = resolve_arg( + secondary_arg.argument, True + ) + metadata_to_table[ + (metadata, thingy.local_table.schema) + ].add(secondary_arg.argument) + + resolver._resolvers += ( + cls._sa_deferred_table_resolver(metadata), + ) + + secondary_arg.argument = resolver() + + for (metadata, schema), table_names in metadata_to_table.items(): + metadata.reflect( + conn, + only=table_names, + schema=schema, + extend_existing=True, + autoload_replace=False, + ) + + @classmethod + def _sa_deferred_table_resolver( + cls, metadata: MetaData + ) -> Callable[[str], Table]: + def _resolve(key: str) -> Table: + # reflection has already occurred so this Table would have + # its contents already + return Table(key, metadata) + + return _resolve + + _sa_decl_prepare = True + + @classmethod + def _sa_raise_deferred_config(cls): + raise orm_exc.UnmappedClassError( + cls, + msg="Class %s is a subclass of DeferredReflection. " + "Mappings are not produced until the .prepare() " + "method is called on the class hierarchy." + % orm_exc._safe_cls_name(cls), + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/horizontal_shard.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/horizontal_shard.py new file mode 100644 index 0000000..d8ee819 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/horizontal_shard.py @@ -0,0 +1,481 @@ +# ext/horizontal_shard.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Horizontal sharding support. + +Defines a rudimental 'horizontal sharding' system which allows a Session to +distribute queries and persistence operations across multiple databases. + +For a usage example, see the :ref:`examples_sharding` example included in +the source distribution. + +.. deepalchemy:: The horizontal sharding extension is an advanced feature, + involving a complex statement -> database interaction as well as + use of semi-public APIs for non-trivial cases. Simpler approaches to + refering to multiple database "shards", most commonly using a distinct + :class:`_orm.Session` per "shard", should always be considered first + before using this more complex and less-production-tested system. + + + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import event +from .. import exc +from .. import inspect +from .. import util +from ..orm import PassiveFlag +from ..orm._typing import OrmExecuteOptionsParameter +from ..orm.interfaces import ORMOption +from ..orm.mapper import Mapper +from ..orm.query import Query +from ..orm.session import _BindArguments +from ..orm.session import _PKIdentityArgument +from ..orm.session import Session +from ..util.typing import Protocol +from ..util.typing import Self + +if TYPE_CHECKING: + from ..engine.base import Connection + from ..engine.base import Engine + from ..engine.base import OptionEngine + from ..engine.result import IteratorResult + from ..engine.result import Result + from ..orm import LoaderCallableStatus + from ..orm._typing import _O + from ..orm.bulk_persistence import BulkUDCompileState + from ..orm.context import QueryContext + from ..orm.session import _EntityBindKey + from ..orm.session import _SessionBind + from ..orm.session import ORMExecuteState + from ..orm.state import InstanceState + from ..sql import Executable + from ..sql._typing import _TP + from ..sql.elements import ClauseElement + +__all__ = ["ShardedSession", "ShardedQuery"] + +_T = TypeVar("_T", bound=Any) + + +ShardIdentifier = str + + +class ShardChooser(Protocol): + def __call__( + self, + mapper: Optional[Mapper[_T]], + instance: Any, + clause: Optional[ClauseElement], + ) -> Any: ... + + +class IdentityChooser(Protocol): + def __call__( + self, + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: ... + + +class ShardedQuery(Query[_T]): + """Query class used with :class:`.ShardedSession`. + + .. legacy:: The :class:`.ShardedQuery` is a subclass of the legacy + :class:`.Query` class. The :class:`.ShardedSession` now supports + 2.0 style execution via the :meth:`.ShardedSession.execute` method. + + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + assert isinstance(self.session, ShardedSession) + + self.identity_chooser = self.session.identity_chooser + self.execute_chooser = self.session.execute_chooser + self._shard_id = None + + def set_shard(self, shard_id: ShardIdentifier) -> Self: + """Return a new query, limited to a single shard ID. + + All subsequent operations with the returned query will + be against the single shard regardless of other state. + + The shard_id can be passed for a 2.0 style execution to the + bind_arguments dictionary of :meth:`.Session.execute`:: + + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} + ) + + """ + return self.execution_options(_sa_shard_id=shard_id) + + +class ShardedSession(Session): + shard_chooser: ShardChooser + identity_chooser: IdentityChooser + execute_chooser: Callable[[ORMExecuteState], Iterable[Any]] + + def __init__( + self, + shard_chooser: ShardChooser, + identity_chooser: Optional[IdentityChooser] = None, + execute_chooser: Optional[ + Callable[[ORMExecuteState], Iterable[Any]] + ] = None, + shards: Optional[Dict[str, Any]] = None, + query_cls: Type[Query[_T]] = ShardedQuery, + *, + id_chooser: Optional[ + Callable[[Query[_T], Iterable[_T]], Iterable[Any]] + ] = None, + query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, + **kwargs: Any, + ) -> None: + """Construct a ShardedSession. + + :param shard_chooser: A callable which, passed a Mapper, a mapped + instance, and possibly a SQL clause, returns a shard ID. This id + may be based off of the attributes present within the object, or on + some round-robin scheme. If the scheme is based on a selection, it + should set whatever state on the instance to mark it in the future as + participating in that shard. + + :param identity_chooser: A callable, passed a Mapper and primary key + argument, which should return a list of shard ids where this + primary key might reside. + + .. versionchanged:: 2.0 The ``identity_chooser`` parameter + supersedes the ``id_chooser`` parameter. + + :param execute_chooser: For a given :class:`.ORMExecuteState`, + returns the list of shard_ids + where the query should be issued. Results from all shards returned + will be combined together into a single listing. + + .. versionchanged:: 1.4 The ``execute_chooser`` parameter + supersedes the ``query_chooser`` parameter. + + :param shards: A dictionary of string shard names + to :class:`~sqlalchemy.engine.Engine` objects. + + """ + super().__init__(query_cls=query_cls, **kwargs) + + event.listen( + self, "do_orm_execute", execute_and_instances, retval=True + ) + self.shard_chooser = shard_chooser + + if id_chooser: + _id_chooser = id_chooser + util.warn_deprecated( + "The ``id_chooser`` parameter is deprecated; " + "please use ``identity_chooser``.", + "2.0", + ) + + def _legacy_identity_chooser( + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + q = self.query(mapper) + if lazy_loaded_from: + q = q._set_lazyload_from(lazy_loaded_from) + return _id_chooser(q, primary_key) + + self.identity_chooser = _legacy_identity_chooser + elif identity_chooser: + self.identity_chooser = identity_chooser + else: + raise exc.ArgumentError( + "identity_chooser or id_chooser is required" + ) + + if query_chooser: + _query_chooser = query_chooser + util.warn_deprecated( + "The ``query_chooser`` parameter is deprecated; " + "please use ``execute_chooser``.", + "1.4", + ) + if execute_chooser: + raise exc.ArgumentError( + "Can't pass query_chooser and execute_chooser " + "at the same time." + ) + + def _default_execute_chooser( + orm_context: ORMExecuteState, + ) -> Iterable[Any]: + return _query_chooser(orm_context.statement) + + if execute_chooser is None: + execute_chooser = _default_execute_chooser + + if execute_chooser is None: + raise exc.ArgumentError( + "execute_chooser or query_chooser is required" + ) + self.execute_chooser = execute_chooser + self.__shards: Dict[ShardIdentifier, _SessionBind] = {} + if shards is not None: + for k in shards: + self.bind_shard(k, shards[k]) + + def _identity_lookup( + self, + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Optional[Any] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Union[Optional[_O], LoaderCallableStatus]: + """override the default :meth:`.Session._identity_lookup` method so + that we search for a given non-token primary key identity across all + possible identity tokens (e.g. shard ids). + + .. versionchanged:: 1.4 Moved :meth:`.Session._identity_lookup` from + the :class:`_query.Query` object to the :class:`.Session`. + + """ + + if identity_token is not None: + obj = super()._identity_lookup( + mapper, + primary_key_identity, + identity_token=identity_token, + **kw, + ) + + return obj + else: + for shard_id in self.identity_chooser( + mapper, + primary_key_identity, + lazy_loaded_from=lazy_loaded_from, + execution_options=execution_options, + bind_arguments=dict(bind_arguments) if bind_arguments else {}, + ): + obj2 = super()._identity_lookup( + mapper, + primary_key_identity, + identity_token=shard_id, + lazy_loaded_from=lazy_loaded_from, + **kw, + ) + if obj2 is not None: + return obj2 + + return None + + def _choose_shard_and_assign( + self, + mapper: Optional[_EntityBindKey[_O]], + instance: Any, + **kw: Any, + ) -> Any: + if instance is not None: + state = inspect(instance) + if state.key: + token = state.key[2] + assert token is not None + return token + elif state.identity_token: + return state.identity_token + + assert isinstance(mapper, Mapper) + shard_id = self.shard_chooser(mapper, instance, **kw) + if instance is not None: + state.identity_token = shard_id + return shard_id + + def connection_callable( # type: ignore [override] + self, + mapper: Optional[Mapper[_T]] = None, + instance: Optional[Any] = None, + shard_id: Optional[ShardIdentifier] = None, + **kw: Any, + ) -> Connection: + """Provide a :class:`_engine.Connection` to use in the unit of work + flush process. + + """ + + if shard_id is None: + shard_id = self._choose_shard_and_assign(mapper, instance) + + if self.in_transaction(): + trans = self.get_transaction() + assert trans is not None + return trans.connection(mapper, shard_id=shard_id) + else: + bind = self.get_bind( + mapper=mapper, shard_id=shard_id, instance=instance + ) + + if isinstance(bind, Engine): + return bind.connect(**kw) + else: + assert isinstance(bind, Connection) + return bind + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + shard_id: Optional[ShardIdentifier] = None, + instance: Optional[Any] = None, + clause: Optional[ClauseElement] = None, + **kw: Any, + ) -> _SessionBind: + if shard_id is None: + shard_id = self._choose_shard_and_assign( + mapper, instance=instance, clause=clause + ) + assert shard_id is not None + return self.__shards[shard_id] + + def bind_shard( + self, shard_id: ShardIdentifier, bind: Union[Engine, OptionEngine] + ) -> None: + self.__shards[shard_id] = bind + + +class set_shard_id(ORMOption): + """a loader option for statements to apply a specific shard id to the + primary query as well as for additional relationship and column + loaders. + + The :class:`_horizontal.set_shard_id` option may be applied using + the :meth:`_sql.Executable.options` method of any executable statement:: + + stmt = ( + select(MyObject). + where(MyObject.name == 'some name'). + options(set_shard_id("shard1")) + ) + + Above, the statement when invoked will limit to the "shard1" shard + identifier for the primary query as well as for all relationship and + column loading strategies, including eager loaders such as + :func:`_orm.selectinload`, deferred column loaders like :func:`_orm.defer`, + and the lazy relationship loader :func:`_orm.lazyload`. + + In this way, the :class:`_horizontal.set_shard_id` option has much wider + scope than using the "shard_id" argument within the + :paramref:`_orm.Session.execute.bind_arguments` dictionary. + + + .. versionadded:: 2.0.0 + + """ + + __slots__ = ("shard_id", "propagate_to_loaders") + + def __init__( + self, shard_id: ShardIdentifier, propagate_to_loaders: bool = True + ): + """Construct a :class:`_horizontal.set_shard_id` option. + + :param shard_id: shard identifier + :param propagate_to_loaders: if left at its default of ``True``, the + shard option will take place for lazy loaders such as + :func:`_orm.lazyload` and :func:`_orm.defer`; if False, the option + will not be propagated to loaded objects. Note that :func:`_orm.defer` + always limits to the shard_id of the parent row in any case, so the + parameter only has a net effect on the behavior of the + :func:`_orm.lazyload` strategy. + + """ + self.shard_id = shard_id + self.propagate_to_loaders = propagate_to_loaders + + +def execute_and_instances( + orm_context: ORMExecuteState, +) -> Union[Result[_T], IteratorResult[_TP]]: + active_options: Union[ + None, + QueryContext.default_load_options, + Type[QueryContext.default_load_options], + BulkUDCompileState.default_update_options, + Type[BulkUDCompileState.default_update_options], + ] + + if orm_context.is_select: + active_options = orm_context.load_options + + elif orm_context.is_update or orm_context.is_delete: + active_options = orm_context.update_delete_options + else: + active_options = None + + session = orm_context.session + assert isinstance(session, ShardedSession) + + def iter_for_shard( + shard_id: ShardIdentifier, + ) -> Union[Result[_T], IteratorResult[_TP]]: + bind_arguments = dict(orm_context.bind_arguments) + bind_arguments["shard_id"] = shard_id + + orm_context.update_execution_options(identity_token=shard_id) + return orm_context.invoke_statement(bind_arguments=bind_arguments) + + for orm_opt in orm_context._non_compile_orm_options: + # TODO: if we had an ORMOption that gets applied at ORM statement + # execution time, that would allow this to be more generalized. + # for now just iterate and look for our options + if isinstance(orm_opt, set_shard_id): + shard_id = orm_opt.shard_id + break + else: + if active_options and active_options._identity_token is not None: + shard_id = active_options._identity_token + elif "_sa_shard_id" in orm_context.execution_options: + shard_id = orm_context.execution_options["_sa_shard_id"] + elif "shard_id" in orm_context.bind_arguments: + shard_id = orm_context.bind_arguments["shard_id"] + else: + shard_id = None + + if shard_id is not None: + return iter_for_shard(shard_id) + else: + partial = [] + for shard_id in session.execute_chooser(orm_context): + result_ = iter_for_shard(shard_id) + partial.append(result_) + return partial[0].merge(*partial[1:]) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py new file mode 100644 index 0000000..25b74d8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/hybrid.py @@ -0,0 +1,1514 @@ +# ext/hybrid.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Define attributes on ORM-mapped classes that have "hybrid" behavior. + +"hybrid" means the attribute has distinct behaviors defined at the +class level and at the instance level. + +The :mod:`~sqlalchemy.ext.hybrid` extension provides a special form of +method decorator and has minimal dependencies on the rest of SQLAlchemy. +Its basic theory of operation can work with any descriptor-based expression +system. + +Consider a mapping ``Interval``, representing integer ``start`` and ``end`` +values. We can define higher level functions on mapped classes that produce SQL +expressions at the class level, and Python expression evaluation at the +instance level. Below, each function decorated with :class:`.hybrid_method` or +:class:`.hybrid_property` may receive ``self`` as an instance of the class, or +may receive the class directly, depending on context:: + + from __future__ import annotations + + from sqlalchemy.ext.hybrid import hybrid_method + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + class Interval(Base): + __tablename__ = 'interval' + + id: Mapped[int] = mapped_column(primary_key=True) + start: Mapped[int] + end: Mapped[int] + + def __init__(self, start: int, end: int): + self.start = start + self.end = end + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @hybrid_method + def contains(self, point: int) -> bool: + return (self.start <= point) & (point <= self.end) + + @hybrid_method + def intersects(self, other: Interval) -> bool: + return self.contains(other.start) | self.contains(other.end) + + +Above, the ``length`` property returns the difference between the +``end`` and ``start`` attributes. With an instance of ``Interval``, +this subtraction occurs in Python, using normal Python descriptor +mechanics:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + +When dealing with the ``Interval`` class itself, the :class:`.hybrid_property` +descriptor evaluates the function body given the ``Interval`` class as +the argument, which when evaluated with SQLAlchemy expression mechanics +returns a new SQL expression: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(Interval.length)) + {printsql}SELECT interval."end" - interval.start AS length + FROM interval{stop} + + + >>> print(select(Interval).filter(Interval.length > 10)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval."end" - interval.start > :param_1 + +Filtering methods such as :meth:`.Select.filter_by` are supported +with hybrid attributes as well: + +.. sourcecode:: pycon+sql + + >>> print(select(Interval).filter_by(length=5)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval."end" - interval.start = :param_1 + +The ``Interval`` class example also illustrates two methods, +``contains()`` and ``intersects()``, decorated with +:class:`.hybrid_method`. This decorator applies the same idea to +methods that :class:`.hybrid_property` applies to attributes. The +methods return boolean values, and take advantage of the Python ``|`` +and ``&`` bitwise operators to produce equivalent instance-level and +SQL expression-level boolean behavior: + +.. sourcecode:: pycon+sql + + >>> i1.contains(6) + True + >>> i1.contains(15) + False + >>> i1.intersects(Interval(7, 18)) + True + >>> i1.intersects(Interval(25, 29)) + False + + >>> print(select(Interval).filter(Interval.contains(15))) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE interval.start <= :start_1 AND interval."end" > :end_1{stop} + + >>> ia = aliased(Interval) + >>> print(select(Interval, ia).filter(Interval.intersects(ia))) + {printsql}SELECT interval.id, interval.start, + interval."end", interval_1.id AS interval_1_id, + interval_1.start AS interval_1_start, interval_1."end" AS interval_1_end + FROM interval, interval AS interval_1 + WHERE interval.start <= interval_1.start + AND interval."end" > interval_1.start + OR interval.start <= interval_1."end" + AND interval."end" > interval_1."end"{stop} + +.. _hybrid_distinct_expression: + +Defining Expression Behavior Distinct from Attribute Behavior +-------------------------------------------------------------- + +In the previous section, our usage of the ``&`` and ``|`` bitwise operators +within the ``Interval.contains`` and ``Interval.intersects`` methods was +fortunate, considering our functions operated on two boolean values to return a +new one. In many cases, the construction of an in-Python function and a +SQLAlchemy SQL expression have enough differences that two separate Python +expressions should be defined. The :mod:`~sqlalchemy.ext.hybrid` decorator +defines a **modifier** :meth:`.hybrid_property.expression` for this purpose. As an +example we'll define the radius of the interval, which requires the usage of +the absolute value function:: + + from sqlalchemy import ColumnElement + from sqlalchemy import Float + from sqlalchemy import func + from sqlalchemy import type_coerce + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + +In the above example, the :class:`.hybrid_property` first assigned to the +name ``Interval.radius`` is amended by a subsequent method called +``Interval._radius_expression``, using the decorator +``@radius.inplace.expression``, which chains together two modifiers +:attr:`.hybrid_property.inplace` and :attr:`.hybrid_property.expression`. +The use of :attr:`.hybrid_property.inplace` indicates that the +:meth:`.hybrid_property.expression` modifier should mutate the +existing hybrid object at ``Interval.radius`` in place, without creating a +new object. Notes on this modifier and its +rationale are discussed in the next section :ref:`hybrid_pep484_naming`. +The use of ``@classmethod`` is optional, and is strictly to give typing +tools a hint that ``cls`` in this case is expected to be the ``Interval`` +class, and not an instance of ``Interval``. + +.. note:: :attr:`.hybrid_property.inplace` as well as the use of ``@classmethod`` + for proper typing support are available as of SQLAlchemy 2.0.4, and will + not work in earlier versions. + +With ``Interval.radius`` now including an expression element, the SQL +function ``ABS()`` is returned when accessing ``Interval.radius`` +at the class level: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(Interval).filter(Interval.radius > 5)) + {printsql}SELECT interval.id, interval.start, interval."end" + FROM interval + WHERE abs(interval."end" - interval.start) / :abs_1 > :param_1 + + +.. _hybrid_pep484_naming: + +Using ``inplace`` to create pep-484 compliant hybrid properties +--------------------------------------------------------------- + +In the previous section, a :class:`.hybrid_property` decorator is illustrated +which includes two separate method-level functions being decorated, both +to produce a single object attribute referenced as ``Interval.radius``. +There are actually several different modifiers we can use for +:class:`.hybrid_property` including :meth:`.hybrid_property.expression`, +:meth:`.hybrid_property.setter` and :meth:`.hybrid_property.update_expression`. + +SQLAlchemy's :class:`.hybrid_property` decorator intends that adding on these +methods may be done in the identical manner as Python's built-in +``@property`` decorator, where idiomatic use is to continue to redefine the +attribute repeatedly, using the **same attribute name** each time, as in the +example below that illustrates the use of :meth:`.hybrid_property.setter` and +:meth:`.hybrid_property.expression` for the ``Interval.radius`` descriptor:: + + # correct use, however is not accepted by pep-484 tooling + + class Interval(Base): + # ... + + @hybrid_property + def radius(self): + return abs(self.length) / 2 + + @radius.setter + def radius(self, value): + self.length = value * 2 + + @radius.expression + def radius(cls): + return type_coerce(func.abs(cls.length) / 2, Float) + +Above, there are three ``Interval.radius`` methods, but as each are decorated, +first by the :class:`.hybrid_property` decorator and then by the +``@radius`` name itself, the end effect is that ``Interval.radius`` is +a single attribute with three different functions contained within it. +This style of use is taken from `Python's documented use of @property +`_. +It is important to note that the way both ``@property`` as well as +:class:`.hybrid_property` work, a **copy of the descriptor is made each time**. +That is, each call to ``@radius.expression``, ``@radius.setter`` etc. +make a new object entirely. This allows the attribute to be re-defined in +subclasses without issue (see :ref:`hybrid_reuse_subclass` later in this +section for how this is used). + +However, the above approach is not compatible with typing tools such as +mypy and pyright. Python's own ``@property`` decorator does not have this +limitation only because +`these tools hardcode the behavior of @property +`_, meaning this syntax +is not available to SQLAlchemy under :pep:`484` compliance. + +In order to produce a reasonable syntax while remaining typing compliant, +the :attr:`.hybrid_property.inplace` decorator allows the same +decorator to be re-used with different method names, while still producing +a single decorator under one name:: + + # correct use which is also accepted by pep-484 tooling + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + # for example only + self.length = value * 2 + + @radius.inplace.expression + @classmethod + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + +Using :attr:`.hybrid_property.inplace` further qualifies the use of the +decorator that a new copy should not be made, thereby maintaining the +``Interval.radius`` name while allowing additional methods +``Interval._radius_setter`` and ``Interval._radius_expression`` to be +differently named. + + +.. versionadded:: 2.0.4 Added :attr:`.hybrid_property.inplace` to allow + less verbose construction of composite :class:`.hybrid_property` objects + while not having to use repeated method names. Additionally allowed the + use of ``@classmethod`` within :attr:`.hybrid_property.expression`, + :attr:`.hybrid_property.update_expression`, and + :attr:`.hybrid_property.comparator` to allow typing tools to identify + ``cls`` as a class and not an instance in the method signature. + + +Defining Setters +---------------- + +The :meth:`.hybrid_property.setter` modifier allows the construction of a +custom setter method, that can modify values on the object:: + + class Interval(Base): + # ... + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @length.inplace.setter + def _length_setter(self, value: int) -> None: + self.end = self.start + value + +The ``length(self, value)`` method is now called upon set:: + + >>> i1 = Interval(5, 10) + >>> i1.length + 5 + >>> i1.length = 12 + >>> i1.end + 17 + +.. _hybrid_bulk_update: + +Allowing Bulk ORM Update +------------------------ + +A hybrid can define a custom "UPDATE" handler for when using +ORM-enabled updates, allowing the hybrid to be used in the +SET clause of the update. + +Normally, when using a hybrid with :func:`_sql.update`, the SQL +expression is used as the column that's the target of the SET. If our +``Interval`` class had a hybrid ``start_point`` that linked to +``Interval.start``, this could be substituted directly:: + + from sqlalchemy import update + stmt = update(Interval).values({Interval.start_point: 10}) + +However, when using a composite hybrid like ``Interval.length``, this +hybrid represents more than one column. We can set up a handler that will +accommodate a value passed in the VALUES expression which can affect +this, using the :meth:`.hybrid_property.update_expression` decorator. +A handler that works similarly to our setter would be:: + + from typing import List, Tuple, Any + + class Interval(Base): + # ... + + @hybrid_property + def length(self) -> int: + return self.end - self.start + + @length.inplace.setter + def _length_setter(self, value: int) -> None: + self.end = self.start + value + + @length.inplace.update_expression + def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]: + return [ + (cls.end, cls.start + value) + ] + +Above, if we use ``Interval.length`` in an UPDATE expression, we get +a hybrid SET expression: + +.. sourcecode:: pycon+sql + + + >>> from sqlalchemy import update + >>> print(update(Interval).values({Interval.length: 25})) + {printsql}UPDATE interval SET "end"=(interval.start + :start_1) + +This SET expression is accommodated by the ORM automatically. + +.. seealso:: + + :ref:`orm_expression_update_delete` - includes background on ORM-enabled + UPDATE statements + + +Working with Relationships +-------------------------- + +There's no essential difference when creating hybrids that work with +related objects as opposed to column-based data. The need for distinct +expressions tends to be greater. The two variants we'll illustrate +are the "join-dependent" hybrid, and the "correlated subquery" hybrid. + +Join-Dependent Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Consider the following declarative +mapping which relates a ``User`` to a ``SavingsAccount``:: + + from __future__ import annotations + + from decimal import Decimal + from typing import cast + from typing import List + from typing import Optional + + from sqlalchemy import ForeignKey + from sqlalchemy import Numeric + from sqlalchemy import String + from sqlalchemy import SQLColumnExpression + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass + + + class SavingsAccount(Base): + __tablename__ = 'account' + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") + + class User(Base): + __tablename__ = 'user' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) + + @hybrid_property + def balance(self) -> Optional[Decimal]: + if self.accounts: + return self.accounts[0].balance + else: + return None + + @balance.inplace.setter + def _balance_setter(self, value: Optional[Decimal]) -> None: + assert value is not None + + if not self.accounts: + account = SavingsAccount(owner=self) + else: + account = self.accounts[0] + account.balance = value + + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: + return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance) + +The above hybrid property ``balance`` works with the first +``SavingsAccount`` entry in the list of accounts for this user. The +in-Python getter/setter methods can treat ``accounts`` as a Python +list available on ``self``. + +.. tip:: The ``User.balance`` getter in the above example accesses the + ``self.acccounts`` collection, which will normally be loaded via the + :func:`.selectinload` loader strategy configured on the ``User.balance`` + :func:`_orm.relationship`. The default loader strategy when not otherwise + stated on :func:`_orm.relationship` is :func:`.lazyload`, which emits SQL on + demand. When using asyncio, on-demand loaders such as :func:`.lazyload` are + not supported, so care should be taken to ensure the ``self.accounts`` + collection is accessible to this hybrid accessor when using asyncio. + +At the expression level, it's expected that the ``User`` class will +be used in an appropriate context such that an appropriate join to +``SavingsAccount`` will be present: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(User, User.balance). + ... join(User.accounts).filter(User.balance > 5000)) + {printsql}SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" JOIN account ON "user".id = account.user_id + WHERE account.balance > :balance_1 + +Note however, that while the instance level accessors need to worry +about whether ``self.accounts`` is even present, this issue expresses +itself differently at the SQL expression level, where we basically +would use an outer join: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> from sqlalchemy import or_ + >>> print (select(User, User.balance).outerjoin(User.accounts). + ... filter(or_(User.balance < 5000, User.balance == None))) + {printsql}SELECT "user".id AS user_id, "user".name AS user_name, + account.balance AS account_balance + FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id + WHERE account.balance < :balance_1 OR account.balance IS NULL + +Correlated Subquery Relationship Hybrid +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We can, of course, forego being dependent on the enclosing query's usage +of joins in favor of the correlated subquery, which can portably be packed +into a single column expression. A correlated subquery is more portable, but +often performs more poorly at the SQL level. Using the same technique +illustrated at :ref:`mapper_column_property_sql_expressions`, +we can adjust our ``SavingsAccount`` example to aggregate the balances for +*all* accounts, and use a correlated subquery for the column expression:: + + from __future__ import annotations + + from decimal import Decimal + from typing import List + + from sqlalchemy import ForeignKey + from sqlalchemy import func + from sqlalchemy import Numeric + from sqlalchemy import select + from sqlalchemy import SQLColumnExpression + from sqlalchemy import String + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy.orm import relationship + + + class Base(DeclarativeBase): + pass + + + class SavingsAccount(Base): + __tablename__ = 'account' + id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) + balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) + + owner: Mapped[User] = relationship(back_populates="accounts") + + class User(Base): + __tablename__ = 'user' + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + + accounts: Mapped[List[SavingsAccount]] = relationship( + back_populates="owner", lazy="selectin" + ) + + @hybrid_property + def balance(self) -> Decimal: + return sum((acc.balance for acc in self.accounts), start=Decimal("0")) + + @balance.inplace.expression + @classmethod + def _balance_expression(cls) -> SQLColumnExpression[Decimal]: + return ( + select(func.sum(SavingsAccount.balance)) + .where(SavingsAccount.user_id == cls.id) + .label("total_balance") + ) + + +The above recipe will give us the ``balance`` column which renders +a correlated SELECT: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(User).filter(User.balance > 400)) + {printsql}SELECT "user".id, "user".name + FROM "user" + WHERE ( + SELECT sum(account.balance) AS sum_1 FROM account + WHERE account.user_id = "user".id + ) > :param_1 + + +.. _hybrid_custom_comparators: + +Building Custom Comparators +--------------------------- + +The hybrid property also includes a helper that allows construction of +custom comparators. A comparator object allows one to customize the +behavior of each SQLAlchemy expression operator individually. They +are useful when creating custom types that have some highly +idiosyncratic behavior on the SQL side. + +.. note:: The :meth:`.hybrid_property.comparator` decorator introduced + in this section **replaces** the use of the + :meth:`.hybrid_property.expression` decorator. + They cannot be used together. + +The example class below allows case-insensitive comparisons on the attribute +named ``word_insensitive``:: + + from __future__ import annotations + + from typing import Any + + from sqlalchemy import ColumnElement + from sqlalchemy import func + from sqlalchemy.ext.hybrid import Comparator + from sqlalchemy.ext.hybrid import hybrid_property + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class Base(DeclarativeBase): + pass + + + class CaseInsensitiveComparator(Comparator[str]): + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return func.lower(self.__clause_element__()) == func.lower(other) + + class SearchWord(Base): + __tablename__ = 'searchword' + + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] + + @hybrid_property + def word_insensitive(self) -> str: + return self.word.lower() + + @word_insensitive.inplace.comparator + @classmethod + def _word_insensitive_comparator(cls) -> CaseInsensitiveComparator: + return CaseInsensitiveComparator(cls.word) + +Above, SQL expressions against ``word_insensitive`` will apply the ``LOWER()`` +SQL function to both sides: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy import select + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + {printsql}SELECT searchword.id, searchword.word + FROM searchword + WHERE lower(searchword.word) = lower(:lower_1) + + +The ``CaseInsensitiveComparator`` above implements part of the +:class:`.ColumnOperators` interface. A "coercion" operation like +lowercasing can be applied to all comparison operations (i.e. ``eq``, +``lt``, ``gt``, etc.) using :meth:`.Operators.operate`:: + + class CaseInsensitiveComparator(Comparator): + def operate(self, op, other, **kwargs): + return op( + func.lower(self.__clause_element__()), + func.lower(other), + **kwargs, + ) + +.. _hybrid_reuse_subclass: + +Reusing Hybrid Properties across Subclasses +------------------------------------------- + +A hybrid can be referred to from a superclass, to allow modifying +methods like :meth:`.hybrid_property.getter`, :meth:`.hybrid_property.setter` +to be used to redefine those methods on a subclass. This is similar to +how the standard Python ``@property`` object works:: + + class FirstNameOnly(Base): + # ... + + first_name: Mapped[str] + + @hybrid_property + def name(self) -> str: + return self.first_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name = value + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name: Mapped[str] + + # 'inplace' is not used here; calling getter creates a copy + # of FirstNameOnly.name that is local to FirstNameLastName + @FirstNameOnly.name.getter + def name(self) -> str: + return self.first_name + ' ' + self.last_name + + @name.inplace.setter + def _name_setter(self, value: str) -> None: + self.first_name, self.last_name = value.split(' ', 1) + +Above, the ``FirstNameLastName`` class refers to the hybrid from +``FirstNameOnly.name`` to repurpose its getter and setter for the subclass. + +When overriding :meth:`.hybrid_property.expression` and +:meth:`.hybrid_property.comparator` alone as the first reference to the +superclass, these names conflict with the same-named accessors on the class- +level :class:`.QueryableAttribute` object returned at the class level. To +override these methods when referring directly to the parent class descriptor, +add the special qualifier :attr:`.hybrid_property.overrides`, which will de- +reference the instrumented attribute back to the hybrid object:: + + class FirstNameLastName(FirstNameOnly): + # ... + + last_name: Mapped[str] + + @FirstNameOnly.name.overrides.expression + @classmethod + def name(cls): + return func.concat(cls.first_name, ' ', cls.last_name) + + +Hybrid Value Objects +-------------------- + +Note in our previous example, if we were to compare the ``word_insensitive`` +attribute of a ``SearchWord`` instance to a plain Python string, the plain +Python string would not be coerced to lower case - the +``CaseInsensitiveComparator`` we built, being returned by +``@word_insensitive.comparator``, only applies to the SQL side. + +A more comprehensive form of the custom comparator is to construct a *Hybrid +Value Object*. This technique applies the target value or expression to a value +object which is then returned by the accessor in all cases. The value object +allows control of all operations upon the value as well as how compared values +are treated, both on the SQL expression side as well as the Python value side. +Replacing the previous ``CaseInsensitiveComparator`` class with a new +``CaseInsensitiveWord`` class:: + + class CaseInsensitiveWord(Comparator): + "Hybrid value representing a lower case representation of a word." + + def __init__(self, word): + if isinstance(word, basestring): + self.word = word.lower() + elif isinstance(word, CaseInsensitiveWord): + self.word = word.word + else: + self.word = func.lower(word) + + def operate(self, op, other, **kwargs): + if not isinstance(other, CaseInsensitiveWord): + other = CaseInsensitiveWord(other) + return op(self.word, other.word, **kwargs) + + def __clause_element__(self): + return self.word + + def __str__(self): + return self.word + + key = 'word' + "Label to apply to Query tuple results" + +Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may +be a SQL function, or may be a Python native. By overriding ``operate()`` and +``__clause_element__()`` to work in terms of ``self.word``, all comparison +operations will work against the "converted" form of ``word``, whether it be +SQL side or Python side. Our ``SearchWord`` class can now deliver the +``CaseInsensitiveWord`` object unconditionally from a single hybrid call:: + + class SearchWord(Base): + __tablename__ = 'searchword' + id: Mapped[int] = mapped_column(primary_key=True) + word: Mapped[str] + + @hybrid_property + def word_insensitive(self) -> CaseInsensitiveWord: + return CaseInsensitiveWord(self.word) + +The ``word_insensitive`` attribute now has case-insensitive comparison behavior +universally, including SQL expression vs. Python expression (note the Python +value is converted to lower case on the Python side here): + +.. sourcecode:: pycon+sql + + >>> print(select(SearchWord).filter_by(word_insensitive="Trucks")) + {printsql}SELECT searchword.id AS searchword_id, searchword.word AS searchword_word + FROM searchword + WHERE lower(searchword.word) = :lower_1 + +SQL expression versus SQL expression: + +.. sourcecode:: pycon+sql + + >>> from sqlalchemy.orm import aliased + >>> sw1 = aliased(SearchWord) + >>> sw2 = aliased(SearchWord) + >>> print( + ... select(sw1.word_insensitive, sw2.word_insensitive).filter( + ... sw1.word_insensitive > sw2.word_insensitive + ... ) + ... ) + {printsql}SELECT lower(searchword_1.word) AS lower_1, + lower(searchword_2.word) AS lower_2 + FROM searchword AS searchword_1, searchword AS searchword_2 + WHERE lower(searchword_1.word) > lower(searchword_2.word) + +Python only expression:: + + >>> ws1 = SearchWord(word="SomeWord") + >>> ws1.word_insensitive == "sOmEwOrD" + True + >>> ws1.word_insensitive == "XOmEwOrX" + False + >>> print(ws1.word_insensitive) + someword + +The Hybrid Value pattern is very useful for any kind of value that may have +multiple representations, such as timestamps, time deltas, units of +measurement, currencies and encrypted passwords. + +.. seealso:: + + `Hybrids and Value Agnostic Types + `_ + - on the techspot.zzzeek.org blog + + `Value Agnostic Types, Part II + `_ - + on the techspot.zzzeek.org blog + + +""" # noqa + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast +from typing import Generic +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .. import util +from ..orm import attributes +from ..orm import InspectionAttrExtensionType +from ..orm import interfaces +from ..orm import ORMDescriptor +from ..orm.attributes import QueryableAttribute +from ..sql import roles +from ..sql._typing import is_has_clause_element +from ..sql.elements import ColumnElement +from ..sql.elements import SQLCoreOperations +from ..util.typing import Concatenate +from ..util.typing import Literal +from ..util.typing import ParamSpec +from ..util.typing import Protocol +from ..util.typing import Self + +if TYPE_CHECKING: + from ..orm.interfaces import MapperProperty + from ..orm.util import AliasedInsp + from ..sql import SQLColumnExpression + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _HasClauseElement + from ..sql._typing import _InfoType + from ..sql.operators import OperatorType + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_T = TypeVar("_T", bound=Any) +_TE = TypeVar("_TE", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_T_con = TypeVar("_T_con", bound=Any, contravariant=True) + + +class HybridExtensionType(InspectionAttrExtensionType): + HYBRID_METHOD = "HYBRID_METHOD" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + + """ + + HYBRID_PROPERTY = "HYBRID_PROPERTY" + """Symbol indicating an :class:`InspectionAttr` that's + of type :class:`.hybrid_method`. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_attributes` + + """ + + +class _HybridGetterType(Protocol[_T_co]): + def __call__(s, self: Any) -> _T_co: ... + + +class _HybridSetterType(Protocol[_T_con]): + def __call__(s, self: Any, value: _T_con) -> None: ... + + +class _HybridUpdaterType(Protocol[_T_con]): + def __call__( + s, + cls: Any, + value: Union[_T_con, _ColumnExpressionArgument[_T_con]], + ) -> List[Tuple[_DMLColumnArgument, Any]]: ... + + +class _HybridDeleterType(Protocol[_T_co]): + def __call__(s, self: Any) -> None: ... + + +class _HybridExprCallableType(Protocol[_T_co]): + def __call__( + s, cls: Any + ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ... + + +class _HybridComparatorCallableType(Protocol[_T]): + def __call__(self, cls: Any) -> Comparator[_T]: ... + + +class _HybridClassLevelAccessor(QueryableAttribute[_T]): + """Describe the object returned by a hybrid_property() when + called as a class-level descriptor. + + """ + + if TYPE_CHECKING: + + def getter( + self, fget: _HybridGetterType[_T] + ) -> hybrid_property[_T]: ... + + def setter( + self, fset: _HybridSetterType[_T] + ) -> hybrid_property[_T]: ... + + def deleter( + self, fdel: _HybridDeleterType[_T] + ) -> hybrid_property[_T]: ... + + @property + def overrides(self) -> hybrid_property[_T]: ... + + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: ... + + +class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): + """A decorator which allows definition of a Python object method with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HybridExtensionType.HYBRID_METHOD + + def __init__( + self, + func: Callable[Concatenate[Any, _P], _R], + expr: Optional[ + Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ] = None, + ): + """Create a new :class:`.hybrid_method`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_method + + class SomeClass: + @hybrid_method + def value(self, x, y): + return self._value + x + y + + @value.expression + @classmethod + def value(cls, x, y): + return func.some_function(cls._value, x, y) + + """ + self.func = func + if expr is not None: + self.expression(expr) + else: + self.expression(func) # type: ignore + + @property + def inplace(self) -> Self: + """Return the inplace mutator for this :class:`.hybrid_method`. + + The :class:`.hybrid_method` class already performs "in place" mutation + when the :meth:`.hybrid_method.expression` decorator is called, + so this attribute returns Self. + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return self + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> Callable[_P, SQLCoreOperations[_R]]: ... + + @overload + def __get__( + self, instance: object, owner: Type[object] + ) -> Callable[_P, _R]: ... + + def __get__( + self, instance: Optional[object], owner: Type[object] + ) -> Union[Callable[_P, _R], Callable[_P, SQLCoreOperations[_R]]]: + if instance is None: + return self.expr.__get__(owner, owner) # type: ignore + else: + return self.func.__get__(instance, owner) # type: ignore + + def expression( + self, expr: Callable[Concatenate[Any, _P], SQLCoreOperations[_R]] + ) -> hybrid_method[_P, _R]: + """Provide a modifying decorator that defines a + SQL-expression producing method.""" + + self.expr = expr + if not self.expr.__doc__: + self.expr.__doc__ = self.func.__doc__ + return self + + +def _unwrap_classmethod(meth: _T) -> _T: + if isinstance(meth, classmethod): + return meth.__func__ # type: ignore + else: + return meth + + +class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): + """A decorator which allows definition of a Python descriptor with both + instance-level and class-level behavior. + + """ + + is_attribute = True + extension_type = HybridExtensionType.HYBRID_PROPERTY + + __name__: str + + def __init__( + self, + fget: _HybridGetterType[_T], + fset: Optional[_HybridSetterType[_T]] = None, + fdel: Optional[_HybridDeleterType[_T]] = None, + expr: Optional[_HybridExprCallableType[_T]] = None, + custom_comparator: Optional[Comparator[_T]] = None, + update_expr: Optional[_HybridUpdaterType[_T]] = None, + ): + """Create a new :class:`.hybrid_property`. + + Usage is typically via decorator:: + + from sqlalchemy.ext.hybrid import hybrid_property + + class SomeClass: + @hybrid_property + def value(self): + return self._value + + @value.setter + def value(self, value): + self._value = value + + """ + self.fget = fget + self.fset = fset + self.fdel = fdel + self.expr = _unwrap_classmethod(expr) + self.custom_comparator = _unwrap_classmethod(custom_comparator) + self.update_expr = _unwrap_classmethod(update_expr) + util.update_wrapper(self, fget) + + @overload + def __get__(self, instance: Any, owner: Literal[None]) -> Self: ... + + @overload + def __get__( + self, instance: Literal[None], owner: Type[object] + ) -> _HybridClassLevelAccessor[_T]: ... + + @overload + def __get__(self, instance: object, owner: Type[object]) -> _T: ... + + def __get__( + self, instance: Optional[object], owner: Optional[Type[object]] + ) -> Union[hybrid_property[_T], _HybridClassLevelAccessor[_T], _T]: + if owner is None: + return self + elif instance is None: + return self._expr_comparator(owner) + else: + return self.fget(instance) + + def __set__(self, instance: object, value: Any) -> None: + if self.fset is None: + raise AttributeError("can't set attribute") + self.fset(instance, value) + + def __delete__(self, instance: object) -> None: + if self.fdel is None: + raise AttributeError("can't delete attribute") + self.fdel(instance) + + def _copy(self, **kw: Any) -> hybrid_property[_T]: + defaults = { + key: value + for key, value in self.__dict__.items() + if not key.startswith("_") + } + defaults.update(**kw) + return type(self)(**defaults) + + @property + def overrides(self) -> Self: + """Prefix for a method that is overriding an existing attribute. + + The :attr:`.hybrid_property.overrides` accessor just returns + this hybrid object, which when called at the class level from + a parent class, will de-reference the "instrumented attribute" + normally returned at this level, and allow modifying decorators + like :meth:`.hybrid_property.expression` and + :meth:`.hybrid_property.comparator` + to be used without conflicting with the same-named attributes + normally present on the :class:`.QueryableAttribute`:: + + class SuperClass: + # ... + + @hybrid_property + def foobar(self): + return self._foobar + + class SubClass(SuperClass): + # ... + + @SuperClass.foobar.overrides.expression + def foobar(cls): + return func.subfoobar(self._foobar) + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`hybrid_reuse_subclass` + + """ + return self + + class _InPlace(Generic[_TE]): + """A builder helper for .hybrid_property. + + .. versionadded:: 2.0.4 + + """ + + __slots__ = ("attr",) + + def __init__(self, attr: hybrid_property[_TE]): + self.attr = attr + + def _set(self, **kw: Any) -> hybrid_property[_TE]: + for k, v in kw.items(): + setattr(self.attr, k, _unwrap_classmethod(v)) + return self.attr + + def getter(self, fget: _HybridGetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fget=fget) + + def setter(self, fset: _HybridSetterType[_TE]) -> hybrid_property[_TE]: + return self._set(fset=fset) + + def deleter( + self, fdel: _HybridDeleterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(fdel=fdel) + + def expression( + self, expr: _HybridExprCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(expr=expr) + + def comparator( + self, comparator: _HybridComparatorCallableType[_TE] + ) -> hybrid_property[_TE]: + return self._set(custom_comparator=comparator) + + def update_expression( + self, meth: _HybridUpdaterType[_TE] + ) -> hybrid_property[_TE]: + return self._set(update_expr=meth) + + @property + def inplace(self) -> _InPlace[_T]: + """Return the inplace mutator for this :class:`.hybrid_property`. + + This is to allow in-place mutation of the hybrid, allowing the first + hybrid method of a certain name to be re-used in order to add + more methods without having to name those methods the same, e.g.:: + + class Interval(Base): + # ... + + @hybrid_property + def radius(self) -> float: + return abs(self.length) / 2 + + @radius.inplace.setter + def _radius_setter(self, value: float) -> None: + self.length = value * 2 + + @radius.inplace.expression + def _radius_expression(cls) -> ColumnElement[float]: + return type_coerce(func.abs(cls.length) / 2, Float) + + .. versionadded:: 2.0.4 + + .. seealso:: + + :ref:`hybrid_pep484_naming` + + """ + return hybrid_property._InPlace(self) + + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a getter method. + + .. versionadded:: 1.2 + + """ + + return self._copy(fget=fget) + + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a setter method.""" + + return self._copy(fset=fset) + + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a deletion method.""" + + return self._copy(fdel=fdel) + + def expression( + self, expr: _HybridExprCallableType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a SQL-expression + producing method. + + When a hybrid is invoked at the class level, the SQL expression given + here is wrapped inside of a specialized :class:`.QueryableAttribute`, + which is the same kind of object used by the ORM to represent other + mapped attributes. The reason for this is so that other class-level + attributes such as docstrings and a reference to the hybrid itself may + be maintained within the structure that's returned, without any + modifications to the original SQL expression passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as well as this hybrid object. + However, that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + .. seealso:: + + :ref:`hybrid_distinct_expression` + + """ + + return self._copy(expr=expr) + + def comparator( + self, comparator: _HybridComparatorCallableType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines a custom + comparator producing method. + + The return value of the decorated method should be an instance of + :class:`~.hybrid.Comparator`. + + .. note:: The :meth:`.hybrid_property.comparator` decorator + **replaces** the use of the :meth:`.hybrid_property.expression` + decorator. They cannot be used together. + + When a hybrid is invoked at the class level, the + :class:`~.hybrid.Comparator` object given here is wrapped inside of a + specialized :class:`.QueryableAttribute`, which is the same kind of + object used by the ORM to represent other mapped attributes. The + reason for this is so that other class-level attributes such as + docstrings and a reference to the hybrid itself may be maintained + within the structure that's returned, without any modifications to the + original comparator object passed in. + + .. note:: + + When referring to a hybrid property from an owning class (e.g. + ``SomeClass.some_hybrid``), an instance of + :class:`.QueryableAttribute` is returned, representing the + expression or comparator object as this hybrid object. However, + that object itself has accessors called ``expression`` and + ``comparator``; so when attempting to override these decorators on a + subclass, it may be necessary to qualify it using the + :attr:`.hybrid_property.overrides` modifier first. See that + modifier for details. + + """ + return self._copy(custom_comparator=comparator) + + def update_expression( + self, meth: _HybridUpdaterType[_T] + ) -> hybrid_property[_T]: + """Provide a modifying decorator that defines an UPDATE tuple + producing method. + + The method accepts a single value, which is the value to be + rendered into the SET clause of an UPDATE statement. The method + should then process this value into individual column expressions + that fit into the ultimate SET clause, and return them as a + sequence of 2-tuples. Each tuple + contains a column expression as the key and a value to be rendered. + + E.g.:: + + class Person(Base): + # ... + + first_name = Column(String) + last_name = Column(String) + + @hybrid_property + def fullname(self): + return first_name + " " + last_name + + @fullname.update_expression + def fullname(cls, value): + fname, lname = value.split(" ", 1) + return [ + (cls.first_name, fname), + (cls.last_name, lname) + ] + + .. versionadded:: 1.2 + + """ + return self._copy(update_expr=meth) + + @util.memoized_property + def _expr_comparator( + self, + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + if self.custom_comparator is not None: + return self._get_comparator(self.custom_comparator) + elif self.expr is not None: + return self._get_expr(self.expr) + else: + return self._get_expr(cast(_HybridExprCallableType[_T], self.fget)) + + def _get_expr( + self, expr: _HybridExprCallableType[_T] + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + def _expr(cls: Any) -> ExprComparator[_T]: + return ExprComparator(cls, expr(cls), self) + + util.update_wrapper(_expr, expr) + + return self._get_comparator(_expr) + + def _get_comparator( + self, comparator: Any + ) -> Callable[[Any], _HybridClassLevelAccessor[_T]]: + proxy_attr = attributes.create_proxied_attribute(self) + + def expr_comparator( + owner: Type[object], + ) -> _HybridClassLevelAccessor[_T]: + # because this is the descriptor protocol, we don't really know + # what our attribute name is. so search for it through the + # MRO. + for lookup in owner.__mro__: + if self.__name__ in lookup.__dict__: + if lookup.__dict__[self.__name__] is self: + name = self.__name__ + break + else: + name = attributes._UNKNOWN_ATTR_KEY # type: ignore[assignment] + + return cast( + "_HybridClassLevelAccessor[_T]", + proxy_attr( + owner, + name, + self, + comparator(owner), + doc=comparator.__doc__ or self.__doc__, + ), + ) + + return expr_comparator + + +class Comparator(interfaces.PropComparator[_T]): + """A helper class that allows easy construction of custom + :class:`~.orm.interfaces.PropComparator` + classes for usage with hybrids.""" + + def __init__( + self, expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]] + ): + self.expression = expression + + def __clause_element__(self) -> roles.ColumnsClauseRole: + expr = self.expression + if is_has_clause_element(expr): + ret_expr = expr.__clause_element__() + else: + if TYPE_CHECKING: + assert isinstance(expr, ColumnElement) + ret_expr = expr + + if TYPE_CHECKING: + # see test_hybrid->test_expression_isnt_clause_element + # that exercises the usual place this is caught if not + # true + assert isinstance(ret_expr, ColumnElement) + return ret_expr + + @util.non_memoized_property + def property(self) -> interfaces.MapperProperty[_T]: + raise NotImplementedError() + + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> Comparator[_T]: + # interesting.... + return self + + +class ExprComparator(Comparator[_T]): + def __init__( + self, + cls: Type[Any], + expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]], + hybrid: hybrid_property[_T], + ): + self.cls = cls + self.expression = expression + self.hybrid = hybrid + + def __getattr__(self, key: str) -> Any: + return getattr(self.expression, key) + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.hybrid.info + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + if isinstance(self.expression, attributes.QueryableAttribute): + return self.expression._bulk_update_tuples(value) + elif self.hybrid.update_expr is not None: + return self.hybrid.update_expr(self.cls, value) + else: + return [(self.expression, value)] + + @util.non_memoized_property + def property(self) -> MapperProperty[_T]: + # this accessor is not normally used, however is accessed by things + # like ORM synonyms if the hybrid is used in this context; the + # .property attribute is not necessarily accessible + return self.expression.property # type: ignore + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.expression, *other, **kwargs) + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.expression, **kwargs) # type: ignore diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/indexable.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/indexable.py new file mode 100644 index 0000000..3c41930 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/indexable.py @@ -0,0 +1,341 @@ +# ext/indexable.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Define attributes on ORM-mapped classes that have "index" attributes for +columns with :class:`_types.Indexable` types. + +"index" means the attribute is associated with an element of an +:class:`_types.Indexable` column with the predefined index to access it. +The :class:`_types.Indexable` types include types such as +:class:`_types.ARRAY`, :class:`_types.JSON` and +:class:`_postgresql.HSTORE`. + + + +The :mod:`~sqlalchemy.ext.indexable` extension provides +:class:`_schema.Column`-like interface for any element of an +:class:`_types.Indexable` typed column. In simple cases, it can be +treated as a :class:`_schema.Column` - mapped attribute. + +Synopsis +======== + +Given ``Person`` as a model with a primary key and JSON data field. +While this field may have any number of elements encoded within it, +we would like to refer to the element called ``name`` individually +as a dedicated attribute which behaves like a standalone column:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + name = index_property('data', 'name') + + +Above, the ``name`` attribute now behaves like a mapped column. We +can compose a new ``Person`` and set the value of ``name``:: + + >>> person = Person(name='Alchemist') + +The value is now accessible:: + + >>> person.name + 'Alchemist' + +Behind the scenes, the JSON field was initialized to a new blank dictionary +and the field was set:: + + >>> person.data + {"name": "Alchemist'} + +The field is mutable in place:: + + >>> person.name = 'Renamed' + >>> person.name + 'Renamed' + >>> person.data + {'name': 'Renamed'} + +When using :class:`.index_property`, the change that we make to the indexable +structure is also automatically tracked as history; we no longer need +to use :class:`~.mutable.MutableDict` in order to track this change +for the unit of work. + +Deletions work normally as well:: + + >>> del person.name + >>> person.data + {} + +Above, deletion of ``person.name`` deletes the value from the dictionary, +but not the dictionary itself. + +A missing key will produce ``AttributeError``:: + + >>> person = Person() + >>> person.name + ... + AttributeError: 'name' + +Unless you set a default value:: + + >>> class Person(Base): + >>> __tablename__ = 'person' + >>> + >>> id = Column(Integer, primary_key=True) + >>> data = Column(JSON) + >>> + >>> name = index_property('data', 'name', default=None) # See default + + >>> person = Person() + >>> print(person.name) + None + + +The attributes are also accessible at the class level. +Below, we illustrate ``Person.name`` used to generate +an indexed SQL criteria:: + + >>> from sqlalchemy.orm import Session + >>> session = Session() + >>> query = session.query(Person).filter(Person.name == 'Alchemist') + +The above query is equivalent to:: + + >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist') + +Multiple :class:`.index_property` objects can be chained to produce +multiple levels of indexing:: + + from sqlalchemy import Column, JSON, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.ext.indexable import index_property + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + birthday = index_property('data', 'birthday') + year = index_property('birthday', 'year') + month = index_property('birthday', 'month') + day = index_property('birthday', 'day') + +Above, a query such as:: + + q = session.query(Person).filter(Person.year == '1980') + +On a PostgreSQL backend, the above query will render as:: + + SELECT person.id, person.data + FROM person + WHERE person.data -> %(data_1)s -> %(param_1)s = %(param_2)s + +Default Values +============== + +:class:`.index_property` includes special behaviors for when the indexed +data structure does not exist, and a set operation is called: + +* For an :class:`.index_property` that is given an integer index value, + the default data structure will be a Python list of ``None`` values, + at least as long as the index value; the value is then set at its + place in the list. This means for an index value of zero, the list + will be initialized to ``[None]`` before setting the given value, + and for an index value of five, the list will be initialized to + ``[None, None, None, None, None]`` before setting the fifth element + to the given value. Note that an existing list is **not** extended + in place to receive a value. + +* for an :class:`.index_property` that is given any other kind of index + value (e.g. strings usually), a Python dictionary is used as the + default data structure. + +* The default data structure can be set to any Python callable using the + :paramref:`.index_property.datatype` parameter, overriding the previous + rules. + + +Subclassing +=========== + +:class:`.index_property` can be subclassed, in particular for the common +use case of providing coercion of values or SQL expressions as they are +accessed. Below is a common recipe for use with a PostgreSQL JSON type, +where we want to also include automatic casting plus ``astext()``:: + + class pg_json_property(index_property): + def __init__(self, attr_name, index, cast_type): + super(pg_json_property, self).__init__(attr_name, index) + self.cast_type = cast_type + + def expr(self, model): + expr = super(pg_json_property, self).expr(model) + return expr.astext.cast(self.cast_type) + +The above subclass can be used with the PostgreSQL-specific +version of :class:`_postgresql.JSON`:: + + from sqlalchemy import Column, Integer + from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.dialects.postgresql import JSON + + Base = declarative_base() + + class Person(Base): + __tablename__ = 'person' + + id = Column(Integer, primary_key=True) + data = Column(JSON) + + age = pg_json_property('data', 'age', Integer) + +The ``age`` attribute at the instance level works as before; however +when rendering SQL, PostgreSQL's ``->>`` operator will be used +for indexed access, instead of the usual index operator of ``->``:: + + >>> query = session.query(Person).filter(Person.age < 20) + +The above query will render:: + + SELECT person.id, person.data + FROM person + WHERE CAST(person.data ->> %(data_1)s AS INTEGER) < %(param_1)s + +""" # noqa +from .. import inspect +from ..ext.hybrid import hybrid_property +from ..orm.attributes import flag_modified + + +__all__ = ["index_property"] + + +class index_property(hybrid_property): # noqa + """A property generator. The generated property describes an object + attribute that corresponds to an :class:`_types.Indexable` + column. + + .. seealso:: + + :mod:`sqlalchemy.ext.indexable` + + """ + + _NO_DEFAULT_ARGUMENT = object() + + def __init__( + self, + attr_name, + index, + default=_NO_DEFAULT_ARGUMENT, + datatype=None, + mutable=True, + onebased=True, + ): + """Create a new :class:`.index_property`. + + :param attr_name: + An attribute name of an `Indexable` typed column, or other + attribute that returns an indexable structure. + :param index: + The index to be used for getting and setting this value. This + should be the Python-side index value for integers. + :param default: + A value which will be returned instead of `AttributeError` + when there is not a value at given index. + :param datatype: default datatype to use when the field is empty. + By default, this is derived from the type of index used; a + Python list for an integer index, or a Python dictionary for + any other style of index. For a list, the list will be + initialized to a list of None values that is at least + ``index`` elements long. + :param mutable: if False, writes and deletes to the attribute will + be disallowed. + :param onebased: assume the SQL representation of this value is + one-based; that is, the first index in SQL is 1, not zero. + """ + + if mutable: + super().__init__(self.fget, self.fset, self.fdel, self.expr) + else: + super().__init__(self.fget, None, None, self.expr) + self.attr_name = attr_name + self.index = index + self.default = default + is_numeric = isinstance(index, int) + onebased = is_numeric and onebased + + if datatype is not None: + self.datatype = datatype + else: + if is_numeric: + self.datatype = lambda: [None for x in range(index + 1)] + else: + self.datatype = dict + self.onebased = onebased + + def _fget_default(self, err=None): + if self.default == self._NO_DEFAULT_ARGUMENT: + raise AttributeError(self.attr_name) from err + else: + return self.default + + def fget(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + return self._fget_default() + try: + value = column_value[self.index] + except (KeyError, IndexError) as err: + return self._fget_default(err) + else: + return value + + def fset(self, instance, value): + attr_name = self.attr_name + column_value = getattr(instance, attr_name, None) + if column_value is None: + column_value = self.datatype() + setattr(instance, attr_name, column_value) + column_value[self.index] = value + setattr(instance, attr_name, column_value) + if attr_name in inspect(instance).mapper.attrs: + flag_modified(instance, attr_name) + + def fdel(self, instance): + attr_name = self.attr_name + column_value = getattr(instance, attr_name) + if column_value is None: + raise AttributeError(self.attr_name) + try: + del column_value[self.index] + except KeyError as err: + raise AttributeError(self.attr_name) from err + else: + setattr(instance, attr_name, column_value) + flag_modified(instance, attr_name) + + def expr(self, model): + column = getattr(model, self.attr_name) + index = self.index + if self.onebased: + index += 1 + return column[index] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/instrumentation.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/instrumentation.py new file mode 100644 index 0000000..5f3c712 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/instrumentation.py @@ -0,0 +1,450 @@ +# ext/instrumentation.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Extensible class instrumentation. + +The :mod:`sqlalchemy.ext.instrumentation` package provides for alternate +systems of class instrumentation within the ORM. Class instrumentation +refers to how the ORM places attributes on the class which maintain +data and track changes to that data, as well as event hooks installed +on the class. + +.. note:: + The extension package is provided for the benefit of integration + with other object management packages, which already perform + their own instrumentation. It is not intended for general use. + +For examples of how the instrumentation extension is used, +see the example :ref:`examples_instrumentation`. + +""" +import weakref + +from .. import util +from ..orm import attributes +from ..orm import base as orm_base +from ..orm import collections +from ..orm import exc as orm_exc +from ..orm import instrumentation as orm_instrumentation +from ..orm import util as orm_util +from ..orm.instrumentation import _default_dict_getter +from ..orm.instrumentation import _default_manager_getter +from ..orm.instrumentation import _default_opt_manager_getter +from ..orm.instrumentation import _default_state_getter +from ..orm.instrumentation import ClassManager +from ..orm.instrumentation import InstrumentationFactory + + +INSTRUMENTATION_MANAGER = "__sa_instrumentation_manager__" +"""Attribute, elects custom instrumentation when present on a mapped class. + +Allows a class to specify a slightly or wildly different technique for +tracking changes made to mapped attributes and collections. + +Only one instrumentation implementation is allowed in a given object +inheritance hierarchy. + +The value of this attribute must be a callable and will be passed a class +object. The callable must return one of: + + - An instance of an :class:`.InstrumentationManager` or subclass + - An object implementing all or some of InstrumentationManager (TODO) + - A dictionary of callables, implementing all or some of the above (TODO) + - An instance of a :class:`.ClassManager` or subclass + +This attribute is consulted by SQLAlchemy instrumentation +resolution, once the :mod:`sqlalchemy.ext.instrumentation` module +has been imported. If custom finders are installed in the global +instrumentation_finders list, they may or may not choose to honor this +attribute. + +""" + + +def find_native_user_instrumentation_hook(cls): + """Find user-specified instrumentation management for a class.""" + return getattr(cls, INSTRUMENTATION_MANAGER, None) + + +instrumentation_finders = [find_native_user_instrumentation_hook] +"""An extensible sequence of callables which return instrumentation +implementations + +When a class is registered, each callable will be passed a class object. +If None is returned, the +next finder in the sequence is consulted. Otherwise the return must be an +instrumentation factory that follows the same guidelines as +sqlalchemy.ext.instrumentation.INSTRUMENTATION_MANAGER. + +By default, the only finder is find_native_user_instrumentation_hook, which +searches for INSTRUMENTATION_MANAGER. If all finders return None, standard +ClassManager instrumentation is used. + +""" + + +class ExtendedInstrumentationRegistry(InstrumentationFactory): + """Extends :class:`.InstrumentationFactory` with additional + bookkeeping, to accommodate multiple types of + class managers. + + """ + + _manager_finders = weakref.WeakKeyDictionary() + _state_finders = weakref.WeakKeyDictionary() + _dict_finders = weakref.WeakKeyDictionary() + _extended = False + + def _locate_extended_factory(self, class_): + for finder in instrumentation_finders: + factory = finder(class_) + if factory is not None: + manager = self._extended_class_manager(class_, factory) + return manager, factory + else: + return None, None + + def _check_conflicts(self, class_, factory): + existing_factories = self._collect_management_factories_for( + class_ + ).difference([factory]) + if existing_factories: + raise TypeError( + "multiple instrumentation implementations specified " + "in %s inheritance hierarchy: %r" + % (class_.__name__, list(existing_factories)) + ) + + def _extended_class_manager(self, class_, factory): + manager = factory(class_) + if not isinstance(manager, ClassManager): + manager = _ClassInstrumentationAdapter(class_, manager) + + if factory != ClassManager and not self._extended: + # somebody invoked a custom ClassManager. + # reinstall global "getter" functions with the more + # expensive ones. + self._extended = True + _install_instrumented_lookups() + + self._manager_finders[class_] = manager.manager_getter() + self._state_finders[class_] = manager.state_getter() + self._dict_finders[class_] = manager.dict_getter() + return manager + + def _collect_management_factories_for(self, cls): + """Return a collection of factories in play or specified for a + hierarchy. + + Traverses the entire inheritance graph of a cls and returns a + collection of instrumentation factories for those classes. Factories + are extracted from active ClassManagers, if available, otherwise + instrumentation_finders is consulted. + + """ + hierarchy = util.class_hierarchy(cls) + factories = set() + for member in hierarchy: + manager = self.opt_manager_of_class(member) + if manager is not None: + factories.add(manager.factory) + else: + for finder in instrumentation_finders: + factory = finder(member) + if factory is not None: + break + else: + factory = None + factories.add(factory) + factories.discard(None) + return factories + + def unregister(self, class_): + super().unregister(class_) + if class_ in self._manager_finders: + del self._manager_finders[class_] + del self._state_finders[class_] + del self._dict_finders[class_] + + def opt_manager_of_class(self, cls): + try: + finder = self._manager_finders.get( + cls, _default_opt_manager_getter + ) + except TypeError: + # due to weakref lookup on invalid object + return None + else: + return finder(cls) + + def manager_of_class(self, cls): + try: + finder = self._manager_finders.get(cls, _default_manager_getter) + except TypeError: + # due to weakref lookup on invalid object + raise orm_exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) + else: + manager = finder(cls) + if manager is None: + raise orm_exc.UnmappedClassError( + cls, + f"Can't locate an instrumentation manager for class {cls}", + ) + return manager + + def state_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._state_finders.get( + instance.__class__, _default_state_getter + )(instance) + + def dict_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self._dict_finders.get( + instance.__class__, _default_dict_getter + )(instance) + + +orm_instrumentation._instrumentation_factory = _instrumentation_factory = ( + ExtendedInstrumentationRegistry() +) +orm_instrumentation.instrumentation_finders = instrumentation_finders + + +class InstrumentationManager: + """User-defined class instrumentation extension. + + :class:`.InstrumentationManager` can be subclassed in order + to change + how class instrumentation proceeds. This class exists for + the purposes of integration with other object management + frameworks which would like to entirely modify the + instrumentation methodology of the ORM, and is not intended + for regular usage. For interception of class instrumentation + events, see :class:`.InstrumentationEvents`. + + The API for this class should be considered as semi-stable, + and may change slightly with new releases. + + """ + + # r4361 added a mandatory (cls) constructor to this interface. + # given that, perhaps class_ should be dropped from all of these + # signatures. + + def __init__(self, class_): + pass + + def manage(self, class_, manager): + setattr(class_, "_default_class_manager", manager) + + def unregister(self, class_, manager): + delattr(class_, "_default_class_manager") + + def manager_getter(self, class_): + def get(cls): + return cls._default_class_manager + + return get + + def instrument_attribute(self, class_, key, inst): + pass + + def post_configure_attribute(self, class_, key, inst): + pass + + def install_descriptor(self, class_, key, inst): + setattr(class_, key, inst) + + def uninstall_descriptor(self, class_, key): + delattr(class_, key) + + def install_member(self, class_, key, implementation): + setattr(class_, key, implementation) + + def uninstall_member(self, class_, key): + delattr(class_, key) + + def instrument_collection_class(self, class_, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def get_instance_dict(self, class_, instance): + return instance.__dict__ + + def initialize_instance_dict(self, class_, instance): + pass + + def install_state(self, class_, instance, state): + setattr(instance, "_default_state", state) + + def remove_state(self, class_, instance): + delattr(instance, "_default_state") + + def state_getter(self, class_): + return lambda instance: getattr(instance, "_default_state") + + def dict_getter(self, class_): + return lambda inst: self.get_instance_dict(class_, inst) + + +class _ClassInstrumentationAdapter(ClassManager): + """Adapts a user-defined InstrumentationManager to a ClassManager.""" + + def __init__(self, class_, override): + self._adapted = override + self._get_state = self._adapted.state_getter(class_) + self._get_dict = self._adapted.dict_getter(class_) + + ClassManager.__init__(self, class_) + + def manage(self): + self._adapted.manage(self.class_, self) + + def unregister(self): + self._adapted.unregister(self.class_, self) + + def manager_getter(self): + return self._adapted.manager_getter(self.class_) + + def instrument_attribute(self, key, inst, propagated=False): + ClassManager.instrument_attribute(self, key, inst, propagated) + if not propagated: + self._adapted.instrument_attribute(self.class_, key, inst) + + def post_configure_attribute(self, key): + super().post_configure_attribute(key) + self._adapted.post_configure_attribute(self.class_, key, self[key]) + + def install_descriptor(self, key, inst): + self._adapted.install_descriptor(self.class_, key, inst) + + def uninstall_descriptor(self, key): + self._adapted.uninstall_descriptor(self.class_, key) + + def install_member(self, key, implementation): + self._adapted.install_member(self.class_, key, implementation) + + def uninstall_member(self, key): + self._adapted.uninstall_member(self.class_, key) + + def instrument_collection_class(self, key, collection_class): + return self._adapted.instrument_collection_class( + self.class_, key, collection_class + ) + + def initialize_collection(self, key, state, factory): + delegate = getattr(self._adapted, "initialize_collection", None) + if delegate: + return delegate(key, state, factory) + else: + return ClassManager.initialize_collection( + self, key, state, factory + ) + + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + self.setup_instance(instance, state) + return instance + + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + """ + if self.has_state(instance): + return False + else: + return self.setup_instance(instance) + + def setup_instance(self, instance, state=None): + self._adapted.initialize_instance_dict(self.class_, instance) + + if state is None: + state = self._state_constructor(instance, self) + + # the given instance is assumed to have no state + self._adapted.install_state(self.class_, instance, state) + return state + + def teardown_instance(self, instance): + self._adapted.remove_state(self.class_, instance) + + def has_state(self, instance): + try: + self._get_state(instance) + except orm_exc.NO_STATE: + return False + else: + return True + + def state_getter(self): + return self._get_state + + def dict_getter(self): + return self._get_dict + + +def _install_instrumented_lookups(): + """Replace global class/object management functions + with ExtendedInstrumentationRegistry implementations, which + allow multiple types of class managers to be present, + at the cost of performance. + + This function is called only by ExtendedInstrumentationRegistry + and unit tests specific to this behavior. + + The _reinstall_default_lookups() function can be called + after this one to re-establish the default functions. + + """ + _install_lookups( + dict( + instance_state=_instrumentation_factory.state_of, + instance_dict=_instrumentation_factory.dict_of, + manager_of_class=_instrumentation_factory.manager_of_class, + opt_manager_of_class=_instrumentation_factory.opt_manager_of_class, + ) + ) + + +def _reinstall_default_lookups(): + """Restore simplified lookups.""" + _install_lookups( + dict( + instance_state=_default_state_getter, + instance_dict=_default_dict_getter, + manager_of_class=_default_manager_getter, + opt_manager_of_class=_default_opt_manager_getter, + ) + ) + _instrumentation_factory._extended = False + + +def _install_lookups(lookups): + global instance_state, instance_dict + global manager_of_class, opt_manager_of_class + instance_state = lookups["instance_state"] + instance_dict = lookups["instance_dict"] + manager_of_class = lookups["manager_of_class"] + opt_manager_of_class = lookups["opt_manager_of_class"] + orm_base.instance_state = attributes.instance_state = ( + orm_instrumentation.instance_state + ) = instance_state + orm_base.instance_dict = attributes.instance_dict = ( + orm_instrumentation.instance_dict + ) = instance_dict + orm_base.manager_of_class = attributes.manager_of_class = ( + orm_instrumentation.manager_of_class + ) = manager_of_class + orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = ( + attributes.opt_manager_of_class + ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py new file mode 100644 index 0000000..7da5075 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mutable.py @@ -0,0 +1,1073 @@ +# ext/mutable.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +r"""Provide support for tracking of in-place changes to scalar values, +which are propagated into ORM change events on owning parent objects. + +.. _mutable_scalars: + +Establishing Mutability on Scalar Column Values +=============================================== + +A typical example of a "mutable" structure is a Python dictionary. +Following the example introduced in :ref:`types_toplevel`, we +begin with a custom type that marshals Python dictionaries into +JSON strings before being persisted:: + + from sqlalchemy.types import TypeDecorator, VARCHAR + import json + + class JSONEncodedDict(TypeDecorator): + "Represents an immutable structure as a json-encoded string." + + impl = VARCHAR + + def process_bind_param(self, value, dialect): + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = json.loads(value) + return value + +The usage of ``json`` is only for the purposes of example. The +:mod:`sqlalchemy.ext.mutable` extension can be used +with any type whose target Python type may be mutable, including +:class:`.PickleType`, :class:`_postgresql.ARRAY`, etc. + +When using the :mod:`sqlalchemy.ext.mutable` extension, the value itself +tracks all parents which reference it. Below, we illustrate a simple +version of the :class:`.MutableDict` dictionary object, which applies +the :class:`.Mutable` mixin to a plain Python dictionary:: + + from sqlalchemy.ext.mutable import Mutable + + class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + +The above dictionary class takes the approach of subclassing the Python +built-in ``dict`` to produce a dict +subclass which routes all mutation events through ``__setitem__``. There are +variants on this approach, such as subclassing ``UserDict.UserDict`` or +``collections.MutableMapping``; the part that's important to this example is +that the :meth:`.Mutable.changed` method is called whenever an in-place +change to the datastructure takes place. + +We also redefine the :meth:`.Mutable.coerce` method which will be used to +convert any values that are not instances of ``MutableDict``, such +as the plain dictionaries returned by the ``json`` module, into the +appropriate type. Defining this method is optional; we could just as well +created our ``JSONEncodedDict`` such that it always returns an instance +of ``MutableDict``, and additionally ensured that all calling code +uses ``MutableDict`` explicitly. When :meth:`.Mutable.coerce` is not +overridden, any values applied to a parent object which are not instances +of the mutable type will raise a ``ValueError``. + +Our new ``MutableDict`` type offers a class method +:meth:`~.Mutable.as_mutable` which we can use within column metadata +to associate with types. This method grabs the given type object or +class and associates a listener that will detect all future mappings +of this type, applying event listening instrumentation to the mapped +attribute. Such as, with classical table metadata:: + + from sqlalchemy import Table, Column, Integer + + my_data = Table('my_data', metadata, + Column('id', Integer, primary_key=True), + Column('data', MutableDict.as_mutable(JSONEncodedDict)) + ) + +Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict`` +(if the type object was not an instance already), which will intercept any +attributes which are mapped against this type. Below we establish a simple +mapping against the ``my_data`` table:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + +The ``MyDataClass.data`` member will now be notified of in place changes +to its value. + +Any in-place changes to the ``MyDataClass.data`` member +will flag the attribute as "dirty" on the parent object:: + + >>> from sqlalchemy.orm import Session + + >>> sess = Session(some_engine) + >>> m1 = MyDataClass(data={'value1':'foo'}) + >>> sess.add(m1) + >>> sess.commit() + + >>> m1.data['value1'] = 'bar' + >>> assert m1 in sess.dirty + True + +The ``MutableDict`` can be associated with all future instances +of ``JSONEncodedDict`` in one step, using +:meth:`~.Mutable.associate_with`. This is similar to +:meth:`~.Mutable.as_mutable` except it will intercept all occurrences +of ``MutableDict`` in all mappings unconditionally, without +the need to declare it individually:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + MutableDict.associate_with(JSONEncodedDict) + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(JSONEncodedDict) + + +Supporting Pickling +-------------------- + +The key to the :mod:`sqlalchemy.ext.mutable` extension relies upon the +placement of a ``weakref.WeakKeyDictionary`` upon the value object, which +stores a mapping of parent mapped objects keyed to the attribute name under +which they are associated with this value. ``WeakKeyDictionary`` objects are +not picklable, due to the fact that they contain weakrefs and function +callbacks. In our case, this is a good thing, since if this dictionary were +picklable, it could lead to an excessively large pickle size for our value +objects that are pickled by themselves outside of the context of the parent. +The developer responsibility here is only to provide a ``__getstate__`` method +that excludes the :meth:`~MutableBase._parents` collection from the pickle +stream:: + + class MyMutableType(Mutable): + def __getstate__(self): + d = self.__dict__.copy() + d.pop('_parents', None) + return d + +With our dictionary example, we need to return the contents of the dict itself +(and also restore them on __setstate__):: + + class MutableDict(Mutable, dict): + # .... + + def __getstate__(self): + return dict(self) + + def __setstate__(self, state): + self.update(state) + +In the case that our mutable value object is pickled as it is attached to one +or more parent objects that are also part of the pickle, the :class:`.Mutable` +mixin will re-establish the :attr:`.Mutable._parents` collection on each value +object as the owning parents themselves are unpickled. + +Receiving Events +---------------- + +The :meth:`.AttributeEvents.modified` event handler may be used to receive +an event when a mutable scalar emits a change event. This event handler +is called when the :func:`.attributes.flag_modified` function is called +from within the mutable extension:: + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + from sqlalchemy import event + + class Base(DeclarativeBase): + pass + + class MyDataClass(Base): + __tablename__ = 'my_data' + id: Mapped[int] = mapped_column(primary_key=True) + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) + + @event.listens_for(MyDataClass.data, "modified") + def modified_json(instance, initiator): + print("json value modified:", instance.data) + +.. _mutable_composites: + +Establishing Mutability on Composites +===================================== + +Composites are a special ORM feature which allow a single scalar attribute to +be assigned an object value which represents information "composed" from one +or more columns from the underlying mapped table. The usual example is that of +a geometric "point", and is introduced in :ref:`mapper_composite`. + +As is the case with :class:`.Mutable`, the user-defined composite class +subclasses :class:`.MutableComposite` as a mixin, and detects and delivers +change events to its parents via the :meth:`.MutableComposite.changed` method. +In the case of a composite class, the detection is usually via the usage of the +special Python method ``__setattr__()``. In the example below, we expand upon the ``Point`` +class introduced in :ref:`mapper_composite` to include +:class:`.MutableComposite` in its bases and to route attribute set events via +``__setattr__`` to the :meth:`.MutableComposite.changed` method:: + + import dataclasses + from sqlalchemy.ext.mutable import MutableComposite + + @dataclasses.dataclass + class Point(MutableComposite): + x: int + y: int + + def __setattr__(self, key, value): + "Intercept set events" + + # set the attribute + object.__setattr__(self, key, value) + + # alert all parents to the change + self.changed() + + +The :class:`.MutableComposite` class makes use of class mapping events to +automatically establish listeners for any usage of :func:`_orm.composite` that +specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` +class, listeners are established which will route change events from ``Point`` +objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: + + from sqlalchemy.orm import DeclarativeBase, Mapped + from sqlalchemy.orm import composite, mapped_column + + class Base(DeclarativeBase): + pass + + + class Vertex(Base): + __tablename__ = "vertices" + + id: Mapped[int] = mapped_column(primary_key=True) + + start: Mapped[Point] = composite(mapped_column("x1"), mapped_column("y1")) + end: Mapped[Point] = composite(mapped_column("x2"), mapped_column("y2")) + + def __repr__(self): + return f"Vertex(start={self.start}, end={self.end})" + +Any in-place changes to the ``Vertex.start`` or ``Vertex.end`` members +will flag the attribute as "dirty" on the parent object: + +.. sourcecode:: python+sql + + >>> from sqlalchemy.orm import Session + >>> sess = Session(engine) + >>> v1 = Vertex(start=Point(3, 4), end=Point(12, 15)) + >>> sess.add(v1) + {sql}>>> sess.flush() + BEGIN (implicit) + INSERT INTO vertices (x1, y1, x2, y2) VALUES (?, ?, ?, ?) + [...] (3, 4, 12, 15) + + {stop}>>> v1.end.x = 8 + >>> assert v1 in sess.dirty + True + {sql}>>> sess.commit() + UPDATE vertices SET x2=? WHERE vertices.id = ? + [...] (8, 1) + COMMIT + +Coercing Mutable Composites +--------------------------- + +The :meth:`.MutableBase.coerce` method is also supported on composite types. +In the case of :class:`.MutableComposite`, the :meth:`.MutableBase.coerce` +method is only called for attribute set operations, not load operations. +Overriding the :meth:`.MutableBase.coerce` method is essentially equivalent +to using a :func:`.validates` validation routine for all attributes which +make use of the custom composite type:: + + @dataclasses.dataclass + class Point(MutableComposite): + # other Point methods + # ... + + def coerce(cls, key, value): + if isinstance(value, tuple): + value = Point(*value) + elif not isinstance(value, Point): + raise ValueError("tuple or Point expected") + return value + +Supporting Pickling +-------------------- + +As is the case with :class:`.Mutable`, the :class:`.MutableComposite` helper +class uses a ``weakref.WeakKeyDictionary`` available via the +:meth:`MutableBase._parents` attribute which isn't picklable. If we need to +pickle instances of ``Point`` or its owning class ``Vertex``, we at least need +to define a ``__getstate__`` that doesn't include the ``_parents`` dictionary. +Below we define both a ``__getstate__`` and a ``__setstate__`` that package up +the minimal form of our ``Point`` class:: + + @dataclasses.dataclass + class Point(MutableComposite): + # ... + + def __getstate__(self): + return self.x, self.y + + def __setstate__(self, state): + self.x, self.y = state + +As with :class:`.Mutable`, the :class:`.MutableComposite` augments the +pickling process of the parent's object-relational state so that the +:meth:`MutableBase._parents` collection is restored to all ``Point`` objects. + +""" # noqa: E501 + +from __future__ import annotations + +from collections import defaultdict +from typing import AbstractSet +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref +from weakref import WeakKeyDictionary + +from .. import event +from .. import inspect +from .. import types +from .. import util +from ..orm import Mapper +from ..orm._typing import _ExternalEntityType +from ..orm._typing import _O +from ..orm._typing import _T +from ..orm.attributes import AttributeEventToken +from ..orm.attributes import flag_modified +from ..orm.attributes import InstrumentedAttribute +from ..orm.attributes import QueryableAttribute +from ..orm.context import QueryContext +from ..orm.decl_api import DeclarativeAttributeIntercept +from ..orm.state import InstanceState +from ..orm.unitofwork import UOWTransaction +from ..sql.base import SchemaEventTarget +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util import memoized_property +from ..util.typing import SupportsIndex +from ..util.typing import TypeGuard + +_KT = TypeVar("_KT") # Key type. +_VT = TypeVar("_VT") # Value type. + + +class MutableBase: + """Common base class to :class:`.Mutable` + and :class:`.MutableComposite`. + + """ + + @memoized_property + def _parents(self) -> WeakKeyDictionary[Any, Any]: + """Dictionary of parent object's :class:`.InstanceState`->attribute + name on the parent. + + This attribute is a so-called "memoized" property. It initializes + itself with a new ``weakref.WeakKeyDictionary`` the first time + it is accessed, returning the same object upon subsequent access. + + .. versionchanged:: 1.4 the :class:`.InstanceState` is now used + as the key in the weak dictionary rather than the instance + itself. + + """ + + return weakref.WeakKeyDictionary() + + @classmethod + def coerce(cls, key: str, value: Any) -> Optional[Any]: + """Given a value, coerce it into the target type. + + Can be overridden by custom subclasses to coerce incoming + data into a particular type. + + By default, raises ``ValueError``. + + This method is called in different scenarios depending on if + the parent class is of type :class:`.Mutable` or of type + :class:`.MutableComposite`. In the case of the former, it is called + for both attribute-set operations as well as during ORM loading + operations. For the latter, it is only called during attribute-set + operations; the mechanics of the :func:`.composite` construct + handle coercion during load operations. + + + :param key: string name of the ORM-mapped attribute being set. + :param value: the incoming value. + :return: the method should return the coerced value, or raise + ``ValueError`` if the coercion cannot be completed. + + """ + if value is None: + return None + msg = "Attribute '%s' does not accept objects of type %s" + raise ValueError(msg % (key, type(value))) + + @classmethod + def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]: + """Given a descriptor attribute, return a ``set()`` of the attribute + keys which indicate a change in the state of this attribute. + + This is normally just ``set([attribute.key])``, but can be overridden + to provide for additional keys. E.g. a :class:`.MutableComposite` + augments this set with the attribute keys associated with the columns + that comprise the composite value. + + This collection is consulted in the case of intercepting the + :meth:`.InstanceEvents.refresh` and + :meth:`.InstanceEvents.refresh_flush` events, which pass along a list + of attribute names that have been refreshed; the list is compared + against this set to determine if action needs to be taken. + + """ + return {attribute.key} + + @classmethod + def _listen_on_attribute( + cls, + attribute: QueryableAttribute[Any], + coerce: bool, + parent_cls: _ExternalEntityType[Any], + ) -> None: + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + key = attribute.key + if parent_cls is not attribute.class_: + return + + # rely on "propagate" here + parent_cls = attribute.class_ + + listen_keys = cls._get_listen_keys(attribute) + + def load(state: InstanceState[_O], *args: Any) -> None: + """Listen for objects loaded or refreshed. + + Wrap the target data member's value with + ``Mutable``. + + """ + val = state.dict.get(key, None) + if val is not None: + if coerce: + val = cls.coerce(key, val) + state.dict[key] = val + val._parents[state] = key + + def load_attrs( + state: InstanceState[_O], + ctx: Union[object, QueryContext, UOWTransaction], + attrs: Iterable[Any], + ) -> None: + if not attrs or listen_keys.intersection(attrs): + load(state) + + def set_( + target: InstanceState[_O], + value: MutableBase | None, + oldvalue: MutableBase | None, + initiator: AttributeEventToken, + ) -> MutableBase | None: + """Listen for set/replace events on the target + data member. + + Establish a weak reference to the parent object + on the incoming value, remove it for the one + outgoing. + + """ + if value is oldvalue: + return value + + if not isinstance(value, cls): + value = cls.coerce(key, value) + if value is not None: + value._parents[target] = key + if isinstance(oldvalue, cls): + oldvalue._parents.pop(inspect(target), None) + return value + + def pickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: + val = state.dict.get(key, None) + if val is not None: + if "ext.mutable.values" not in state_dict: + state_dict["ext.mutable.values"] = defaultdict(list) + state_dict["ext.mutable.values"][key].append(val) + + def unpickle( + state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> None: + if "ext.mutable.values" in state_dict: + collection = state_dict["ext.mutable.values"] + if isinstance(collection, list): + # legacy format + for val in collection: + val._parents[state] = key + else: + for val in state_dict["ext.mutable.values"][key]: + val._parents[state] = key + + event.listen( + parent_cls, + "_sa_event_merge_wo_load", + load, + raw=True, + propagate=True, + ) + + event.listen(parent_cls, "load", load, raw=True, propagate=True) + event.listen( + parent_cls, "refresh", load_attrs, raw=True, propagate=True + ) + event.listen( + parent_cls, "refresh_flush", load_attrs, raw=True, propagate=True + ) + event.listen( + attribute, "set", set_, raw=True, retval=True, propagate=True + ) + event.listen(parent_cls, "pickle", pickle, raw=True, propagate=True) + event.listen( + parent_cls, "unpickle", unpickle, raw=True, propagate=True + ) + + +class Mutable(MutableBase): + """Mixin that defines transparent propagation of change + events to a parent object. + + See the example in :ref:`mutable_scalars` for usage information. + + """ + + def changed(self) -> None: + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + flag_modified(parent.obj(), key) + + @classmethod + def associate_with_attribute( + cls, attribute: InstrumentedAttribute[_O] + ) -> None: + """Establish this type as a mutation listener for the given + mapped descriptor. + + """ + cls._listen_on_attribute(attribute, True, attribute.class_) + + @classmethod + def associate_with(cls, sqltype: type) -> None: + """Associate this wrapper with all future mapped columns + of the given type. + + This is a convenience method that calls + ``associate_with_attribute`` automatically. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.associate_with` for types that are permanent to an + application, not with ad-hoc types else this will cause unbounded + growth in memory usage. + + """ + + def listen_for_type(mapper: Mapper[_O], class_: type) -> None: + if mapper.non_primary: + return + for prop in mapper.column_attrs: + if isinstance(prop.columns[0].type, sqltype): + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(Mapper, "mapper_configured", listen_for_type) + + @classmethod + def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: + """Associate a SQL type with this mutable Python type. + + This establishes listeners that will detect ORM mappings against + the given type, adding mutation event trackers to those mappings. + + The type is returned, unconditionally as an instance, so that + :meth:`.as_mutable` can be used inline:: + + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('data', MyMutableType.as_mutable(PickleType)) + ) + + Note that the returned type is always an instance, even if a class + is given, and that only columns which are declared specifically with + that type instance receive additional instrumentation. + + To associate a particular mutable type with all occurrences of a + particular type, use the :meth:`.Mutable.associate_with` classmethod + of the particular :class:`.Mutable` subclass to establish a global + association. + + .. warning:: + + The listeners established by this method are *global* + to all mappers, and are *not* garbage collected. Only use + :meth:`.as_mutable` for types that are permanent to an application, + not with ad-hoc types else this will cause unbounded growth + in memory usage. + + """ + sqltype = types.to_instance(sqltype) + + # a SchemaType will be copied when the Column is copied, + # and we'll lose our ability to link that type back to the original. + # so track our original type w/ columns + if isinstance(sqltype, SchemaEventTarget): + + @event.listens_for(sqltype, "before_parent_attach") + def _add_column_memo( + sqltyp: TypeEngine[Any], + parent: Column[_T], + ) -> None: + parent.info["_ext_mutable_orig_type"] = sqltyp + + schema_event_check = True + else: + schema_event_check = False + + def listen_for_type( + mapper: Mapper[_T], + class_: Union[DeclarativeAttributeIntercept, type], + ) -> None: + if mapper.non_primary: + return + _APPLIED_KEY = "_ext_mutable_listener_applied" + + for prop in mapper.column_attrs: + if ( + # all Mutable types refer to a Column that's mapped, + # since this is the only kind of Core target the ORM can + # "mutate" + isinstance(prop.expression, Column) + and ( + ( + schema_event_check + and prop.expression.info.get( + "_ext_mutable_orig_type" + ) + is sqltype + ) + or prop.expression.type is sqltype + ) + ): + if not prop.expression.info.get(_APPLIED_KEY, False): + prop.expression.info[_APPLIED_KEY] = True + cls.associate_with_attribute(getattr(class_, prop.key)) + + event.listen(Mapper, "mapper_configured", listen_for_type) + + return sqltype + + +class MutableComposite(MutableBase): + """Mixin that defines transparent propagation of change + events on a SQLAlchemy "composite" object to its + owning parent or parents. + + See the example in :ref:`mutable_composites` for usage information. + + """ + + @classmethod + def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]: + return {attribute.key}.union(attribute.property._attribute_keys) + + def changed(self) -> None: + """Subclasses should call this method whenever change events occur.""" + + for parent, key in self._parents.items(): + prop = parent.mapper.get_property(key) + for value, attr_name in zip( + prop._composite_values_from_instance(self), + prop._attribute_keys, + ): + setattr(parent.obj(), attr_name, value) + + +def _setup_composite_listener() -> None: + def _listen_for_type(mapper: Mapper[_T], class_: type) -> None: + for prop in mapper.iterate_properties: + if ( + hasattr(prop, "composite_class") + and isinstance(prop.composite_class, type) + and issubclass(prop.composite_class, MutableComposite) + ): + prop.composite_class._listen_on_attribute( + getattr(class_, prop.key), False, class_ + ) + + if not event.contains(Mapper, "mapper_configured", _listen_for_type): + event.listen(Mapper, "mapper_configured", _listen_for_type) + + +_setup_composite_listener() + + +class MutableDict(Mutable, Dict[_KT, _VT]): + """A dictionary type that implements :class:`.Mutable`. + + The :class:`.MutableDict` object implements a dictionary that will + emit change events to the underlying mapping when the contents of + the dictionary are altered, including when values are added or removed. + + Note that :class:`.MutableDict` does **not** apply mutable tracking to the + *values themselves* inside the dictionary. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + dictionary structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableDict` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableList` + + :class:`.MutableSet` + + """ + + def __setitem__(self, key: _KT, value: _VT) -> None: + """Detect dictionary set events and emit change events.""" + super().__setitem__(key, value) + self.changed() + + if TYPE_CHECKING: + # from https://github.com/python/mypy/issues/14858 + + @overload + def setdefault( + self: MutableDict[_KT, Optional[_T]], key: _KT, value: None = None + ) -> Optional[_T]: ... + + @overload + def setdefault(self, key: _KT, value: _VT) -> _VT: ... + + def setdefault(self, key: _KT, value: object = None) -> object: ... + + else: + + def setdefault(self, *arg): # noqa: F811 + result = super().setdefault(*arg) + self.changed() + return result + + def __delitem__(self, key: _KT) -> None: + """Detect dictionary del events and emit change events.""" + super().__delitem__(key) + self.changed() + + def update(self, *a: Any, **kw: _VT) -> None: + super().update(*a, **kw) + self.changed() + + if TYPE_CHECKING: + + @overload + def pop(self, __key: _KT) -> _VT: ... + + @overload + def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: ... + + def pop( + self, __key: _KT, __default: _VT | _T | None = None + ) -> _VT | _T: ... + + else: + + def pop(self, *arg): # noqa: F811 + result = super().pop(*arg) + self.changed() + return result + + def popitem(self) -> Tuple[_KT, _VT]: + result = super().popitem() + self.changed() + return result + + def clear(self) -> None: + super().clear() + self.changed() + + @classmethod + def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None: + """Convert plain dictionary to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, dict): + return cls(value) + return Mutable.coerce(key, value) + else: + return value + + def __getstate__(self) -> Dict[_KT, _VT]: + return dict(self) + + def __setstate__( + self, state: Union[Dict[str, int], Dict[str, str]] + ) -> None: + self.update(state) + + +class MutableList(Mutable, List[_T]): + """A list type that implements :class:`.Mutable`. + + The :class:`.MutableList` object implements a list that will + emit change events to the underlying mapping when the contents of + the list are altered, including when values are added or removed. + + Note that :class:`.MutableList` does **not** apply mutable tracking to the + *values themselves* inside the list. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure, such as a JSON structure. To support this use case, + build a subclass of :class:`.MutableList` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableSet` + + """ + + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: + return (self.__class__, (list(self),)) + + # needed for backwards compatibility with + # older pickles + def __setstate__(self, state: Iterable[_T]) -> None: + self[:] = state + + def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]: + return not util.is_non_string_iterable(value) + + def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]: + return util.is_non_string_iterable(value) + + def __setitem__( + self, index: SupportsIndex | slice, value: _T | Iterable[_T] + ) -> None: + """Detect list set events and emit change events.""" + if isinstance(index, SupportsIndex) and self.is_scalar(value): + super().__setitem__(index, value) + elif isinstance(index, slice) and self.is_iterable(value): + super().__setitem__(index, value) + self.changed() + + def __delitem__(self, index: SupportsIndex | slice) -> None: + """Detect list del events and emit change events.""" + super().__delitem__(index) + self.changed() + + def pop(self, *arg: SupportsIndex) -> _T: + result = super().pop(*arg) + self.changed() + return result + + def append(self, x: _T) -> None: + super().append(x) + self.changed() + + def extend(self, x: Iterable[_T]) -> None: + super().extend(x) + self.changed() + + def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]: # type: ignore[override,misc] # noqa: E501 + self.extend(x) + return self + + def insert(self, i: SupportsIndex, x: _T) -> None: + super().insert(i, x) + self.changed() + + def remove(self, i: _T) -> None: + super().remove(i) + self.changed() + + def clear(self) -> None: + super().clear() + self.changed() + + def sort(self, **kw: Any) -> None: + super().sort(**kw) + self.changed() + + def reverse(self) -> None: + super().reverse() + self.changed() + + @classmethod + def coerce( + cls, key: str, value: MutableList[_T] | _T + ) -> Optional[MutableList[_T]]: + """Convert plain list to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, list): + return cls(value) + return Mutable.coerce(key, value) + else: + return value + + +class MutableSet(Mutable, Set[_T]): + """A set type that implements :class:`.Mutable`. + + The :class:`.MutableSet` object implements a set that will + emit change events to the underlying mapping when the contents of + the set are altered, including when values are added or removed. + + Note that :class:`.MutableSet` does **not** apply mutable tracking to the + *values themselves* inside the set. Therefore it is not a sufficient + solution for the use case of tracking deep changes to a *recursive* + mutable structure. To support this use case, + build a subclass of :class:`.MutableSet` that provides appropriate + coercion to the values placed in the dictionary so that they too are + "mutable", and emit events up to their parent structure. + + .. seealso:: + + :class:`.MutableDict` + + :class:`.MutableList` + + + """ + + def update(self, *arg: Iterable[_T]) -> None: + super().update(*arg) + self.changed() + + def intersection_update(self, *arg: Iterable[Any]) -> None: + super().intersection_update(*arg) + self.changed() + + def difference_update(self, *arg: Iterable[Any]) -> None: + super().difference_update(*arg) + self.changed() + + def symmetric_difference_update(self, *arg: Iterable[_T]) -> None: + super().symmetric_difference_update(*arg) + self.changed() + + def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 + self.update(other) + return self + + def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]: + self.intersection_update(other) + return self + + def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501 + self.symmetric_difference_update(other) + return self + + def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]: # type: ignore[misc] # noqa: E501 + self.difference_update(other) + return self + + def add(self, elem: _T) -> None: + super().add(elem) + self.changed() + + def remove(self, elem: _T) -> None: + super().remove(elem) + self.changed() + + def discard(self, elem: _T) -> None: + super().discard(elem) + self.changed() + + def pop(self, *arg: Any) -> _T: + result = super().pop(*arg) + self.changed() + return result + + def clear(self) -> None: + super().clear() + self.changed() + + @classmethod + def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]: + """Convert plain set to instance of this class.""" + if not isinstance(value, cls): + if isinstance(value, set): + return cls(value) + return Mutable.coerce(index, value) + else: + return value + + def __getstate__(self) -> Set[_T]: + return set(self) + + def __setstate__(self, state: Iterable[_T]) -> None: + self.update(state) + + def __reduce_ex__( + self, proto: SupportsIndex + ) -> Tuple[type, Tuple[List[int]]]: + return (self.__class__, (list(self),)) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__init__.py new file mode 100644 index 0000000..de2c02e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__init__.py @@ -0,0 +1,6 @@ +# ext/mypy/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..7ad6efd Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/apply.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/apply.cpython-311.pyc new file mode 100644 index 0000000..6072e1d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/apply.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-311.pyc new file mode 100644 index 0000000..0b6844d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/infer.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/infer.cpython-311.pyc new file mode 100644 index 0000000..98231e9 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/infer.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/names.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/names.cpython-311.pyc new file mode 100644 index 0000000..41c9ba3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/names.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/plugin.cpython-311.pyc new file mode 100644 index 0000000..30fab74 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/plugin.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000..ee8ba78 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/__pycache__/util.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/apply.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/apply.py new file mode 100644 index 0000000..eb90194 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/apply.py @@ -0,0 +1,320 @@ +# ext/mypy/apply.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import ARG_NAMED_OPT +from mypy.nodes import Argument +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import MDEF +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import add_method_to_class +from mypy.types import AnyType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneTyp +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import infer +from . import util +from .names import expr_to_mapped_constructor +from .names import NAMED_TYPE_SQLA_MAPPED + + +def apply_mypy_mapped_attr( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + item: Union[NameExpr, StrExpr], + attributes: List[util.SQLAlchemyAttribute], +) -> None: + if isinstance(item, NameExpr): + name = item.name + elif isinstance(item, StrExpr): + name = item.value + else: + return None + + for stmt in cls.defs.body: + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name == name + ): + break + else: + util.fail(api, f"Can't find mapped attribute {name}", cls) + return None + + if stmt.type is None: + util.fail( + api, + "Statement linked from _mypy_mapped_attrs has no " + "typing information", + stmt, + ) + return None + + left_hand_explicit_type = get_proper_type(stmt.type) + assert isinstance( + left_hand_explicit_type, (Instance, UnionType, UnboundType) + ) + + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=item.line, + column=item.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + apply_type_to_mapped_statement( + api, stmt, stmt.lvalues[0], left_hand_explicit_type, None + ) + + +def re_apply_declarative_assignments( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """For multiple class passes, re-apply our left-hand side types as mypy + seems to reset them in place. + + """ + mapped_attr_lookup = {attr.name: attr for attr in attributes} + update_cls_metadata = False + + for stmt in cls.defs.body: + # for a re-apply, all of our statements are AssignmentStmt; + # @declared_attr calls will have been converted and this + # currently seems to be preserved by mypy (but who knows if this + # will change). + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and stmt.lvalues[0].name in mapped_attr_lookup + and isinstance(stmt.lvalues[0].node, Var) + ): + left_node = stmt.lvalues[0].node + + python_type_for_type = mapped_attr_lookup[ + stmt.lvalues[0].name + ].type + + left_node_proper_type = get_proper_type(left_node.type) + + # if we have scanned an UnboundType and now there's a more + # specific type than UnboundType, call the re-scan so we + # can get that set up correctly + if ( + isinstance(python_type_for_type, UnboundType) + and not isinstance(left_node_proper_type, UnboundType) + and ( + isinstance(stmt.rvalue, CallExpr) + and isinstance(stmt.rvalue.callee, MemberExpr) + and isinstance(stmt.rvalue.callee.expr, NameExpr) + and stmt.rvalue.callee.expr.node is not None + and stmt.rvalue.callee.expr.node.fullname + == NAMED_TYPE_SQLA_MAPPED + and stmt.rvalue.callee.name == "_empty_constructor" + and isinstance(stmt.rvalue.args[0], CallExpr) + and isinstance(stmt.rvalue.args[0].callee, RefExpr) + ) + ): + new_python_type_for_type = ( + infer.infer_type_from_right_hand_nameexpr( + api, + stmt, + left_node, + left_node_proper_type, + stmt.rvalue.args[0].callee, + ) + ) + + if new_python_type_for_type is not None and not isinstance( + new_python_type_for_type, UnboundType + ): + python_type_for_type = new_python_type_for_type + + # update the SQLAlchemyAttribute with the better + # information + mapped_attr_lookup[stmt.lvalues[0].name].type = ( + python_type_for_type + ) + + update_cls_metadata = True + + if ( + not isinstance(left_node.type, Instance) + or left_node.type.type.fullname != NAMED_TYPE_SQLA_MAPPED + ): + assert python_type_for_type is not None + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [python_type_for_type] + ) + + if update_cls_metadata: + util.set_mapped_attributes(cls.info, attributes) + + +def apply_type_to_mapped_statement( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + lvalue: NameExpr, + left_hand_explicit_type: Optional[ProperType], + python_type_for_type: Optional[ProperType], +) -> None: + """Apply the Mapped[] annotation and right hand object to a + declarative assignment statement. + + This converts a Python declarative class statement such as:: + + class User(Base): + # ... + + attrname = Column(Integer) + + To one that describes the final Python behavior to Mypy:: + + class User(Base): + # ... + + attrname : Mapped[Optional[int]] = + + """ + left_node = lvalue.node + assert isinstance(left_node, Var) + + # to be completely honest I have no idea what the difference between + # left_node.type and stmt.type is, what it means if these are different + # vs. the same, why in order to get tests to pass I have to assign + # to stmt.type for the second case and not the first. this is complete + # trying every combination until it works stuff. + + if left_hand_explicit_type is not None: + lvalue.is_inferred_def = False + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + else: + lvalue.is_inferred_def = False + left_node.type = api.named_type( + NAMED_TYPE_SQLA_MAPPED, + ( + [AnyType(TypeOfAny.special_form)] + if python_type_for_type is None + else [python_type_for_type] + ), + ) + + # so to have it skip the right side totally, we can do this: + # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # however, if we instead manufacture a new node that uses the old + # one, then we can still get type checking for the call itself, + # e.g. the Column, relationship() call, etc. + + # rewrite the node as: + # : Mapped[] = + # _sa_Mapped._empty_constructor() + # the original right-hand side is maintained so it gets type checked + # internally + stmt.rvalue = expr_to_mapped_constructor(stmt.rvalue) + + if stmt.type is not None and python_type_for_type is not None: + stmt.type = python_type_for_type + + +def add_additional_orm_attributes( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Apply __init__, __table__ and other attributes to the mapped class.""" + + info = util.info_for_cls(cls, api) + + if info is None: + return + + is_base = util.get_is_base(info) + + if "__init__" not in info.names and not is_base: + mapped_attr_names = {attr.name: attr.type for attr in attributes} + + for base in info.mro[1:-1]: + if "sqlalchemy" not in info.metadata: + continue + + base_cls_attributes = util.get_mapped_attributes(base, api) + if base_cls_attributes is None: + continue + + for attr in base_cls_attributes: + mapped_attr_names.setdefault(attr.name, attr.type) + + arguments = [] + for name, typ in mapped_attr_names.items(): + if typ is None: + typ = AnyType(TypeOfAny.special_form) + arguments.append( + Argument( + variable=Var(name, typ), + type_annotation=typ, + initializer=TempNode(typ), + kind=ARG_NAMED_OPT, + ) + ) + + add_method_to_class(api, cls, "__init__", arguments, NoneTyp()) + + if "__table__" not in info.names and util.get_has_table(info): + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.sql.schema.Table", "__table__" + ) + if not is_base: + _apply_placeholder_attr_to_class( + api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__" + ) + + +def _apply_placeholder_attr_to_class( + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + qualified_name: str, + attrname: str, +) -> None: + sym = api.lookup_fully_qualified_or_none(qualified_name) + if sym: + assert isinstance(sym.node, TypeInfo) + type_: ProperType = Instance(sym.node, []) + else: + type_ = AnyType(TypeOfAny.special_form) + var = Var(attrname) + var._fullname = cls.fullname + "." + attrname + var.info = cls.info + var.type = type_ + cls.info.names[attrname] = SymbolTableNode(MDEF, var) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/decl_class.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/decl_class.py new file mode 100644 index 0000000..3d578b3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/decl_class.py @@ -0,0 +1,515 @@ +# ext/mypy/decl_class.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import List +from typing import Optional +from typing import Union + +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import Decorator +from mypy.nodes import LambdaExpr +from mypy.nodes import ListExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import PlaceholderNode +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import SymbolNode +from mypy.nodes import SymbolTableNode +from mypy.nodes import TempNode +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import Type +from mypy.types import TypeOfAny +from mypy.types import UnboundType +from mypy.types import UnionType + +from . import apply +from . import infer +from . import names +from . import util + + +def scan_declarative_assignments_and_apply_types( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + is_mixin_scan: bool = False, +) -> Optional[List[util.SQLAlchemyAttribute]]: + info = util.info_for_cls(cls, api) + + if info is None: + # this can occur during cached passes + return None + elif cls.fullname.startswith("builtins"): + return None + + mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = ( + util.get_mapped_attributes(info, api) + ) + + # used by assign.add_additional_orm_attributes among others + util.establish_as_sqlalchemy(info) + + if mapped_attributes is not None: + # ensure that a class that's mapped is always picked up by + # its mapped() decorator or declarative metaclass before + # it would be detected as an unmapped mixin class + + if not is_mixin_scan: + # mypy can call us more than once. it then *may* have reset the + # left hand side of everything, but not the right that we removed, + # removing our ability to re-scan. but we have the types + # here, so lets re-apply them, or if we have an UnboundType, + # we can re-scan + + apply.re_apply_declarative_assignments(cls, api, mapped_attributes) + + return mapped_attributes + + mapped_attributes = [] + + if not cls.defs.body: + # when we get a mixin class from another file, the body is + # empty (!) but the names are in the symbol table. so use that. + + for sym_name, sym in info.names.items(): + _scan_symbol_table_entry( + cls, api, sym_name, sym, mapped_attributes + ) + else: + for stmt in util.flatten_typechecking(cls.defs.body): + if isinstance(stmt, AssignmentStmt): + _scan_declarative_assignment_stmt( + cls, api, stmt, mapped_attributes + ) + elif isinstance(stmt, Decorator): + _scan_declarative_decorator_stmt( + cls, api, stmt, mapped_attributes + ) + _scan_for_mapped_bases(cls, api) + + if not is_mixin_scan: + apply.add_additional_orm_attributes(cls, api, mapped_attributes) + + util.set_mapped_attributes(info, mapped_attributes) + + return mapped_attributes + + +def _scan_symbol_table_entry( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + name: str, + value: SymbolTableNode, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a SymbolTableNode that's in the + type.names dictionary. + + """ + value_type = get_proper_type(value.type) + if not isinstance(value_type, Instance): + return + + left_hand_explicit_type = None + type_id = names.type_id_for_named_node(value_type.type) + # type_id = names._type_id_for_unbound_type(value.type.type, cls, api) + + err = False + + # TODO: this is nearly the same logic as that of + # _scan_declarative_decorator_stmt, likely can be merged + if type_id in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + }: + if value_type.args: + left_hand_explicit_type = get_proper_type(value_type.args[0]) + else: + err = True + elif type_id is names.COLUMN: + if not value_type.args: + err = True + else: + typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type( + value_type.args[0] + ) + if isinstance(typeengine_arg, Instance): + typeengine_arg = typeengine_arg.type + + if isinstance(typeengine_arg, (UnboundType, TypeInfo)): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + value_type, + ) + + if err: + msg = ( + "Can't infer type from attribute {} on class {}. " + "please specify a return type from this function that is " + "one of: Mapped[], relationship[], " + "Column[], MapperProperty[]" + ) + util.fail(api, msg.format(name, cls.name), cls) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + if left_hand_explicit_type is not None: + assert value.node is not None + attributes.append( + util.SQLAlchemyAttribute( + name=name, + line=value.node.line, + column=value.node.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + + +def _scan_declarative_decorator_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: Decorator, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from a @declared_attr in a declarative + class. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + @declared_attr + def updated_at(cls) -> Column[DateTime]: + return Column(DateTime) + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + updated_at: Mapped[Optional[datetime.datetime]] + + """ + for dec in stmt.decorators: + if ( + isinstance(dec, (NameExpr, MemberExpr, SymbolNode)) + and names.type_id_for_named_node(dec) is names.DECLARED_ATTR + ): + break + else: + return + + dec_index = cls.defs.body.index(stmt) + + left_hand_explicit_type: Optional[ProperType] = None + + if util.name_is_dunder(stmt.name): + # for dunder names like __table_args__, __tablename__, + # __mapper_args__ etc., rewrite these as simple assignment + # statements; otherwise mypy doesn't like if the decorated + # function has an annotation like ``cls: Type[Foo]`` because + # it isn't @classmethod + any_ = AnyType(TypeOfAny.special_form) + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + new_stmt = AssignmentStmt([left_node], TempNode(any_)) + new_stmt.type = left_node.node.type + cls.defs.body[dec_index] = new_stmt + return + elif isinstance(stmt.func.type, CallableType): + func_type = stmt.func.type.ret_type + if isinstance(func_type, UnboundType): + type_id = names.type_id_for_unbound_type(func_type, cls, api) + else: + # this does not seem to occur unless the type argument is + # incorrect + return + + if ( + type_id + in { + names.MAPPED, + names.RELATIONSHIP, + names.COMPOSITE_PROPERTY, + names.MAPPER_PROPERTY, + names.SYNONYM_PROPERTY, + names.COLUMN_PROPERTY, + } + and func_type.args + ): + left_hand_explicit_type = get_proper_type(func_type.args[0]) + elif type_id is names.COLUMN and func_type.args: + typeengine_arg = func_type.args[0] + if isinstance(typeengine_arg, UnboundType): + sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg) + if sym is not None and isinstance(sym.node, TypeInfo): + if names.has_base_type_id(sym.node, names.TYPEENGINE): + left_hand_explicit_type = UnionType( + [ + infer.extract_python_type_from_typeengine( + api, sym.node, [] + ), + NoneType(), + ] + ) + else: + util.fail( + api, + "Column type should be a TypeEngine " + "subclass not '{}'".format(sym.node.fullname), + func_type, + ) + + if left_hand_explicit_type is None: + # no type on the decorated function. our option here is to + # dig into the function body and get the return type, but they + # should just have an annotation. + msg = ( + "Can't infer type from @declared_attr on function '{}'; " + "please specify a return type from this function that is " + "one of: Mapped[], relationship[], " + "Column[], MapperProperty[]" + ) + util.fail(api, msg.format(stmt.var.name), stmt) + + left_hand_explicit_type = AnyType(TypeOfAny.special_form) + + left_node = NameExpr(stmt.var.name) + left_node.node = stmt.var + + # totally feeling around in the dark here as I don't totally understand + # the significance of UnboundType. It seems to be something that is + # not going to do what's expected when it is applied as the type of + # an AssignmentStatement. So do a feeling-around-in-the-dark version + # of converting it to the regular Instance/TypeInfo/UnionType structures + # we see everywhere else. + if isinstance(left_hand_explicit_type, UnboundType): + left_hand_explicit_type = get_proper_type( + util.unbound_to_instance(api, left_hand_explicit_type) + ) + + left_node.node.type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type] + ) + + # this will ignore the rvalue entirely + # rvalue = TempNode(AnyType(TypeOfAny.special_form)) + + # rewrite the node as: + # : Mapped[] = + # _sa_Mapped._empty_constructor(lambda: ) + # the function body is maintained so it gets type checked internally + rvalue = names.expr_to_mapped_constructor( + LambdaExpr(stmt.func.arguments, stmt.func.body) + ) + + new_stmt = AssignmentStmt([left_node], rvalue) + new_stmt.type = left_node.node.type + + attributes.append( + util.SQLAlchemyAttribute( + name=left_node.name, + line=stmt.line, + column=stmt.column, + typ=left_hand_explicit_type, + info=cls.info, + ) + ) + cls.defs.body[dec_index] = new_stmt + + +def _scan_declarative_assignment_stmt( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + attributes: List[util.SQLAlchemyAttribute], +) -> None: + """Extract mapping information from an assignment statement in a + declarative class. + + """ + lvalue = stmt.lvalues[0] + if not isinstance(lvalue, NameExpr): + return + + sym = cls.info.names.get(lvalue.name) + + # this establishes that semantic analysis has taken place, which + # means the nodes are populated and we are called from an appropriate + # hook. + assert sym is not None + node = sym.node + + if isinstance(node, PlaceholderNode): + return + + assert node is lvalue.node + assert isinstance(node, Var) + + if node.name == "__abstract__": + if api.parse_bool(stmt.rvalue) is True: + util.set_is_base(cls.info) + return + elif node.name == "__tablename__": + util.set_has_table(cls.info) + elif node.name.startswith("__"): + return + elif node.name == "_mypy_mapped_attrs": + if not isinstance(stmt.rvalue, ListExpr): + util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt) + else: + for item in stmt.rvalue.items: + if isinstance(item, (NameExpr, StrExpr)): + apply.apply_mypy_mapped_attr(cls, api, item, attributes) + + left_hand_mapped_type: Optional[Type] = None + left_hand_explicit_type: Optional[ProperType] = None + + if node.is_inferred or node.type is None: + if isinstance(stmt.type, UnboundType): + # look for an explicit Mapped[] type annotation on the left + # side with nothing on the right + + # print(stmt.type) + # Mapped?[Optional?[A?]] + + left_hand_explicit_type = stmt.type + + if stmt.type.name == "Mapped": + mapped_sym = api.lookup_qualified("Mapped", cls) + if ( + mapped_sym is not None + and mapped_sym.node is not None + and names.type_id_for_named_node(mapped_sym.node) + is names.MAPPED + ): + left_hand_explicit_type = get_proper_type( + stmt.type.args[0] + ) + left_hand_mapped_type = stmt.type + + # TODO: do we need to convert from unbound for this case? + # left_hand_explicit_type = util._unbound_to_instance( + # api, left_hand_explicit_type + # ) + else: + node_type = get_proper_type(node.type) + if ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.MAPPED + ): + # print(node.type) + # sqlalchemy.orm.attributes.Mapped[] + left_hand_explicit_type = get_proper_type(node_type.args[0]) + left_hand_mapped_type = node_type + else: + # print(node.type) + # + left_hand_explicit_type = node_type + left_hand_mapped_type = None + + if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None: + # annotation without assignment and Mapped is present + # as type annotation + # equivalent to using _infer_type_from_left_hand_type_only. + + python_type_for_type = left_hand_explicit_type + elif isinstance(stmt.rvalue, CallExpr) and isinstance( + stmt.rvalue.callee, RefExpr + ): + python_type_for_type = infer.infer_type_from_right_hand_nameexpr( + api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee + ) + + if python_type_for_type is None: + return + + else: + return + + assert python_type_for_type is not None + + attributes.append( + util.SQLAlchemyAttribute( + name=node.name, + line=stmt.line, + column=stmt.column, + typ=python_type_for_type, + info=cls.info, + ) + ) + + apply.apply_type_to_mapped_statement( + api, + stmt, + lvalue, + left_hand_explicit_type, + python_type_for_type, + ) + + +def _scan_for_mapped_bases( + cls: ClassDef, + api: SemanticAnalyzerPluginInterface, +) -> None: + """Given a class, iterate through its superclass hierarchy to find + all other classes that are considered as ORM-significant. + + Locates non-mapped mixins and scans them for mapped attributes to be + applied to subclasses. + + """ + + info = util.info_for_cls(cls, api) + + if info is None: + return + + for base_info in info.mro[1:-1]: + if base_info.fullname.startswith("builtins"): + continue + + # scan each base for mapped attributes. if they are not already + # scanned (but have all their type info), that means they are unmapped + # mixins + scan_declarative_assignments_and_apply_types( + base_info.defn, api, is_mixin_scan=True + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/infer.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/infer.py new file mode 100644 index 0000000..09b3c44 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/infer.py @@ -0,0 +1,590 @@ +# ext/mypy/infer.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Optional +from typing import Sequence + +from mypy.maptype import map_instance_to_supertype +from mypy.nodes import AssignmentStmt +from mypy.nodes import CallExpr +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import LambdaExpr +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import RefExpr +from mypy.nodes import StrExpr +from mypy.nodes import TypeInfo +from mypy.nodes import Var +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.subtypes import is_subtype +from mypy.types import AnyType +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import ProperType +from mypy.types import TypeOfAny +from mypy.types import UnionType + +from . import names +from . import util + + +def infer_type_from_right_hand_nameexpr( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + type_id = names.type_id_for_callee(infer_from_right_side) + if type_id is None: + return None + elif type_id is names.MAPPED: + python_type_for_type = _infer_type_from_mapped( + api, stmt, node, left_hand_explicit_type, infer_from_right_side + ) + elif type_id is names.COLUMN: + python_type_for_type = _infer_type_from_decl_column( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.RELATIONSHIP: + python_type_for_type = _infer_type_from_relationship( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.COLUMN_PROPERTY: + python_type_for_type = _infer_type_from_decl_column_property( + api, stmt, node, left_hand_explicit_type + ) + elif type_id is names.SYNONYM_PROPERTY: + python_type_for_type = infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif type_id is names.COMPOSITE_PROPERTY: + python_type_for_type = _infer_type_from_decl_composite_property( + api, stmt, node, left_hand_explicit_type + ) + else: + return None + + return python_type_for_type + + +def _infer_type_from_relationship( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a relationship. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + addresses = relationship(Address, uselist=True) + + order: Mapped["Order"] = relationship("Order") + + Will resolve in mypy as:: + + @reg.mapped + class MyClass: + # ... + + addresses: Mapped[List[Address]] + + order: Mapped["Order"] + + """ + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type: Optional[ProperType] = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + # type + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + + # other cases not covered - an error message directs the user + # to set an explicit type annotation + # + # node.type == str, it's a string + # if isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, Var + # ) + # points to a type + # isinstance(target_cls_arg, NameExpr) and isinstance( + # target_cls_arg.node, TypeAlias + # ) + # string expression + # isinstance(target_cls_arg, StrExpr) + + uselist_arg = util.get_callexpr_kwarg(stmt.rvalue, "uselist") + collection_cls_arg: Optional[Expression] = util.get_callexpr_kwarg( + stmt.rvalue, "collection_class" + ) + type_is_a_collection = False + + # this can be used to determine Optional for a many-to-one + # in the same way nullable=False could be used, if we start supporting + # that. + # innerjoin_arg = util.get_callexpr_kwarg(stmt.rvalue, "innerjoin") + + if ( + uselist_arg is not None + and api.parse_bool(uselist_arg) is True + and collection_cls_arg is None + ): + type_is_a_collection = True + if python_type_for_type is not None: + python_type_for_type = api.named_type( + names.NAMED_TYPE_BUILTINS_LIST, [python_type_for_type] + ) + elif ( + uselist_arg is None or api.parse_bool(uselist_arg) is True + ) and collection_cls_arg is not None: + type_is_a_collection = True + if isinstance(collection_cls_arg, CallExpr): + collection_cls_arg = collection_cls_arg.callee + + if isinstance(collection_cls_arg, NameExpr) and isinstance( + collection_cls_arg.node, TypeInfo + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + python_type_for_type = Instance( + collection_cls_arg.node, [python_type_for_type] + ) + elif ( + isinstance(collection_cls_arg, NameExpr) + and isinstance(collection_cls_arg.node, FuncDef) + and collection_cls_arg.node.type is not None + ): + if python_type_for_type is not None: + # this can still be overridden by the left hand side + # within _infer_Type_from_left_and_inferred_right + + # TODO: handle mypy.types.Overloaded + if isinstance(collection_cls_arg.node.type, CallableType): + rt = get_proper_type(collection_cls_arg.node.type.ret_type) + + if isinstance(rt, CallableType): + callable_ret_type = get_proper_type(rt.ret_type) + if isinstance(callable_ret_type, Instance): + python_type_for_type = Instance( + callable_ret_type.type, + [python_type_for_type], + ) + else: + util.fail( + api, + "Expected Python collection type for " + "collection_class parameter", + stmt.rvalue, + ) + python_type_for_type = None + elif uselist_arg is not None and api.parse_bool(uselist_arg) is False: + if collection_cls_arg is not None: + util.fail( + api, + "Sending uselist=False and collection_class at the same time " + "does not make sense", + stmt.rvalue, + ) + if python_type_for_type is not None: + python_type_for_type = UnionType( + [python_type_for_type, NoneType()] + ) + + else: + if left_hand_explicit_type is None: + msg = ( + "Can't infer scalar or collection for ORM mapped expression " + "assigned to attribute '{}' if both 'uselist' and " + "'collection_class' arguments are absent from the " + "relationship(); please specify a " + "type annotation on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + if type_is_a_collection: + assert isinstance(left_hand_explicit_type, Instance) + assert isinstance(python_type_for_type, Instance) + return _infer_collection_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_explicit_type, + python_type_for_type, + ) + else: + return python_type_for_type + + +def _infer_type_from_decl_composite_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a Composite.""" + + assert isinstance(stmt.rvalue, CallExpr) + target_cls_arg = stmt.rvalue.args[0] + python_type_for_type = None + + if isinstance(target_cls_arg, NameExpr) and isinstance( + target_cls_arg.node, TypeInfo + ): + related_object_type = target_cls_arg.node + python_type_for_type = Instance(related_object_type, []) + else: + python_type_for_type = None + + if python_type_for_type is None: + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + elif left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + else: + return python_type_for_type + + +def _infer_type_from_mapped( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + infer_from_right_side: RefExpr, +) -> Optional[ProperType]: + """Infer the type of mapping from a right side expression + that returns Mapped. + + + """ + assert isinstance(stmt.rvalue, CallExpr) + + # (Pdb) print(stmt.rvalue.callee) + # NameExpr(query_expression [sqlalchemy.orm._orm_constructors.query_expression]) # noqa: E501 + # (Pdb) stmt.rvalue.callee.node + # + # (Pdb) stmt.rvalue.callee.node.type + # def [_T] (default_expr: sqlalchemy.sql.elements.ColumnElement[_T`-1] =) -> sqlalchemy.orm.base.Mapped[_T`-1] # noqa: E501 + # sqlalchemy.orm.base.Mapped[_T`-1] + # the_mapped_type = stmt.rvalue.callee.node.type.ret_type + + # TODO: look at generic ref and either use that, + # or reconcile w/ what's present, etc. + the_mapped_type = util.type_for_callee(infer_from_right_side) # noqa + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column_property( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Infer the type of mapping from a ColumnProperty. + + This includes mappings against ``column_property()`` as well as the + ``deferred()`` function. + + """ + assert isinstance(stmt.rvalue, CallExpr) + + if stmt.rvalue.args: + first_prop_arg = stmt.rvalue.args[0] + + if isinstance(first_prop_arg, CallExpr): + type_id = names.type_id_for_callee(first_prop_arg.callee) + + # look for column_property() / deferred() etc with Column as first + # argument + if type_id is names.COLUMN: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + right_hand_expression=first_prop_arg, + ) + + if isinstance(stmt.rvalue, CallExpr): + type_id = names.type_id_for_callee(stmt.rvalue.callee) + # this is probably not strictly necessary as we have to use the left + # hand type for query expression in any case. any other no-arg + # column prop objects would go here also + if type_id is names.QUERY_EXPRESSION: + return _infer_type_from_decl_column( + api, + stmt, + node, + left_hand_explicit_type, + ) + + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_decl_column( + api: SemanticAnalyzerPluginInterface, + stmt: AssignmentStmt, + node: Var, + left_hand_explicit_type: Optional[ProperType], + right_hand_expression: Optional[CallExpr] = None, +) -> Optional[ProperType]: + """Infer the type of mapping from a Column. + + E.g.:: + + @reg.mapped + class MyClass: + # ... + + a = Column(Integer) + + b = Column("b", String) + + c: Mapped[int] = Column(Integer) + + d: bool = Column(Boolean) + + Will resolve in MyPy as:: + + @reg.mapped + class MyClass: + # ... + + a : Mapped[int] + + b : Mapped[str] + + c: Mapped[int] + + d: Mapped[bool] + + """ + assert isinstance(node, Var) + + callee = None + + if right_hand_expression is None: + if not isinstance(stmt.rvalue, CallExpr): + return None + + right_hand_expression = stmt.rvalue + + for column_arg in right_hand_expression.args[0:2]: + if isinstance(column_arg, CallExpr): + if isinstance(column_arg.callee, RefExpr): + # x = Column(String(50)) + callee = column_arg.callee + type_args: Sequence[Expression] = column_arg.args + break + elif isinstance(column_arg, (NameExpr, MemberExpr)): + if isinstance(column_arg.node, TypeInfo): + # x = Column(String) + callee = column_arg + type_args = () + break + else: + # x = Column(some_name, String), go to next argument + continue + elif isinstance(column_arg, (StrExpr,)): + # x = Column("name", String), go to next argument + continue + elif isinstance(column_arg, (LambdaExpr,)): + # x = Column("name", String, default=lambda: uuid.uuid4()) + # go to next argument + continue + else: + assert False + + if callee is None: + return None + + if isinstance(callee.node, TypeInfo) and names.mro_has_id( + callee.node.mro, names.TYPEENGINE + ): + python_type_for_type = extract_python_type_from_typeengine( + api, callee.node, type_args + ) + + if left_hand_explicit_type is not None: + return _infer_type_from_left_and_inferred_right( + api, node, left_hand_explicit_type, python_type_for_type + ) + + else: + return UnionType([python_type_for_type, NoneType()]) + else: + # it's not TypeEngine, it's typically implicitly typed + # like ForeignKey. we can't infer from the right side. + return infer_type_from_left_hand_type_only( + api, node, left_hand_explicit_type + ) + + +def _infer_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: ProperType, + python_type_for_type: ProperType, + orig_left_hand_type: Optional[ProperType] = None, + orig_python_type_for_type: Optional[ProperType] = None, +) -> Optional[ProperType]: + """Validate type when a left hand annotation is present and we also + could infer the right hand side:: + + attrname: SomeType = Column(SomeDBType) + + """ + + if orig_left_hand_type is None: + orig_left_hand_type = left_hand_explicit_type + if orig_python_type_for_type is None: + orig_python_type_for_type = python_type_for_type + + if not is_subtype(left_hand_explicit_type, python_type_for_type): + effective_type = api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [orig_python_type_for_type] + ) + + msg = ( + "Left hand assignment '{}: {}' not compatible " + "with ORM mapped expression of type {}" + ) + util.fail( + api, + msg.format( + node.name, + util.format_type(orig_left_hand_type, api.options), + util.format_type(effective_type, api.options), + ), + node, + ) + + return orig_left_hand_type + + +def _infer_collection_type_from_left_and_inferred_right( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Instance, + python_type_for_type: Instance, +) -> Optional[ProperType]: + orig_left_hand_type = left_hand_explicit_type + orig_python_type_for_type = python_type_for_type + + if left_hand_explicit_type.args: + left_hand_arg = get_proper_type(left_hand_explicit_type.args[0]) + python_type_arg = get_proper_type(python_type_for_type.args[0]) + else: + left_hand_arg = left_hand_explicit_type + python_type_arg = python_type_for_type + + assert isinstance(left_hand_arg, (Instance, UnionType)) + assert isinstance(python_type_arg, (Instance, UnionType)) + + return _infer_type_from_left_and_inferred_right( + api, + node, + left_hand_arg, + python_type_arg, + orig_left_hand_type=orig_left_hand_type, + orig_python_type_for_type=orig_python_type_for_type, + ) + + +def infer_type_from_left_hand_type_only( + api: SemanticAnalyzerPluginInterface, + node: Var, + left_hand_explicit_type: Optional[ProperType], +) -> Optional[ProperType]: + """Determine the type based on explicit annotation only. + + if no annotation were present, note that we need one there to know + the type. + + """ + if left_hand_explicit_type is None: + msg = ( + "Can't infer type from ORM mapped expression " + "assigned to attribute '{}'; please specify a " + "Python type or " + "Mapped[] on the left hand side." + ) + util.fail(api, msg.format(node.name), node) + + return api.named_type( + names.NAMED_TYPE_SQLA_MAPPED, [AnyType(TypeOfAny.special_form)] + ) + + else: + # use type from the left hand side + return left_hand_explicit_type + + +def extract_python_type_from_typeengine( + api: SemanticAnalyzerPluginInterface, + node: TypeInfo, + type_args: Sequence[Expression], +) -> ProperType: + if node.fullname == "sqlalchemy.sql.sqltypes.Enum" and type_args: + first_arg = type_args[0] + if isinstance(first_arg, RefExpr) and isinstance( + first_arg.node, TypeInfo + ): + for base_ in first_arg.node.mro: + if base_.fullname == "enum.Enum": + return Instance(first_arg.node, []) + # TODO: support other pep-435 types here + else: + return api.named_type(names.NAMED_TYPE_BUILTINS_STR, []) + + assert node.has_base("sqlalchemy.sql.type_api.TypeEngine"), ( + "could not extract Python type from node: %s" % node + ) + + type_engine_sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.sql.type_api.TypeEngine" + ) + + assert type_engine_sym is not None and isinstance( + type_engine_sym.node, TypeInfo + ) + type_engine = map_instance_to_supertype( + Instance(node, []), + type_engine_sym.node, + ) + return get_proper_type(type_engine.args[-1]) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/names.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/names.py new file mode 100644 index 0000000..fc3d708 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/names.py @@ -0,0 +1,335 @@ +# ext/mypy/names.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Dict +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +from mypy.nodes import ARG_POS +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import Decorator +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import OverloadedFuncDef +from mypy.nodes import SymbolNode +from mypy.nodes import TypeAlias +from mypy.nodes import TypeInfo +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import UnboundType + +from ... import util + +COLUMN: int = util.symbol("COLUMN") +RELATIONSHIP: int = util.symbol("RELATIONSHIP") +REGISTRY: int = util.symbol("REGISTRY") +COLUMN_PROPERTY: int = util.symbol("COLUMN_PROPERTY") +TYPEENGINE: int = util.symbol("TYPEENGNE") +MAPPED: int = util.symbol("MAPPED") +DECLARATIVE_BASE: int = util.symbol("DECLARATIVE_BASE") +DECLARATIVE_META: int = util.symbol("DECLARATIVE_META") +MAPPED_DECORATOR: int = util.symbol("MAPPED_DECORATOR") +SYNONYM_PROPERTY: int = util.symbol("SYNONYM_PROPERTY") +COMPOSITE_PROPERTY: int = util.symbol("COMPOSITE_PROPERTY") +DECLARED_ATTR: int = util.symbol("DECLARED_ATTR") +MAPPER_PROPERTY: int = util.symbol("MAPPER_PROPERTY") +AS_DECLARATIVE: int = util.symbol("AS_DECLARATIVE") +AS_DECLARATIVE_BASE: int = util.symbol("AS_DECLARATIVE_BASE") +DECLARATIVE_MIXIN: int = util.symbol("DECLARATIVE_MIXIN") +QUERY_EXPRESSION: int = util.symbol("QUERY_EXPRESSION") + +# names that must succeed with mypy.api.named_type +NAMED_TYPE_BUILTINS_OBJECT = "builtins.object" +NAMED_TYPE_BUILTINS_STR = "builtins.str" +NAMED_TYPE_BUILTINS_LIST = "builtins.list" +NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" + +_RelFullNames = { + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.relationships._RelationshipDeclared", + "sqlalchemy.orm.Relationship", + "sqlalchemy.orm.RelationshipProperty", +} + +_lookup: Dict[str, Tuple[int, Set[str]]] = { + "Column": ( + COLUMN, + { + "sqlalchemy.sql.schema.Column", + "sqlalchemy.sql.Column", + }, + ), + "Relationship": (RELATIONSHIP, _RelFullNames), + "RelationshipProperty": (RELATIONSHIP, _RelFullNames), + "_RelationshipDeclared": (RELATIONSHIP, _RelFullNames), + "registry": ( + REGISTRY, + { + "sqlalchemy.orm.decl_api.registry", + "sqlalchemy.orm.registry", + }, + ), + "ColumnProperty": ( + COLUMN_PROPERTY, + { + "sqlalchemy.orm.properties.MappedSQLExpression", + "sqlalchemy.orm.MappedSQLExpression", + "sqlalchemy.orm.properties.ColumnProperty", + "sqlalchemy.orm.ColumnProperty", + }, + ), + "MappedSQLExpression": ( + COLUMN_PROPERTY, + { + "sqlalchemy.orm.properties.MappedSQLExpression", + "sqlalchemy.orm.MappedSQLExpression", + "sqlalchemy.orm.properties.ColumnProperty", + "sqlalchemy.orm.ColumnProperty", + }, + ), + "Synonym": ( + SYNONYM_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", + "sqlalchemy.orm.descriptor_props.SynonymProperty", + "sqlalchemy.orm.SynonymProperty", + }, + ), + "SynonymProperty": ( + SYNONYM_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Synonym", + "sqlalchemy.orm.Synonym", + "sqlalchemy.orm.descriptor_props.SynonymProperty", + "sqlalchemy.orm.SynonymProperty", + }, + ), + "Composite": ( + COMPOSITE_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", + "sqlalchemy.orm.descriptor_props.CompositeProperty", + "sqlalchemy.orm.CompositeProperty", + }, + ), + "CompositeProperty": ( + COMPOSITE_PROPERTY, + { + "sqlalchemy.orm.descriptor_props.Composite", + "sqlalchemy.orm.Composite", + "sqlalchemy.orm.descriptor_props.CompositeProperty", + "sqlalchemy.orm.CompositeProperty", + }, + ), + "MapperProperty": ( + MAPPER_PROPERTY, + { + "sqlalchemy.orm.interfaces.MapperProperty", + "sqlalchemy.orm.MapperProperty", + }, + ), + "TypeEngine": (TYPEENGINE, {"sqlalchemy.sql.type_api.TypeEngine"}), + "Mapped": (MAPPED, {NAMED_TYPE_SQLA_MAPPED}), + "declarative_base": ( + DECLARATIVE_BASE, + { + "sqlalchemy.ext.declarative.declarative_base", + "sqlalchemy.orm.declarative_base", + "sqlalchemy.orm.decl_api.declarative_base", + }, + ), + "DeclarativeMeta": ( + DECLARATIVE_META, + { + "sqlalchemy.ext.declarative.DeclarativeMeta", + "sqlalchemy.orm.DeclarativeMeta", + "sqlalchemy.orm.decl_api.DeclarativeMeta", + }, + ), + "mapped": ( + MAPPED_DECORATOR, + { + "sqlalchemy.orm.decl_api.registry.mapped", + "sqlalchemy.orm.registry.mapped", + }, + ), + "as_declarative": ( + AS_DECLARATIVE, + { + "sqlalchemy.ext.declarative.as_declarative", + "sqlalchemy.orm.decl_api.as_declarative", + "sqlalchemy.orm.as_declarative", + }, + ), + "as_declarative_base": ( + AS_DECLARATIVE_BASE, + { + "sqlalchemy.orm.decl_api.registry.as_declarative_base", + "sqlalchemy.orm.registry.as_declarative_base", + }, + ), + "declared_attr": ( + DECLARED_ATTR, + { + "sqlalchemy.orm.decl_api.declared_attr", + "sqlalchemy.orm.declared_attr", + }, + ), + "declarative_mixin": ( + DECLARATIVE_MIXIN, + { + "sqlalchemy.orm.decl_api.declarative_mixin", + "sqlalchemy.orm.declarative_mixin", + }, + ), + "query_expression": ( + QUERY_EXPRESSION, + { + "sqlalchemy.orm.query_expression", + "sqlalchemy.orm._orm_constructors.query_expression", + }, + ), +} + + +def has_base_type_id(info: TypeInfo, type_id: int) -> bool: + for mr in info.mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def mro_has_id(mro: List[TypeInfo], type_id: int) -> bool: + for mr in mro: + check_type_id, fullnames = _lookup.get(mr.name, (None, None)) + if check_type_id == type_id: + break + else: + return False + + if fullnames is None: + return False + + return mr.fullname in fullnames + + +def type_id_for_unbound_type( + type_: UnboundType, cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[int]: + sym = api.lookup_qualified(type_.name, type_) + if sym is not None: + if isinstance(sym.node, TypeAlias): + target_type = get_proper_type(sym.node.target) + if isinstance(target_type, Instance): + return type_id_for_named_node(target_type.type) + elif isinstance(sym.node, TypeInfo): + return type_id_for_named_node(sym.node) + + return None + + +def type_id_for_callee(callee: Expression) -> Optional[int]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, Decorator) and isinstance( + callee.node.func, FuncDef + ): + if callee.node.func.type and isinstance( + callee.node.func.type, CallableType + ): + ret_type = get_proper_type(callee.node.func.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + + elif isinstance(callee.node, OverloadedFuncDef): + if ( + callee.node.impl + and callee.node.impl.type + and isinstance(callee.node.impl.type, CallableType) + ): + ret_type = get_proper_type(callee.node.impl.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return type_id_for_fullname(ret_type.type.fullname) + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return type_id_for_fullname(target_type.type.fullname) + elif isinstance(callee.node, TypeInfo): + return type_id_for_named_node(callee) + return None + + +def type_id_for_named_node( + node: Union[NameExpr, MemberExpr, SymbolNode] +) -> Optional[int]: + type_id, fullnames = _lookup.get(node.name, (None, None)) + + if type_id is None or fullnames is None: + return None + elif node.fullname in fullnames: + return type_id + else: + return None + + +def type_id_for_fullname(fullname: str) -> Optional[int]: + tokens = fullname.split(".") + immediate = tokens[-1] + + type_id, fullnames = _lookup.get(immediate, (None, None)) + + if type_id is None or fullnames is None: + return None + elif fullname in fullnames: + return type_id + else: + return None + + +def expr_to_mapped_constructor(expr: Expression) -> CallExpr: + column_descriptor = NameExpr("__sa_Mapped") + column_descriptor.fullname = NAMED_TYPE_SQLA_MAPPED + member_expr = MemberExpr(column_descriptor, "_empty_constructor") + return CallExpr( + member_expr, + [expr], + [ARG_POS], + ["arg1"], + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/plugin.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/plugin.py new file mode 100644 index 0000000..00eb4d1 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/plugin.py @@ -0,0 +1,303 @@ +# ext/mypy/plugin.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +Mypy plugin for SQLAlchemy ORM. + +""" +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type as TypingType +from typing import Union + +from mypy import nodes +from mypy.mro import calculate_mro +from mypy.mro import MroError +from mypy.nodes import Block +from mypy.nodes import ClassDef +from mypy.nodes import GDEF +from mypy.nodes import MypyFile +from mypy.nodes import NameExpr +from mypy.nodes import SymbolTable +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeInfo +from mypy.plugin import AttributeContext +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import Plugin +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import Type + +from . import decl_class +from . import names +from . import util + +try: + __import__("sqlalchemy-stubs") +except ImportError: + pass +else: + raise ImportError( + "The SQLAlchemy mypy plugin in SQLAlchemy " + "2.0 does not work with sqlalchemy-stubs or " + "sqlalchemy2-stubs installed, as well as with any other third party " + "SQLAlchemy stubs. Please uninstall all SQLAlchemy stubs " + "packages." + ) + + +class SQLAlchemyPlugin(Plugin): + def get_dynamic_class_hook( + self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_BASE: + return _dynamic_class_hook + return None + + def get_customize_class_mro_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return _fill_in_decorators + + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + + if sym is not None and sym.node is not None: + type_id = names.type_id_for_named_node(sym.node) + if type_id is names.MAPPED_DECORATOR: + return _cls_decorator_hook + elif type_id in ( + names.AS_DECLARATIVE, + names.AS_DECLARATIVE_BASE, + ): + return _base_cls_decorator_hook + elif type_id is names.DECLARATIVE_MIXIN: + return _declarative_mixin_hook + + return None + + def get_metaclass_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + if names.type_id_for_fullname(fullname) is names.DECLARATIVE_META: + # Set any classes that explicitly have metaclass=DeclarativeMeta + # as declarative so the check in `get_base_class_hook()` works + return _metaclass_cls_hook + + return None + + def get_base_class_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + sym = self.lookup_fully_qualified(fullname) + + if ( + sym + and isinstance(sym.node, TypeInfo) + and util.has_declarative_base(sym.node) + ): + return _base_cls_hook + + return None + + def get_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + if fullname.startswith( + "sqlalchemy.orm.attributes.QueryableAttribute." + ): + return _queryable_getattr_hook + + return None + + def get_additional_deps( + self, file: MypyFile + ) -> List[Tuple[int, str, int]]: + return [ + # + (10, "sqlalchemy.orm", -1), + (10, "sqlalchemy.orm.attributes", -1), + (10, "sqlalchemy.orm.decl_api", -1), + ] + + +def plugin(version: str) -> TypingType[SQLAlchemyPlugin]: + return SQLAlchemyPlugin + + +def _dynamic_class_hook(ctx: DynamicClassDefContext) -> None: + """Generate a declarative Base class when the declarative_base() function + is encountered.""" + + _add_globals(ctx) + + cls = ClassDef(ctx.name, Block([])) + cls.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), cls, ctx.api.cur_mod_id) + cls.info = info + _set_declarative_metaclass(ctx.api, cls) + + cls_arg = util.get_callexpr_kwarg(ctx.call, "cls", expr_types=(NameExpr,)) + if cls_arg is not None and isinstance(cls_arg.node, TypeInfo): + util.set_is_base(cls_arg.node) + decl_class.scan_declarative_assignments_and_apply_types( + cls_arg.node.defn, ctx.api, is_mixin_scan=True + ) + info.bases = [Instance(cls_arg.node, [])] + else: + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + + info.bases = [obj] + + try: + calculate_mro(info) + except MroError: + util.fail( + ctx.api, "Not able to calculate MRO for declarative base", ctx.call + ) + obj = ctx.api.named_type(names.NAMED_TYPE_BUILTINS_OBJECT) + info.bases = [obj] + info.fallback_to_any = True + + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + util.set_is_base(info) + + +def _fill_in_decorators(ctx: ClassDefContext) -> None: + for decorator in ctx.cls.decorators: + # set the ".fullname" attribute of a class decorator + # that is a MemberExpr. This causes the logic in + # semanal.py->apply_class_plugin_hooks to invoke the + # get_class_decorator_hook for our "registry.map_class()" + # and "registry.as_declarative_base()" methods. + # this seems like a bug in mypy that these decorators are otherwise + # skipped. + + if ( + isinstance(decorator, nodes.CallExpr) + and isinstance(decorator.callee, nodes.MemberExpr) + and decorator.callee.name == "as_declarative_base" + ): + target = decorator.callee + elif ( + isinstance(decorator, nodes.MemberExpr) + and decorator.name == "mapped" + ): + target = decorator + else: + continue + + if isinstance(target.expr, NameExpr): + sym = ctx.api.lookup_qualified( + target.expr.name, target, suppress_errors=True + ) + else: + continue + + if sym and sym.node: + sym_type = get_proper_type(sym.type) + if isinstance(sym_type, Instance): + target.fullname = f"{sym_type.type.fullname}.{target.name}" + else: + # if the registry is in the same file as where the + # decorator is used, it might not have semantic + # symbols applied and we can't get a fully qualified + # name or an inferred type, so we are actually going to + # flag an error in this case that they need to annotate + # it. The "registry" is declared just + # once (or few times), so they have to just not use + # type inference for its assignment in this one case. + util.fail( + ctx.api, + "Class decorator called %s(), but we can't " + "tell if it's from an ORM registry. Please " + "annotate the registry assignment, e.g. " + "my_registry: registry = registry()" % target.name, + sym.node, + ) + + +def _cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + assert isinstance(ctx.reason, nodes.MemberExpr) + expr = ctx.reason.expr + + assert isinstance(expr, nodes.RefExpr) and isinstance(expr.node, nodes.Var) + + node_type = get_proper_type(expr.node.type) + + assert ( + isinstance(node_type, Instance) + and names.type_id_for_named_node(node_type.type) is names.REGISTRY + ) + + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _base_cls_decorator_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + + cls = ctx.cls + + _set_declarative_metaclass(ctx.api, cls) + + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + cls, ctx.api, is_mixin_scan=True + ) + + +def _declarative_mixin_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + util.set_is_base(ctx.cls.info) + decl_class.scan_declarative_assignments_and_apply_types( + ctx.cls, ctx.api, is_mixin_scan=True + ) + + +def _metaclass_cls_hook(ctx: ClassDefContext) -> None: + util.set_is_base(ctx.cls.info) + + +def _base_cls_hook(ctx: ClassDefContext) -> None: + _add_globals(ctx) + decl_class.scan_declarative_assignments_and_apply_types(ctx.cls, ctx.api) + + +def _queryable_getattr_hook(ctx: AttributeContext) -> Type: + # how do I....tell it it has no attribute of a certain name? + # can't find any Type that seems to match that + return ctx.default_attr_type + + +def _add_globals(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> None: + """Add __sa_DeclarativeMeta and __sa_Mapped symbol to the global space + for all class defs + + """ + + util.add_global(ctx, "sqlalchemy.orm", "Mapped", "__sa_Mapped") + + +def _set_declarative_metaclass( + api: SemanticAnalyzerPluginInterface, target_cls: ClassDef +) -> None: + info = target_cls.info + sym = api.lookup_fully_qualified_or_none( + "sqlalchemy.orm.decl_api.DeclarativeMeta" + ) + assert sym is not None and isinstance(sym.node, TypeInfo) + info.declared_metaclass = info.metaclass_type = Instance(sym.node, []) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/util.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/util.py new file mode 100644 index 0000000..7f04c48 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/mypy/util.py @@ -0,0 +1,338 @@ +# ext/mypy/util.py +# Copyright (C) 2021-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import re +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type as TypingType +from typing import TypeVar +from typing import Union + +from mypy import version +from mypy.messages import format_type as _mypy_format_type +from mypy.nodes import CallExpr +from mypy.nodes import ClassDef +from mypy.nodes import CLASSDEF_NO_INFO +from mypy.nodes import Context +from mypy.nodes import Expression +from mypy.nodes import FuncDef +from mypy.nodes import IfStmt +from mypy.nodes import JsonDict +from mypy.nodes import MemberExpr +from mypy.nodes import NameExpr +from mypy.nodes import Statement +from mypy.nodes import SymbolTableNode +from mypy.nodes import TypeAlias +from mypy.nodes import TypeInfo +from mypy.options import Options +from mypy.plugin import ClassDefContext +from mypy.plugin import DynamicClassDefContext +from mypy.plugin import SemanticAnalyzerPluginInterface +from mypy.plugins.common import deserialize_and_fixup_type +from mypy.typeops import map_type_from_supertype +from mypy.types import CallableType +from mypy.types import get_proper_type +from mypy.types import Instance +from mypy.types import NoneType +from mypy.types import Type +from mypy.types import TypeVarType +from mypy.types import UnboundType +from mypy.types import UnionType + +_vers = tuple( + [int(x) for x in version.__version__.split(".") if re.match(r"^\d+$", x)] +) +mypy_14 = _vers >= (1, 4) + + +_TArgType = TypeVar("_TArgType", bound=Union[CallExpr, NameExpr]) + + +class SQLAlchemyAttribute: + def __init__( + self, + name: str, + line: int, + column: int, + typ: Optional[Type], + info: TypeInfo, + ) -> None: + self.name = name + self.line = line + self.column = column + self.type = typ + self.info = info + + def serialize(self) -> JsonDict: + assert self.type + return { + "name": self.name, + "line": self.line, + "column": self.column, + "type": self.type.serialize(), + } + + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is + inherited from a generic super type. + """ + if not isinstance(self.type, TypeVarType): + return + + self.type = map_type_from_supertype(self.type, sub_type, self.info) + + @classmethod + def deserialize( + cls, + info: TypeInfo, + data: JsonDict, + api: SemanticAnalyzerPluginInterface, + ) -> SQLAlchemyAttribute: + data = data.copy() + typ = deserialize_and_fixup_type(data.pop("type"), api) + return cls(typ=typ, info=info, **data) + + +def name_is_dunder(name: str) -> bool: + return bool(re.match(r"^__.+?__$", name)) + + +def _set_info_metadata(info: TypeInfo, key: str, data: Any) -> None: + info.metadata.setdefault("sqlalchemy", {})[key] = data + + +def _get_info_metadata(info: TypeInfo, key: str) -> Optional[Any]: + return info.metadata.get("sqlalchemy", {}).get(key, None) + + +def _get_info_mro_metadata(info: TypeInfo, key: str) -> Optional[Any]: + if info.mro: + for base in info.mro: + metadata = _get_info_metadata(base, key) + if metadata is not None: + return metadata + return None + + +def establish_as_sqlalchemy(info: TypeInfo) -> None: + info.metadata.setdefault("sqlalchemy", {}) + + +def set_is_base(info: TypeInfo) -> None: + _set_info_metadata(info, "is_base", True) + + +def get_is_base(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "is_base") + return is_base is True + + +def has_declarative_base(info: TypeInfo) -> bool: + is_base = _get_info_mro_metadata(info, "is_base") + return is_base is True + + +def set_has_table(info: TypeInfo) -> None: + _set_info_metadata(info, "has_table", True) + + +def get_has_table(info: TypeInfo) -> bool: + is_base = _get_info_metadata(info, "has_table") + return is_base is True + + +def get_mapped_attributes( + info: TypeInfo, api: SemanticAnalyzerPluginInterface +) -> Optional[List[SQLAlchemyAttribute]]: + mapped_attributes: Optional[List[JsonDict]] = _get_info_metadata( + info, "mapped_attributes" + ) + if mapped_attributes is None: + return None + + attributes: List[SQLAlchemyAttribute] = [] + + for data in mapped_attributes: + attr = SQLAlchemyAttribute.deserialize(info, data, api) + attr.expand_typevar_from_subtype(info) + attributes.append(attr) + + return attributes + + +def format_type(typ_: Type, options: Options) -> str: + if mypy_14: + return _mypy_format_type(typ_, options) + else: + return _mypy_format_type(typ_) # type: ignore + + +def set_mapped_attributes( + info: TypeInfo, attributes: List[SQLAlchemyAttribute] +) -> None: + _set_info_metadata( + info, + "mapped_attributes", + [attribute.serialize() for attribute in attributes], + ) + + +def fail(api: SemanticAnalyzerPluginInterface, msg: str, ctx: Context) -> None: + msg = "[SQLAlchemy Mypy plugin] %s" % msg + return api.fail(msg, ctx) + + +def add_global( + ctx: Union[ClassDefContext, DynamicClassDefContext], + module: str, + symbol_name: str, + asname: str, +) -> None: + module_globals = ctx.api.modules[ctx.api.cur_mod_id].names + + if asname not in module_globals: + lookup_sym: SymbolTableNode = ctx.api.modules[module].names[ + symbol_name + ] + + module_globals[asname] = lookup_sym + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, name: str, *, expr_types: None = ... +) -> Optional[Union[CallExpr, NameExpr]]: ... + + +@overload +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Tuple[TypingType[_TArgType], ...], +) -> Optional[_TArgType]: ... + + +def get_callexpr_kwarg( + callexpr: CallExpr, + name: str, + *, + expr_types: Optional[Tuple[TypingType[Any], ...]] = None, +) -> Optional[Any]: + try: + arg_idx = callexpr.arg_names.index(name) + except ValueError: + return None + + kwarg = callexpr.args[arg_idx] + if isinstance( + kwarg, expr_types if expr_types is not None else (NameExpr, CallExpr) + ): + return kwarg + + return None + + +def flatten_typechecking(stmts: Iterable[Statement]) -> Iterator[Statement]: + for stmt in stmts: + if ( + isinstance(stmt, IfStmt) + and isinstance(stmt.expr[0], NameExpr) + and stmt.expr[0].fullname == "typing.TYPE_CHECKING" + ): + yield from stmt.body[0].body + else: + yield stmt + + +def type_for_callee(callee: Expression) -> Optional[Union[Instance, TypeInfo]]: + if isinstance(callee, (MemberExpr, NameExpr)): + if isinstance(callee.node, FuncDef): + if callee.node.type and isinstance(callee.node.type, CallableType): + ret_type = get_proper_type(callee.node.type.ret_type) + + if isinstance(ret_type, Instance): + return ret_type + + return None + elif isinstance(callee.node, TypeAlias): + target_type = get_proper_type(callee.node.target) + if isinstance(target_type, Instance): + return target_type + elif isinstance(callee.node, TypeInfo): + return callee.node + return None + + +def unbound_to_instance( + api: SemanticAnalyzerPluginInterface, typ: Type +) -> Type: + """Take the UnboundType that we seem to get as the ret_type from a FuncDef + and convert it into an Instance/TypeInfo kind of structure that seems + to work as the left-hand type of an AssignmentStatement. + + """ + + if not isinstance(typ, UnboundType): + return typ + + # TODO: figure out a more robust way to check this. The node is some + # kind of _SpecialForm, there's a typing.Optional that's _SpecialForm, + # but I can't figure out how to get them to match up + if typ.name == "Optional": + # convert from "Optional?" to the more familiar + # UnionType[..., NoneType()] + return unbound_to_instance( + api, + UnionType( + [unbound_to_instance(api, typ_arg) for typ_arg in typ.args] + + [NoneType()] + ), + ) + + node = api.lookup_qualified(typ.name, typ) + + if ( + node is not None + and isinstance(node, SymbolTableNode) + and isinstance(node.node, TypeInfo) + ): + bound_type = node.node + + return Instance( + bound_type, + [ + ( + unbound_to_instance(api, arg) + if isinstance(arg, UnboundType) + else arg + ) + for arg in typ.args + ], + ) + else: + return typ + + +def info_for_cls( + cls: ClassDef, api: SemanticAnalyzerPluginInterface +) -> Optional[TypeInfo]: + if cls.info is CLASSDEF_NO_INFO: + sym = api.lookup_qualified(cls.name, cls) + if sym is None: + return None + assert sym and isinstance(sym.node, TypeInfo) + return sym.node + + return cls.info diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/orderinglist.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/orderinglist.py new file mode 100644 index 0000000..1a12cf3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/orderinglist.py @@ -0,0 +1,416 @@ +# ext/orderinglist.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""A custom list that manages index/position information for contained +elements. + +:author: Jason Kirtland + +``orderinglist`` is a helper for mutable ordered relationships. It will +intercept list operations performed on a :func:`_orm.relationship`-managed +collection and +automatically synchronize changes in list position onto a target scalar +attribute. + +Example: A ``slide`` table, where each row refers to zero or more entries +in a related ``bullet`` table. The bullets within a slide are +displayed in order based on the value of the ``position`` column in the +``bullet`` table. As entries are reordered in memory, the value of the +``position`` attribute should be updated to reflect the new sort order:: + + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position") + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +The standard relationship mapping will produce a list-like attribute on each +``Slide`` containing all related ``Bullet`` objects, +but coping with changes in ordering is not handled automatically. +When appending a ``Bullet`` into ``Slide.bullets``, the ``Bullet.position`` +attribute will remain unset until manually assigned. When the ``Bullet`` +is inserted into the middle of the list, the following ``Bullet`` objects +will also need to be renumbered. + +The :class:`.OrderingList` object automates this task, managing the +``position`` attribute on all ``Bullet`` objects in the collection. It is +constructed using the :func:`.ordering_list` factory:: + + from sqlalchemy.ext.orderinglist import ordering_list + + Base = declarative_base() + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + class Bullet(Base): + __tablename__ = 'bullet' + id = Column(Integer, primary_key=True) + slide_id = Column(Integer, ForeignKey('slide.id')) + position = Column(Integer) + text = Column(String) + +With the above mapping the ``Bullet.position`` attribute is managed:: + + s = Slide() + s.bullets.append(Bullet()) + s.bullets.append(Bullet()) + s.bullets[1].position + >>> 1 + s.bullets.insert(1, Bullet()) + s.bullets[2].position + >>> 2 + +The :class:`.OrderingList` construct only works with **changes** to a +collection, and not the initial load from the database, and requires that the +list be sorted when loaded. Therefore, be sure to specify ``order_by`` on the +:func:`_orm.relationship` against the target ordering attribute, so that the +ordering is correct when first loaded. + +.. warning:: + + :class:`.OrderingList` only provides limited functionality when a primary + key column or unique column is the target of the sort. Operations + that are unsupported or are problematic include: + + * two entries must trade values. This is not supported directly in the + case of a primary key or unique constraint because it means at least + one row would need to be temporarily removed first, or changed to + a third, neutral value while the switch occurs. + + * an entry must be deleted in order to make room for a new entry. + SQLAlchemy's unit of work performs all INSERTs before DELETEs within a + single flush. In the case of a primary key, it will trade + an INSERT/DELETE of the same primary key for an UPDATE statement in order + to lessen the impact of this limitation, however this does not take place + for a UNIQUE column. + A future feature will allow the "DELETE before INSERT" behavior to be + possible, alleviating this limitation, though this feature will require + explicit configuration at the mapper level for sets of columns that + are to be handled in this way. + +:func:`.ordering_list` takes the name of the related object's ordering +attribute as an argument. By default, the zero-based integer index of the +object's position in the :func:`.ordering_list` is synchronized with the +ordering attribute: index 0 will get position 0, index 1 position 1, etc. To +start numbering at 1 or some other integer, provide ``count_from=1``. + + +""" +from __future__ import annotations + +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import TypeVar + +from ..orm.collections import collection +from ..orm.collections import collection_adapter + +_T = TypeVar("_T") +OrderingFunc = Callable[[int, Sequence[_T]], int] + + +__all__ = ["ordering_list"] + + +def ordering_list( + attr: str, + count_from: Optional[int] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, +) -> Callable[[], OrderingList]: + """Prepares an :class:`OrderingList` factory for use in mapper definitions. + + Returns an object suitable for use as an argument to a Mapper + relationship's ``collection_class`` option. e.g.:: + + from sqlalchemy.ext.orderinglist import ordering_list + + class Slide(Base): + __tablename__ = 'slide' + + id = Column(Integer, primary_key=True) + name = Column(String) + + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) + + :param attr: + Name of the mapped attribute to use for storage and retrieval of + ordering information + + :param count_from: + Set up an integer-based ordering, starting at ``count_from``. For + example, ``ordering_list('pos', count_from=1)`` would create a 1-based + list in SQL, storing the value in the 'pos' column. Ignored if + ``ordering_func`` is supplied. + + Additional arguments are passed to the :class:`.OrderingList` constructor. + + """ + + kw = _unsugar_count_from( + count_from=count_from, + ordering_func=ordering_func, + reorder_on_append=reorder_on_append, + ) + return lambda: OrderingList(attr, **kw) + + +# Ordering utility functions + + +def count_from_0(index, collection): + """Numbering function: consecutive integers starting at 0.""" + + return index + + +def count_from_1(index, collection): + """Numbering function: consecutive integers starting at 1.""" + + return index + 1 + + +def count_from_n_factory(start): + """Numbering function: consecutive integers starting at arbitrary start.""" + + def f(index, collection): + return index + start + + try: + f.__name__ = "count_from_%i" % start + except TypeError: + pass + return f + + +def _unsugar_count_from(**kw): + """Builds counting functions from keyword arguments. + + Keyword argument filter, prepares a simple ``ordering_func`` from a + ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. + """ + + count_from = kw.pop("count_from", None) + if kw.get("ordering_func", None) is None and count_from is not None: + if count_from == 0: + kw["ordering_func"] = count_from_0 + elif count_from == 1: + kw["ordering_func"] = count_from_1 + else: + kw["ordering_func"] = count_from_n_factory(count_from) + return kw + + +class OrderingList(List[_T]): + """A custom list that manages position information for its children. + + The :class:`.OrderingList` object is normally set up using the + :func:`.ordering_list` factory function, used in conjunction with + the :func:`_orm.relationship` function. + + """ + + ordering_attr: str + ordering_func: OrderingFunc + reorder_on_append: bool + + def __init__( + self, + ordering_attr: Optional[str] = None, + ordering_func: Optional[OrderingFunc] = None, + reorder_on_append: bool = False, + ): + """A custom list that manages position information for its children. + + ``OrderingList`` is a ``collection_class`` list implementation that + syncs position in a Python list with a position attribute on the + mapped objects. + + This implementation relies on the list starting in the proper order, + so be **sure** to put an ``order_by`` on your relationship. + + :param ordering_attr: + Name of the attribute that stores the object's order in the + relationship. + + :param ordering_func: Optional. A function that maps the position in + the Python list to a value to store in the + ``ordering_attr``. Values returned are usually (but need not be!) + integers. + + An ``ordering_func`` is called with two positional parameters: the + index of the element in the list, and the list itself. + + If omitted, Python list indexes are used for the attribute values. + Two basic pre-built numbering functions are provided in this module: + ``count_from_0`` and ``count_from_1``. For more exotic examples + like stepped numbering, alphabetical and Fibonacci numbering, see + the unit tests. + + :param reorder_on_append: + Default False. When appending an object with an existing (non-None) + ordering value, that value will be left untouched unless + ``reorder_on_append`` is true. This is an optimization to avoid a + variety of dangerous unexpected database writes. + + SQLAlchemy will add instances to the list via append() when your + object loads. If for some reason the result set from the database + skips a step in the ordering (say, row '1' is missing but you get + '2', '3', and '4'), reorder_on_append=True would immediately + renumber the items to '1', '2', '3'. If you have multiple sessions + making changes, any of whom happen to load this collection even in + passing, all of the sessions would try to "clean up" the numbering + in their commits, possibly causing all but one to fail with a + concurrent modification error. + + Recommend leaving this with the default of False, and just call + ``reorder()`` if you're doing ``append()`` operations with + previously ordered instances or when doing some housekeeping after + manual sql operations. + + """ + self.ordering_attr = ordering_attr + if ordering_func is None: + ordering_func = count_from_0 + self.ordering_func = ordering_func + self.reorder_on_append = reorder_on_append + + # More complex serialization schemes (multi column, e.g.) are possible by + # subclassing and reimplementing these two methods. + def _get_order_value(self, entity): + return getattr(entity, self.ordering_attr) + + def _set_order_value(self, entity, value): + setattr(entity, self.ordering_attr, value) + + def reorder(self) -> None: + """Synchronize ordering for the entire collection. + + Sweeps through the list and ensures that each object has accurate + ordering information set. + + """ + for index, entity in enumerate(self): + self._order_entity(index, entity, True) + + # As of 0.5, _reorder is no longer semi-private + _reorder = reorder + + def _order_entity(self, index, entity, reorder=True): + have = self._get_order_value(entity) + + # Don't disturb existing ordering if reorder is False + if have is not None and not reorder: + return + + should_be = self.ordering_func(index, self) + if have != should_be: + self._set_order_value(entity, should_be) + + def append(self, entity): + super().append(entity) + self._order_entity(len(self) - 1, entity, self.reorder_on_append) + + def _raw_append(self, entity): + """Append without any ordering behavior.""" + + super().append(entity) + + _raw_append = collection.adds(1)(_raw_append) + + def insert(self, index, entity): + super().insert(index, entity) + self._reorder() + + def remove(self, entity): + super().remove(entity) + + adapter = collection_adapter(self) + if adapter and adapter._referenced_by_owner: + self._reorder() + + def pop(self, index=-1): + entity = super().pop(index) + self._reorder() + return entity + + def __setitem__(self, index, entity): + if isinstance(index, slice): + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + stop = index.stop or len(self) + if stop < 0: + stop += len(self) + + for i in range(start, stop, step): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super().__setitem__(index, entity) + + def __delitem__(self, index): + super().__delitem__(index) + self._reorder() + + def __setslice__(self, start, end, values): + super().__setslice__(start, end, values) + self._reorder() + + def __delslice__(self, start, end): + super().__delslice__(start, end) + self._reorder() + + def __reduce__(self): + return _reconstitute, (self.__class__, self.__dict__, list(self)) + + for func_name, func in list(locals().items()): + if ( + callable(func) + and func.__name__ == func_name + and not func.__doc__ + and hasattr(list, func_name) + ): + func.__doc__ = getattr(list, func_name).__doc__ + del func_name, func + + +def _reconstitute(cls, dict_, items): + """Reconstitute an :class:`.OrderingList`. + + This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for + unpickling :class:`.OrderingList` objects. + + """ + obj = cls.__new__(cls) + obj.__dict__.update(dict_) + list.extend(obj, items) + return obj diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/ext/serializer.py b/venv/lib/python3.11/site-packages/sqlalchemy/ext/serializer.py new file mode 100644 index 0000000..f21e997 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/ext/serializer.py @@ -0,0 +1,185 @@ +# ext/serializer.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Serializer/Deserializer objects for usage with SQLAlchemy query structures, +allowing "contextual" deserialization. + +.. legacy:: + + The serializer extension is **legacy** and should not be used for + new development. + +Any SQLAlchemy query structure, either based on sqlalchemy.sql.* +or sqlalchemy.orm.* can be used. The mappers, Tables, Columns, Session +etc. which are referenced by the structure are not persisted in serialized +form, but are instead re-associated with the query structure +when it is deserialized. + +.. warning:: The serializer extension uses pickle to serialize and + deserialize objects, so the same security consideration mentioned + in the `python documentation + `_ apply. + +Usage is nearly the same as that of the standard Python pickle module:: + + from sqlalchemy.ext.serializer import loads, dumps + metadata = MetaData(bind=some_engine) + Session = scoped_session(sessionmaker()) + + # ... define mappers + + query = Session.query(MyClass). + filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) + + # pickle the query + serialized = dumps(query) + + # unpickle. Pass in metadata + scoped_session + query2 = loads(serialized, metadata, Session) + + print query2.all() + +Similar restrictions as when using raw pickle apply; mapped classes must be +themselves be pickleable, meaning they are importable from a module-level +namespace. + +The serializer module is only appropriate for query structures. It is not +needed for: + +* instances of user-defined classes. These contain no references to engines, + sessions or expression constructs in the typical case and can be serialized + directly. + +* Table metadata that is to be loaded entirely from the serialized structure + (i.e. is not already declared in the application). Regular + pickle.loads()/dumps() can be used to fully dump any ``MetaData`` object, + typically one which was reflected from an existing database at some previous + point in time. The serializer module is specifically for the opposite case, + where the Table metadata is already present in memory. + +""" + +from io import BytesIO +import pickle +import re + +from .. import Column +from .. import Table +from ..engine import Engine +from ..orm import class_mapper +from ..orm.interfaces import MapperProperty +from ..orm.mapper import Mapper +from ..orm.session import Session +from ..util import b64decode +from ..util import b64encode + + +__all__ = ["Serializer", "Deserializer", "dumps", "loads"] + + +def Serializer(*args, **kw): + pickler = pickle.Pickler(*args, **kw) + + def persistent_id(obj): + # print "serializing:", repr(obj) + if isinstance(obj, Mapper) and not obj.non_primary: + id_ = "mapper:" + b64encode(pickle.dumps(obj.class_)) + elif isinstance(obj, MapperProperty) and not obj.parent.non_primary: + id_ = ( + "mapperprop:" + + b64encode(pickle.dumps(obj.parent.class_)) + + ":" + + obj.key + ) + elif isinstance(obj, Table): + if "parententity" in obj._annotations: + id_ = "mapper_selectable:" + b64encode( + pickle.dumps(obj._annotations["parententity"].class_) + ) + else: + id_ = f"table:{obj.key}" + elif isinstance(obj, Column) and isinstance(obj.table, Table): + id_ = f"column:{obj.table.key}:{obj.key}" + elif isinstance(obj, Session): + id_ = "session:" + elif isinstance(obj, Engine): + id_ = "engine:" + else: + return None + return id_ + + pickler.persistent_id = persistent_id + return pickler + + +our_ids = re.compile( + r"(mapperprop|mapper|mapper_selectable|table|column|" + r"session|attribute|engine):(.*)" +) + + +def Deserializer(file, metadata=None, scoped_session=None, engine=None): + unpickler = pickle.Unpickler(file) + + def get_engine(): + if engine: + return engine + elif scoped_session and scoped_session().bind: + return scoped_session().bind + elif metadata and metadata.bind: + return metadata.bind + else: + return None + + def persistent_load(id_): + m = our_ids.match(str(id_)) + if not m: + return None + else: + type_, args = m.group(1, 2) + if type_ == "attribute": + key, clsarg = args.split(":") + cls = pickle.loads(b64decode(clsarg)) + return getattr(cls, key) + elif type_ == "mapper": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls) + elif type_ == "mapper_selectable": + cls = pickle.loads(b64decode(args)) + return class_mapper(cls).__clause_element__() + elif type_ == "mapperprop": + mapper, keyname = args.split(":") + cls = pickle.loads(b64decode(mapper)) + return class_mapper(cls).attrs[keyname] + elif type_ == "table": + return metadata.tables[args] + elif type_ == "column": + table, colname = args.split(":") + return metadata.tables[table].c[colname] + elif type_ == "session": + return scoped_session() + elif type_ == "engine": + return get_engine() + else: + raise Exception("Unknown token: %s" % type_) + + unpickler.persistent_load = persistent_load + return unpickler + + +def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL): + buf = BytesIO() + pickler = Serializer(buf, protocol) + pickler.dump(obj) + return buf.getvalue() + + +def loads(data, metadata=None, scoped_session=None, engine=None): + buf = BytesIO(data) + unpickler = Deserializer(buf, metadata, scoped_session, engine) + return unpickler.load() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/future/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/future/__init__.py new file mode 100644 index 0000000..8ce36cc --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/future/__init__.py @@ -0,0 +1,16 @@ +# future/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""2.0 API features. + +this module is legacy as 2.0 APIs are now standard. + +""" +from .engine import Connection as Connection +from .engine import create_engine as create_engine +from .engine import Engine as Engine +from ..sql._selectable_constructors import select as select diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..7cc433f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/engine.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/engine.cpython-311.pyc new file mode 100644 index 0000000..8bf4ed3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/future/__pycache__/engine.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/future/engine.py b/venv/lib/python3.11/site-packages/sqlalchemy/future/engine.py new file mode 100644 index 0000000..b55cda0 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/future/engine.py @@ -0,0 +1,15 @@ +# future/engine.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""2.0 API features. + +this module is legacy as 2.0 APIs are now standard. + +""" + +from ..engine import Connection as Connection # noqa: F401 +from ..engine import create_engine as create_engine # noqa: F401 +from ..engine import Engine as Engine # noqa: F401 diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/inspection.py b/venv/lib/python3.11/site-packages/sqlalchemy/inspection.py new file mode 100644 index 0000000..30d5319 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/inspection.py @@ -0,0 +1,174 @@ +# inspection.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""The inspection module provides the :func:`_sa.inspect` function, +which delivers runtime information about a wide variety +of SQLAlchemy objects, both within the Core as well as the +ORM. + +The :func:`_sa.inspect` function is the entry point to SQLAlchemy's +public API for viewing the configuration and construction +of in-memory objects. Depending on the type of object +passed to :func:`_sa.inspect`, the return value will either be +a related object which provides a known interface, or in many +cases it will return the object itself. + +The rationale for :func:`_sa.inspect` is twofold. One is that +it replaces the need to be aware of a large variety of "information +getting" functions in SQLAlchemy, such as +:meth:`_reflection.Inspector.from_engine` (deprecated in 1.4), +:func:`.orm.attributes.instance_state`, :func:`_orm.class_mapper`, +and others. The other is that the return value of :func:`_sa.inspect` +is guaranteed to obey a documented API, thus allowing third party +tools which build on top of SQLAlchemy configurations to be constructed +in a forwards-compatible way. + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Optional +from typing import overload +from typing import Type +from typing import TypeVar +from typing import Union + +from . import exc +from .util.typing import Literal +from .util.typing import Protocol + +_T = TypeVar("_T", bound=Any) +_TCov = TypeVar("_TCov", bound=Any, covariant=True) +_F = TypeVar("_F", bound=Callable[..., Any]) + +_IN = TypeVar("_IN", bound=Any) + +_registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {} + + +class Inspectable(Generic[_T]): + """define a class as inspectable. + + This allows typing to set up a linkage between an object that + can be inspected and the type of inspection it returns. + + Unfortunately we cannot at the moment get all classes that are + returned by inspection to suit this interface as we get into + MRO issues. + + """ + + __slots__ = () + + +class _InspectableTypeProtocol(Protocol[_TCov]): + """a protocol defining a method that's used when a type (ie the class + itself) is passed to inspect(). + + """ + + def _sa_inspect_type(self) -> _TCov: ... + + +class _InspectableProtocol(Protocol[_TCov]): + """a protocol defining a method that's used when an instance is + passed to inspect(). + + """ + + def _sa_inspect_instance(self) -> _TCov: ... + + +@overload +def inspect( + subject: Type[_InspectableTypeProtocol[_IN]], raiseerr: bool = True +) -> _IN: ... + + +@overload +def inspect( + subject: _InspectableProtocol[_IN], raiseerr: bool = True +) -> _IN: ... + + +@overload +def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: ... + + +@overload +def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: ... + + +@overload +def inspect(subject: Any, raiseerr: bool = True) -> Any: ... + + +def inspect(subject: Any, raiseerr: bool = True) -> Any: + """Produce an inspection object for the given target. + + The returned value in some cases may be the + same object as the one given, such as if a + :class:`_orm.Mapper` object is passed. In other + cases, it will be an instance of the registered + inspection type for the given object, such as + if an :class:`_engine.Engine` is passed, an + :class:`_reflection.Inspector` object is returned. + + :param subject: the subject to be inspected. + :param raiseerr: When ``True``, if the given subject + does not + correspond to a known SQLAlchemy inspected type, + :class:`sqlalchemy.exc.NoInspectionAvailable` + is raised. If ``False``, ``None`` is returned. + + """ + type_ = type(subject) + for cls in type_.__mro__: + if cls in _registrars: + reg = _registrars.get(cls, None) + if reg is None: + continue + elif reg is True: + return subject + ret = reg(subject) + if ret is not None: + return ret + else: + reg = ret = None + + if raiseerr and (reg is None or ret is None): + raise exc.NoInspectionAvailable( + "No inspection system is " + "available for object of type %s" % type_ + ) + return ret + + +def _inspects( + *types: Type[Any], +) -> Callable[[_F], _F]: + def decorate(fn_or_cls: _F) -> _F: + for type_ in types: + if type_ in _registrars: + raise AssertionError("Type %s is already registered" % type_) + _registrars[type_] = fn_or_cls + return fn_or_cls + + return decorate + + +_TT = TypeVar("_TT", bound="Type[Any]") + + +def _self_inspects(cls: _TT) -> _TT: + if cls in _registrars: + raise AssertionError("Type %s is already registered" % cls) + _registrars[cls] = True + return cls diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/log.py b/venv/lib/python3.11/site-packages/sqlalchemy/log.py new file mode 100644 index 0000000..e6922b8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/log.py @@ -0,0 +1,288 @@ +# log.py +# Copyright (C) 2006-2024 the SQLAlchemy authors and contributors +# +# Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Logging control and utilities. + +Control of logging for SA can be performed from the regular python logging +module. The regular dotted module namespace is used, starting at +'sqlalchemy'. For class-level logging, the class name is appended. + +The "echo" keyword parameter, available on SQLA :class:`_engine.Engine` +and :class:`_pool.Pool` objects, corresponds to a logger specific to that +instance only. + +""" +from __future__ import annotations + +import logging +import sys +from typing import Any +from typing import Optional +from typing import overload +from typing import Set +from typing import Type +from typing import TypeVar +from typing import Union + +from .util import py311 +from .util import py38 +from .util.typing import Literal + + +if py38: + STACKLEVEL = True + # needed as of py3.11.0b1 + # #8019 + STACKLEVEL_OFFSET = 2 if py311 else 1 +else: + STACKLEVEL = False + STACKLEVEL_OFFSET = 0 + +_IT = TypeVar("_IT", bound="Identified") + +_EchoFlagType = Union[None, bool, Literal["debug"]] + +# set initial level to WARN. This so that +# log statements don't occur in the absence of explicit +# logging being enabled for 'sqlalchemy'. +rootlogger = logging.getLogger("sqlalchemy") +if rootlogger.level == logging.NOTSET: + rootlogger.setLevel(logging.WARN) + + +def _add_default_handler(logger: logging.Logger) -> None: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s") + ) + logger.addHandler(handler) + + +_logged_classes: Set[Type[Identified]] = set() + + +def _qual_logger_name_for_cls(cls: Type[Identified]) -> str: + return ( + getattr(cls, "_sqla_logger_namespace", None) + or cls.__module__ + "." + cls.__name__ + ) + + +def class_logger(cls: Type[_IT]) -> Type[_IT]: + logger = logging.getLogger(_qual_logger_name_for_cls(cls)) + cls._should_log_debug = lambda self: logger.isEnabledFor( # type: ignore[method-assign] # noqa: E501 + logging.DEBUG + ) + cls._should_log_info = lambda self: logger.isEnabledFor( # type: ignore[method-assign] # noqa: E501 + logging.INFO + ) + cls.logger = logger + _logged_classes.add(cls) + return cls + + +_IdentifiedLoggerType = Union[logging.Logger, "InstanceLogger"] + + +class Identified: + __slots__ = () + + logging_name: Optional[str] = None + + logger: _IdentifiedLoggerType + + _echo: _EchoFlagType + + def _should_log_debug(self) -> bool: + return self.logger.isEnabledFor(logging.DEBUG) + + def _should_log_info(self) -> bool: + return self.logger.isEnabledFor(logging.INFO) + + +class InstanceLogger: + """A logger adapter (wrapper) for :class:`.Identified` subclasses. + + This allows multiple instances (e.g. Engine or Pool instances) + to share a logger, but have its verbosity controlled on a + per-instance basis. + + The basic functionality is to return a logging level + which is based on an instance's echo setting. + + Default implementation is: + + 'debug' -> logging.DEBUG + True -> logging.INFO + False -> Effective level of underlying logger ( + logging.WARNING by default) + None -> same as False + """ + + # Map echo settings to logger levels + _echo_map = { + None: logging.NOTSET, + False: logging.NOTSET, + True: logging.INFO, + "debug": logging.DEBUG, + } + + _echo: _EchoFlagType + + __slots__ = ("echo", "logger") + + def __init__(self, echo: _EchoFlagType, name: str): + self.echo = echo + self.logger = logging.getLogger(name) + + # if echo flag is enabled and no handlers, + # add a handler to the list + if self._echo_map[echo] <= logging.INFO and not self.logger.handlers: + _add_default_handler(self.logger) + + # + # Boilerplate convenience methods + # + def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a debug call to the underlying logger.""" + + self.log(logging.DEBUG, msg, *args, **kwargs) + + def info(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate an info call to the underlying logger.""" + + self.log(logging.INFO, msg, *args, **kwargs) + + def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a warning call to the underlying logger.""" + + self.log(logging.WARNING, msg, *args, **kwargs) + + warn = warning + + def error(self, msg: str, *args: Any, **kwargs: Any) -> None: + """ + Delegate an error call to the underlying logger. + """ + self.log(logging.ERROR, msg, *args, **kwargs) + + def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate an exception call to the underlying logger.""" + + kwargs["exc_info"] = 1 + self.log(logging.ERROR, msg, *args, **kwargs) + + def critical(self, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a critical call to the underlying logger.""" + + self.log(logging.CRITICAL, msg, *args, **kwargs) + + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: + """Delegate a log call to the underlying logger. + + The level here is determined by the echo + flag as well as that of the underlying logger, and + logger._log() is called directly. + + """ + + # inline the logic from isEnabledFor(), + # getEffectiveLevel(), to avoid overhead. + + if self.logger.manager.disable >= level: + return + + selected_level = self._echo_map[self.echo] + if selected_level == logging.NOTSET: + selected_level = self.logger.getEffectiveLevel() + + if level >= selected_level: + if STACKLEVEL: + kwargs["stacklevel"] = ( + kwargs.get("stacklevel", 1) + STACKLEVEL_OFFSET + ) + + self.logger._log(level, msg, args, **kwargs) + + def isEnabledFor(self, level: int) -> bool: + """Is this logger enabled for level 'level'?""" + + if self.logger.manager.disable >= level: + return False + return level >= self.getEffectiveLevel() + + def getEffectiveLevel(self) -> int: + """What's the effective level for this logger?""" + + level = self._echo_map[self.echo] + if level == logging.NOTSET: + level = self.logger.getEffectiveLevel() + return level + + +def instance_logger( + instance: Identified, echoflag: _EchoFlagType = None +) -> None: + """create a logger for an instance that implements :class:`.Identified`.""" + + if instance.logging_name: + name = "%s.%s" % ( + _qual_logger_name_for_cls(instance.__class__), + instance.logging_name, + ) + else: + name = _qual_logger_name_for_cls(instance.__class__) + + instance._echo = echoflag # type: ignore + + logger: Union[logging.Logger, InstanceLogger] + + if echoflag in (False, None): + # if no echo setting or False, return a Logger directly, + # avoiding overhead of filtering + logger = logging.getLogger(name) + else: + # if a specified echo flag, return an EchoLogger, + # which checks the flag, overrides normal log + # levels by calling logger._log() + logger = InstanceLogger(echoflag, name) + + instance.logger = logger # type: ignore + + +class echo_property: + __doc__ = """\ + When ``True``, enable log output for this element. + + This has the effect of setting the Python logging level for the namespace + of this element's class and object reference. A value of boolean ``True`` + indicates that the loglevel ``logging.INFO`` will be set for the logger, + whereas the string value ``debug`` will set the loglevel to + ``logging.DEBUG``. + """ + + @overload + def __get__( + self, instance: Literal[None], owner: Type[Identified] + ) -> echo_property: ... + + @overload + def __get__( + self, instance: Identified, owner: Type[Identified] + ) -> _EchoFlagType: ... + + def __get__( + self, instance: Optional[Identified], owner: Type[Identified] + ) -> Union[echo_property, _EchoFlagType]: + if instance is None: + return self + else: + return instance._echo + + def __set__(self, instance: Identified, value: _EchoFlagType) -> None: + instance_logger(instance, echoflag=value) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__init__.py new file mode 100644 index 0000000..70a1129 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__init__.py @@ -0,0 +1,170 @@ +# orm/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" +Functional constructs for ORM configuration. + +See the SQLAlchemy object relational tutorial and mapper configuration +documentation for an overview of how this module is used. + +""" + +from __future__ import annotations + +from typing import Any + +from . import exc as exc +from . import mapper as mapperlib +from . import strategy_options as strategy_options +from ._orm_constructors import _mapper_fn as mapper +from ._orm_constructors import aliased as aliased +from ._orm_constructors import backref as backref +from ._orm_constructors import clear_mappers as clear_mappers +from ._orm_constructors import column_property as column_property +from ._orm_constructors import composite as composite +from ._orm_constructors import contains_alias as contains_alias +from ._orm_constructors import create_session as create_session +from ._orm_constructors import deferred as deferred +from ._orm_constructors import dynamic_loader as dynamic_loader +from ._orm_constructors import join as join +from ._orm_constructors import mapped_column as mapped_column +from ._orm_constructors import orm_insert_sentinel as orm_insert_sentinel +from ._orm_constructors import outerjoin as outerjoin +from ._orm_constructors import query_expression as query_expression +from ._orm_constructors import relationship as relationship +from ._orm_constructors import synonym as synonym +from ._orm_constructors import with_loader_criteria as with_loader_criteria +from ._orm_constructors import with_polymorphic as with_polymorphic +from .attributes import AttributeEventToken as AttributeEventToken +from .attributes import InstrumentedAttribute as InstrumentedAttribute +from .attributes import QueryableAttribute as QueryableAttribute +from .base import class_mapper as class_mapper +from .base import DynamicMapped as DynamicMapped +from .base import InspectionAttrExtensionType as InspectionAttrExtensionType +from .base import LoaderCallableStatus as LoaderCallableStatus +from .base import Mapped as Mapped +from .base import NotExtension as NotExtension +from .base import ORMDescriptor as ORMDescriptor +from .base import PassiveFlag as PassiveFlag +from .base import SQLORMExpression as SQLORMExpression +from .base import WriteOnlyMapped as WriteOnlyMapped +from .context import FromStatement as FromStatement +from .context import QueryContext as QueryContext +from .decl_api import add_mapped_attribute as add_mapped_attribute +from .decl_api import as_declarative as as_declarative +from .decl_api import declarative_base as declarative_base +from .decl_api import declarative_mixin as declarative_mixin +from .decl_api import DeclarativeBase as DeclarativeBase +from .decl_api import DeclarativeBaseNoMeta as DeclarativeBaseNoMeta +from .decl_api import DeclarativeMeta as DeclarativeMeta +from .decl_api import declared_attr as declared_attr +from .decl_api import has_inherited_table as has_inherited_table +from .decl_api import MappedAsDataclass as MappedAsDataclass +from .decl_api import registry as registry +from .decl_api import synonym_for as synonym_for +from .decl_base import MappedClassProtocol as MappedClassProtocol +from .descriptor_props import Composite as Composite +from .descriptor_props import CompositeProperty as CompositeProperty +from .descriptor_props import Synonym as Synonym +from .descriptor_props import SynonymProperty as SynonymProperty +from .dynamic import AppenderQuery as AppenderQuery +from .events import AttributeEvents as AttributeEvents +from .events import InstanceEvents as InstanceEvents +from .events import InstrumentationEvents as InstrumentationEvents +from .events import MapperEvents as MapperEvents +from .events import QueryEvents as QueryEvents +from .events import SessionEvents as SessionEvents +from .identity import IdentityMap as IdentityMap +from .instrumentation import ClassManager as ClassManager +from .interfaces import EXT_CONTINUE as EXT_CONTINUE +from .interfaces import EXT_SKIP as EXT_SKIP +from .interfaces import EXT_STOP as EXT_STOP +from .interfaces import InspectionAttr as InspectionAttr +from .interfaces import InspectionAttrInfo as InspectionAttrInfo +from .interfaces import MANYTOMANY as MANYTOMANY +from .interfaces import MANYTOONE as MANYTOONE +from .interfaces import MapperProperty as MapperProperty +from .interfaces import NO_KEY as NO_KEY +from .interfaces import NO_VALUE as NO_VALUE +from .interfaces import ONETOMANY as ONETOMANY +from .interfaces import PropComparator as PropComparator +from .interfaces import RelationshipDirection as RelationshipDirection +from .interfaces import UserDefinedOption as UserDefinedOption +from .loading import merge_frozen_result as merge_frozen_result +from .loading import merge_result as merge_result +from .mapped_collection import attribute_keyed_dict as attribute_keyed_dict +from .mapped_collection import ( + attribute_mapped_collection as attribute_mapped_collection, +) +from .mapped_collection import column_keyed_dict as column_keyed_dict +from .mapped_collection import ( + column_mapped_collection as column_mapped_collection, +) +from .mapped_collection import keyfunc_mapping as keyfunc_mapping +from .mapped_collection import KeyFuncDict as KeyFuncDict +from .mapped_collection import mapped_collection as mapped_collection +from .mapped_collection import MappedCollection as MappedCollection +from .mapper import configure_mappers as configure_mappers +from .mapper import Mapper as Mapper +from .mapper import reconstructor as reconstructor +from .mapper import validates as validates +from .properties import ColumnProperty as ColumnProperty +from .properties import MappedColumn as MappedColumn +from .properties import MappedSQLExpression as MappedSQLExpression +from .query import AliasOption as AliasOption +from .query import Query as Query +from .relationships import foreign as foreign +from .relationships import Relationship as Relationship +from .relationships import RelationshipProperty as RelationshipProperty +from .relationships import remote as remote +from .scoping import QueryPropertyDescriptor as QueryPropertyDescriptor +from .scoping import scoped_session as scoped_session +from .session import close_all_sessions as close_all_sessions +from .session import make_transient as make_transient +from .session import make_transient_to_detached as make_transient_to_detached +from .session import object_session as object_session +from .session import ORMExecuteState as ORMExecuteState +from .session import Session as Session +from .session import sessionmaker as sessionmaker +from .session import SessionTransaction as SessionTransaction +from .session import SessionTransactionOrigin as SessionTransactionOrigin +from .state import AttributeState as AttributeState +from .state import InstanceState as InstanceState +from .strategy_options import contains_eager as contains_eager +from .strategy_options import defaultload as defaultload +from .strategy_options import defer as defer +from .strategy_options import immediateload as immediateload +from .strategy_options import joinedload as joinedload +from .strategy_options import lazyload as lazyload +from .strategy_options import Load as Load +from .strategy_options import load_only as load_only +from .strategy_options import noload as noload +from .strategy_options import raiseload as raiseload +from .strategy_options import selectin_polymorphic as selectin_polymorphic +from .strategy_options import selectinload as selectinload +from .strategy_options import subqueryload as subqueryload +from .strategy_options import undefer as undefer +from .strategy_options import undefer_group as undefer_group +from .strategy_options import with_expression as with_expression +from .unitofwork import UOWTransaction as UOWTransaction +from .util import Bundle as Bundle +from .util import CascadeOptions as CascadeOptions +from .util import LoaderCriteriaOption as LoaderCriteriaOption +from .util import object_mapper as object_mapper +from .util import polymorphic_union as polymorphic_union +from .util import was_deleted as was_deleted +from .util import with_parent as with_parent +from .writeonly import WriteOnlyCollection as WriteOnlyCollection +from .. import util as _sa_util + + +def __go(lcls: Any) -> None: + _sa_util.preloaded.import_prefix("sqlalchemy.orm") + _sa_util.preloaded.import_prefix("sqlalchemy.ext") + + +__go(locals()) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..8a0b109 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_orm_constructors.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_orm_constructors.cpython-311.pyc new file mode 100644 index 0000000..2fa0118 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_orm_constructors.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_typing.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_typing.cpython-311.pyc new file mode 100644 index 0000000..90d7d47 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/_typing.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/attributes.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/attributes.cpython-311.pyc new file mode 100644 index 0000000..d534677 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/attributes.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..3333b81 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/bulk_persistence.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/bulk_persistence.cpython-311.pyc new file mode 100644 index 0000000..858b764 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/bulk_persistence.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/clsregistry.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/clsregistry.cpython-311.pyc new file mode 100644 index 0000000..929ab36 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/clsregistry.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/collections.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/collections.cpython-311.pyc new file mode 100644 index 0000000..64bba3c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/collections.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/context.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/context.cpython-311.pyc new file mode 100644 index 0000000..a9888f5 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/context.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_api.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_api.cpython-311.pyc new file mode 100644 index 0000000..6aba2cb Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_api.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_base.cpython-311.pyc new file mode 100644 index 0000000..7fbe0f4 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/decl_base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dependency.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dependency.cpython-311.pyc new file mode 100644 index 0000000..de66fd6 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dependency.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/descriptor_props.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/descriptor_props.cpython-311.pyc new file mode 100644 index 0000000..2d3c764 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/descriptor_props.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dynamic.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dynamic.cpython-311.pyc new file mode 100644 index 0000000..af2c1c3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/dynamic.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/evaluator.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/evaluator.cpython-311.pyc new file mode 100644 index 0000000..dd10afb Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/evaluator.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/events.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/events.cpython-311.pyc new file mode 100644 index 0000000..383eaf6 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/events.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/exc.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/exc.cpython-311.pyc new file mode 100644 index 0000000..a9f7995 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/exc.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/identity.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/identity.cpython-311.pyc new file mode 100644 index 0000000..799b3f1 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/identity.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/instrumentation.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/instrumentation.cpython-311.pyc new file mode 100644 index 0000000..e2986cd Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/instrumentation.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/interfaces.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/interfaces.cpython-311.pyc new file mode 100644 index 0000000..15154d7 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/interfaces.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/loading.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/loading.cpython-311.pyc new file mode 100644 index 0000000..c5396e8 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/loading.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapped_collection.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapped_collection.cpython-311.pyc new file mode 100644 index 0000000..44aea8f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapped_collection.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapper.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapper.cpython-311.pyc new file mode 100644 index 0000000..58c19aa Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/mapper.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/path_registry.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/path_registry.cpython-311.pyc new file mode 100644 index 0000000..8d8ba5f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/path_registry.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/persistence.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/persistence.cpython-311.pyc new file mode 100644 index 0000000..566049c Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/persistence.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/properties.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/properties.cpython-311.pyc new file mode 100644 index 0000000..3754ca1 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/properties.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/query.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/query.cpython-311.pyc new file mode 100644 index 0000000..d2df40d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/query.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/relationships.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/relationships.cpython-311.pyc new file mode 100644 index 0000000..c66e5a4 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/relationships.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/scoping.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/scoping.cpython-311.pyc new file mode 100644 index 0000000..e815007 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/scoping.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/session.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/session.cpython-311.pyc new file mode 100644 index 0000000..f051751 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/session.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state.cpython-311.pyc new file mode 100644 index 0000000..a70802e Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state_changes.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state_changes.cpython-311.pyc new file mode 100644 index 0000000..bb62b49 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/state_changes.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategies.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategies.cpython-311.pyc new file mode 100644 index 0000000..57424ef Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategies.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategy_options.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategy_options.cpython-311.pyc new file mode 100644 index 0000000..e1fd9c9 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/strategy_options.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/sync.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/sync.cpython-311.pyc new file mode 100644 index 0000000..b1d21cc Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/sync.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/unitofwork.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/unitofwork.cpython-311.pyc new file mode 100644 index 0000000..433eae1 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/unitofwork.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000..17dd961 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/util.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/writeonly.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/writeonly.cpython-311.pyc new file mode 100644 index 0000000..c3a0b5b Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/orm/__pycache__/writeonly.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/_orm_constructors.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/_orm_constructors.py new file mode 100644 index 0000000..7cb536b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/_orm_constructors.py @@ -0,0 +1,2471 @@ +# orm/_orm_constructors.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Collection +from typing import Iterable +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from . import mapperlib as mapperlib +from ._typing import _O +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .interfaces import _AttributeOptions +from .properties import MappedColumn +from .properties import MappedSQLExpression +from .query import AliasOption +from .relationships import _RelationshipArgumentType +from .relationships import _RelationshipDeclared +from .relationships import _RelationshipSecondaryArgument +from .relationships import RelationshipProperty +from .session import Session +from .util import _ORMJoin +from .util import AliasedClass +from .util import AliasedInsp +from .util import LoaderCriteriaOption +from .. import sql +from .. import util +from ..exc import InvalidRequestError +from ..sql._typing import _no_kw +from ..sql.base import _NoArg +from ..sql.base import SchemaEventTarget +from ..sql.schema import _InsertSentinelColumnDefault +from ..sql.schema import SchemaConst +from ..sql.selectable import FromClause +from ..util.typing import Annotated +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ORMColumnExprArgument + from .descriptor_props import _CC + from .descriptor_props import _CompositeAttrType + from .interfaces import PropComparator + from .mapper import Mapper + from .query import Query + from .relationships import _LazyLoadArgumentType + from .relationships import _ORMColCollectionArgument + from .relationships import _ORMOrderByArgument + from .relationships import _RelationshipJoinConditionArgument + from .relationships import ORMBackrefArgument + from .session import _SessionBind + from ..sql._typing import _AutoIncrementType + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _FromClauseArgument + from ..sql._typing import _InfoType + from ..sql._typing import _OnClauseArgument + from ..sql._typing import _TypeEngineArgument + from ..sql.elements import ColumnElement + from ..sql.schema import _ServerDefaultArgument + from ..sql.schema import FetchedValue + from ..sql.selectable import Alias + from ..sql.selectable import Subquery + + +_T = typing.TypeVar("_T") + + +@util.deprecated( + "1.4", + "The :class:`.AliasOption` object is not necessary " + "for entities to be matched up to a query that is established " + "via :meth:`.Query.from_statement` and now does nothing.", + enable_warnings=False, # AliasOption itself warns +) +def contains_alias(alias: Union[Alias, Subquery]) -> AliasOption: + r"""Return a :class:`.MapperOption` that will indicate to the + :class:`_query.Query` + that the main table has been aliased. + + """ + return AliasOption(alias) + + +def mapped_column( + __name_pos: Optional[ + Union[str, _TypeEngineArgument[Any], SchemaEventTarget] + ] = None, + __type_pos: Optional[ + Union[_TypeEngineArgument[Any], SchemaEventTarget] + ] = None, + *args: SchemaEventTarget, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = SchemaConst.NULL_UNSPECIFIED, + primary_key: Optional[bool] = False, + deferred: Union[_NoArg, bool] = _NoArg.NO_ARG, + deferred_group: Optional[str] = None, + deferred_raiseload: Optional[bool] = None, + use_existing_column: bool = False, + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + autoincrement: _AutoIncrementType = "auto", + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + onupdate: Optional[Any] = None, + insert_default: Optional[Any] = _NoArg.NO_ARG, + server_default: Optional[_ServerDefaultArgument] = None, + server_onupdate: Optional[FetchedValue] = None, + active_history: bool = False, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + sort_order: Union[_NoArg, int] = _NoArg.NO_ARG, + **kw: Any, +) -> MappedColumn[Any]: + r"""declare a new ORM-mapped :class:`_schema.Column` construct + for use within :ref:`Declarative Table ` + configuration. + + The :func:`_orm.mapped_column` function provides an ORM-aware and + Python-typing-compatible construct which is used with + :ref:`declarative ` mappings to indicate an + attribute that's mapped to a Core :class:`_schema.Column` object. It + provides the equivalent feature as mapping an attribute to a + :class:`_schema.Column` object directly when using Declarative, + specifically when using :ref:`Declarative Table ` + configuration. + + .. versionadded:: 2.0 + + :func:`_orm.mapped_column` is normally used with explicit typing along with + the :class:`_orm.Mapped` annotation type, where it can derive the SQL + type and nullability for the column based on what's present within the + :class:`_orm.Mapped` annotation. It also may be used without annotations + as a drop-in replacement for how :class:`_schema.Column` is used in + Declarative mappings in SQLAlchemy 1.x style. + + For usage examples of :func:`_orm.mapped_column`, see the documentation + at :ref:`orm_declarative_table`. + + .. seealso:: + + :ref:`orm_declarative_table` - complete documentation + + :ref:`whatsnew_20_orm_declarative_typing` - migration notes for + Declarative mappings using 1.x style mappings + + :param __name: String name to give to the :class:`_schema.Column`. This + is an optional, positional only argument that if present must be the + first positional argument passed. If omitted, the attribute name to + which the :func:`_orm.mapped_column` is mapped will be used as the SQL + column name. + :param __type: :class:`_types.TypeEngine` type or instance which will + indicate the datatype to be associated with the :class:`_schema.Column`. + This is an optional, positional-only argument that if present must + immediately follow the ``__name`` parameter if present also, or otherwise + be the first positional parameter. If omitted, the ultimate type for + the column may be derived either from the annotated type, or if a + :class:`_schema.ForeignKey` is present, from the datatype of the + referenced column. + :param \*args: Additional positional arguments include constructs such + as :class:`_schema.ForeignKey`, :class:`_schema.CheckConstraint`, + and :class:`_schema.Identity`, which are passed through to the constructed + :class:`_schema.Column`. + :param nullable: Optional bool, whether the column should be "NULL" or + "NOT NULL". If omitted, the nullability is derived from the type + annotation based on whether or not ``typing.Optional`` is present. + ``nullable`` defaults to ``True`` otherwise for non-primary key columns, + and ``False`` for primary key columns. + :param primary_key: optional bool, indicates the :class:`_schema.Column` + would be part of the table's primary key or not. + :param deferred: Optional bool - this keyword argument is consumed by the + ORM declarative process, and is not part of the :class:`_schema.Column` + itself; instead, it indicates that this column should be "deferred" for + loading as though mapped by :func:`_orm.deferred`. + + .. seealso:: + + :ref:`orm_queryguide_deferred_declarative` + + :param deferred_group: Implies :paramref:`_orm.mapped_column.deferred` + to ``True``, and set the :paramref:`_orm.deferred.group` parameter. + + .. seealso:: + + :ref:`orm_queryguide_deferred_group` + + :param deferred_raiseload: Implies :paramref:`_orm.mapped_column.deferred` + to ``True``, and set the :paramref:`_orm.deferred.raiseload` parameter. + + .. seealso:: + + :ref:`orm_queryguide_deferred_raiseload` + + :param use_existing_column: if True, will attempt to locate the given + column name on an inherited superclass (typically single inheriting + superclass), and if present, will not produce a new column, mapping + to the superclass column as though it were omitted from this class. + This is used for mixins that add new columns to an inherited superclass. + + .. seealso:: + + :ref:`orm_inheritance_column_conflicts` + + .. versionadded:: 2.0.0b4 + + :param default: Passed directly to the + :paramref:`_schema.Column.default` parameter if the + :paramref:`_orm.mapped_column.insert_default` parameter is not present. + Additionally, when used with :ref:`orm_declarative_native_dataclasses`, + indicates a default Python value that should be applied to the keyword + constructor within the generated ``__init__()`` method. + + Note that in the case of dataclass generation when + :paramref:`_orm.mapped_column.insert_default` is not present, this means + the :paramref:`_orm.mapped_column.default` value is used in **two** + places, both the ``__init__()`` method as well as the + :paramref:`_schema.Column.default` parameter. While this behavior may + change in a future release, for the moment this tends to "work out"; a + default of ``None`` will mean that the :class:`_schema.Column` gets no + default generator, whereas a default that refers to a non-``None`` Python + or SQL expression value will be assigned up front on the object when + ``__init__()`` is called, which is the same value that the Core + :class:`_sql.Insert` construct would use in any case, leading to the same + end result. + + .. note:: When using Core level column defaults that are callables to + be interpreted by the underlying :class:`_schema.Column` in conjunction + with :ref:`ORM-mapped dataclasses + `, especially those that are + :ref:`context-aware default functions `, + **the** :paramref:`_orm.mapped_column.insert_default` **parameter must + be used instead**. This is necessary to disambiguate the callable from + being interpreted as a dataclass level default. + + :param insert_default: Passed directly to the + :paramref:`_schema.Column.default` parameter; will supersede the value + of :paramref:`_orm.mapped_column.default` when present, however + :paramref:`_orm.mapped_column.default` will always apply to the + constructor default for a dataclasses mapping. + + :param sort_order: An integer that indicates how this mapped column + should be sorted compared to the others when the ORM is creating a + :class:`_schema.Table`. Among mapped columns that have the same + value the default ordering is used, placing first the mapped columns + defined in the main class, then the ones in the super classes. + Defaults to 0. The sort is ascending. + + .. versionadded:: 2.0.4 + + :param active_history=False: + + When ``True``, indicates that the "previous" value for a + scalar attribute should be loaded when replaced, if not + already loaded. Normally, history tracking logic for + simple non-primary-key scalar values only needs to be + aware of the "new" value in order to perform a flush. This + flag is available for applications that make use of + :func:`.attributes.get_history` or :meth:`.Session.is_modified` + which also need to know the "previous" value of the attribute. + + .. versionadded:: 2.0.10 + + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__repr__()`` + method as generated by the dataclass process. + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, + specifies a default-value generation function that will take place + as part of the ``__init__()`` + method as generated by the dataclass process. + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. + + :param \**kw: All remaining keyword arguments are passed through to the + constructor for the :class:`_schema.Column`. + + """ + + return MappedColumn( + __name_pos, + __type_pos, + *args, + name=name, + type_=type_, + autoincrement=autoincrement, + insert_default=insert_default, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + doc=doc, + key=key, + index=index, + unique=unique, + info=info, + active_history=active_history, + nullable=nullable, + onupdate=onupdate, + primary_key=primary_key, + server_default=server_default, + server_onupdate=server_onupdate, + use_existing_column=use_existing_column, + quote=quote, + comment=comment, + system=system, + deferred=deferred, + deferred_group=deferred_group, + deferred_raiseload=deferred_raiseload, + sort_order=sort_order, + **kw, + ) + + +def orm_insert_sentinel( + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[Any]] = None, + *, + default: Optional[Any] = None, + omit_from_statements: bool = True, +) -> MappedColumn[Any]: + """Provides a surrogate :func:`_orm.mapped_column` that generates + a so-called :term:`sentinel` column, allowing efficient bulk + inserts with deterministic RETURNING sorting for tables that don't + otherwise have qualifying primary key configurations. + + Use of :func:`_orm.orm_insert_sentinel` is analogous to the use of the + :func:`_schema.insert_sentinel` construct within a Core + :class:`_schema.Table` construct. + + Guidelines for adding this construct to a Declarative mapped class + are the same as that of the :func:`_schema.insert_sentinel` construct; + the database table itself also needs to have a column with this name + present. + + For background on how this object is used, see the section + :ref:`engine_insertmanyvalues_sentinel_columns` as part of the + section :ref:`engine_insertmanyvalues`. + + .. seealso:: + + :func:`_schema.insert_sentinel` + + :ref:`engine_insertmanyvalues` + + :ref:`engine_insertmanyvalues_sentinel_columns` + + + .. versionadded:: 2.0.10 + + """ + + return mapped_column( + name=name, + default=( + default if default is not None else _InsertSentinelColumnDefault() + ), + _omit_from_statements=omit_from_statements, + insert_sentinel=True, + use_existing_column=True, + nullable=True, + ) + + +@util.deprecated_params( + **{ + arg: ( + "2.0", + f"The :paramref:`_orm.column_property.{arg}` parameter is " + "deprecated for :func:`_orm.column_property`. This parameter " + "applies to a writeable-attribute in a Declarative Dataclasses " + "configuration only, and :func:`_orm.column_property` is treated " + "as a read-only attribute in this context.", + ) + for arg in ("init", "kw_only", "default", "default_factory") + } +) +def column_property( + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> MappedSQLExpression[_T]: + r"""Provide a column-level property for use with a mapping. + + With Declarative mappings, :func:`_orm.column_property` is used to + map read-only SQL expressions to a mapped class. + + When using Imperative mappings, :func:`_orm.column_property` also + takes on the role of mapping table columns with additional features. + When using fully Declarative mappings, the :func:`_orm.mapped_column` + construct should be used for this purpose. + + With Declarative Dataclass mappings, :func:`_orm.column_property` + is considered to be **read only**, and will not be included in the + Dataclass ``__init__()`` constructor. + + The :func:`_orm.column_property` function returns an instance of + :class:`.ColumnProperty`. + + .. seealso:: + + :ref:`mapper_column_property_sql_expressions` - general use of + :func:`_orm.column_property` to map SQL expressions + + :ref:`orm_imperative_table_column_options` - usage of + :func:`_orm.column_property` with Imperative Table mappings to apply + additional options to a plain :class:`_schema.Column` object + + :param \*cols: + list of Column objects to be mapped. + + :param active_history=False: + + Used only for Imperative Table mappings, or legacy-style Declarative + mappings (i.e. which have not been upgraded to + :func:`_orm.mapped_column`), for column-based attributes that are + expected to be writeable; use :func:`_orm.mapped_column` with + :paramref:`_orm.mapped_column.active_history` for Declarative mappings. + See that parameter for functional details. + + :param comparator_factory: a class which extends + :class:`.ColumnProperty.Comparator` which provides custom SQL + clause generation for comparison operations. + + :param group: + a group name for this property when marked as deferred. + + :param deferred: + when True, the column property is "deferred", meaning that + it does not load immediately, and is instead loaded when the + attribute is first accessed on an instance. See also + :func:`~sqlalchemy.orm.deferred`. + + :param doc: + optional string that will be applied as the doc on the + class-bound descriptor. + + :param expire_on_flush=True: + Disable expiry on flush. A column_property() which refers + to a SQL expression (and not a single table-bound column) + is considered to be a "read only" property; populating it + has no effect on the state of data, and it can only return + database state. For this reason a column_property()'s value + is expired whenever the parent object is involved in a + flush, that is, has any kind of "dirty" state within a flush. + Setting this parameter to ``False`` will have the effect of + leaving any existing value present after the flush proceeds. + Note that the :class:`.Session` with default expiration + settings still expires + all attributes after a :meth:`.Session.commit` call, however. + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + :param raiseload: if True, indicates the column should raise an error + when undeferred, rather than loading the value. This can be + altered at query time by using the :func:`.deferred` option with + raiseload=False. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_queryguide_deferred_raiseload` + + :param init: + + :param default: + + :param default_factory: + + :param kw_only: + + """ + return MappedSQLExpression( + column, + *additional_columns, + attribute_options=_AttributeOptions( + False if init is _NoArg.NO_ARG else init, + repr, + default, + default_factory, + compare, + kw_only, + ), + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + _assume_readonly_dc_attributes=True, + ) + + +@overload +def composite( + _class_or_attr: _CompositeAttrType[Any], + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[Any]: ... + + +@overload +def composite( + _class_or_attr: Type[_CC], + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[_CC]: ... + + +@overload +def composite( + _class_or_attr: Callable[..., _CC], + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[_CC]: ... + + +def composite( + _class_or_attr: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, + *attrs: _CompositeAttrType[Any], + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[Composite.Comparator[_T]]] = None, + active_history: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + **__kw: Any, +) -> Composite[Any]: + r"""Return a composite column-based property for use with a Mapper. + + See the mapping documentation section :ref:`mapper_composite` for a + full usage example. + + The :class:`.MapperProperty` returned by :func:`.composite` + is the :class:`.Composite`. + + :param class\_: + The "composite type" class, or any classmethod or callable which + will produce a new instance of the composite object given the + column values in order. + + :param \*attrs: + List of elements to be mapped, which may include: + + * :class:`_schema.Column` objects + * :func:`_orm.mapped_column` constructs + * string names of other attributes on the mapped class, which may be + any other SQL or object-mapped attribute. This can for + example allow a composite that refers to a many-to-one relationship + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + scalar attribute should be loaded when replaced, if not + already loaded. See the same flag on :func:`.column_property`. + + :param group: + A group name for this property when marked as deferred. + + :param deferred: + When True, the column property is "deferred", meaning that it does + not load immediately, and is instead loaded when the attribute is + first accessed on an instance. See also + :func:`~sqlalchemy.orm.deferred`. + + :param comparator_factory: a class which extends + :class:`.Composite.Comparator` which provides custom SQL + clause generation for comparison operations. + + :param doc: + optional string that will be applied as the doc on the + class-bound descriptor. + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__repr__()`` + method as generated by the dataclass process. + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, + specifies a default-value generation function that will take place + as part of the ``__init__()`` + method as generated by the dataclass process. + + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. + + """ + if __kw: + raise _no_kw() + + return Composite( + _class_or_attr, + *attrs, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + group=group, + deferred=deferred, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + info=info, + doc=doc, + ) + + +def with_loader_criteria( + entity_or_base: _EntityType[Any], + where_criteria: Union[ + _ColumnExpressionArgument[bool], + Callable[[Any], _ColumnExpressionArgument[bool]], + ], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, +) -> LoaderCriteriaOption: + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + .. versionadded:: 1.4 + + The :func:`_orm.with_loader_criteria` option is intended to add + limiting criteria to a particular kind of entity in a query, + **globally**, meaning it will apply to the entity as it appears + in the SELECT query as well as within any subqueries, join + conditions, and relationship loads, including both eager and lazy + loaders, without the need for it to be specified in any particular + part of the query. The rendering logic uses the same system used by + single table inheritance to ensure a certain discriminator is applied + to a table. + + E.g., using :term:`2.0-style` queries, we can limit the way the + ``User.addresses`` collection is loaded, regardless of the kind + of loading used:: + + from sqlalchemy.orm import with_loader_criteria + + stmt = select(User).options( + selectinload(User.addresses), + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + Above, the "selectinload" for ``User.addresses`` will apply the + given filtering criteria to the WHERE clause. + + Another example, where the filtering will be applied to the + ON clause of the join, in this example using :term:`1.x style` + queries:: + + q = session.query(User).outerjoin(User.addresses).options( + with_loader_criteria(Address, Address.email_address != 'foo')) + ) + + The primary purpose of :func:`_orm.with_loader_criteria` is to use + it in the :meth:`_orm.SessionEvents.do_orm_execute` event handler + to ensure that all occurrences of a particular entity are filtered + in a certain way, such as filtering for access control roles. It + also can be used to apply criteria to relationship loads. In the + example below, we can apply a certain set of rules to all queries + emitted by a particular :class:`_orm.Session`:: + + session = Session(bind=engine) + + @event.listens_for("do_orm_execute", session) + def _add_filtering_criteria(execute_state): + + if ( + execute_state.is_select + and not execute_state.is_column_load + and not execute_state.is_relationship_load + ): + execute_state.statement = execute_state.statement.options( + with_loader_criteria( + SecurityRole, + lambda cls: cls.role.in_(['some_role']), + include_aliases=True + ) + ) + + In the above example, the :meth:`_orm.SessionEvents.do_orm_execute` + event will intercept all queries emitted using the + :class:`_orm.Session`. For those queries which are SELECT statements + and are not attribute or relationship loads a custom + :func:`_orm.with_loader_criteria` option is added to the query. The + :func:`_orm.with_loader_criteria` option will be used in the given + statement and will also be automatically propagated to all relationship + loads that descend from this query. + + The criteria argument given is a ``lambda`` that accepts a ``cls`` + argument. The given class will expand to include all mapped subclass + and need not itself be a mapped class. + + .. tip:: + + When using :func:`_orm.with_loader_criteria` option in + conjunction with the :func:`_orm.contains_eager` loader option, + it's important to note that :func:`_orm.with_loader_criteria` only + affects the part of the query that determines what SQL is rendered + in terms of the WHERE and FROM clauses. The + :func:`_orm.contains_eager` option does not affect the rendering of + the SELECT statement outside of the columns clause, so does not have + any interaction with the :func:`_orm.with_loader_criteria` option. + However, the way things "work" is that :func:`_orm.contains_eager` + is meant to be used with a query that is already selecting from the + additional entities in some way, where + :func:`_orm.with_loader_criteria` can apply it's additional + criteria. + + In the example below, assuming a mapping relationship as + ``A -> A.bs -> B``, the given :func:`_orm.with_loader_criteria` + option will affect the way in which the JOIN is rendered:: + + stmt = select(A).join(A.bs).options( + contains_eager(A.bs), + with_loader_criteria(B, B.flag == 1) + ) + + Above, the given :func:`_orm.with_loader_criteria` option will + affect the ON clause of the JOIN that is specified by + ``.join(A.bs)``, so is applied as expected. The + :func:`_orm.contains_eager` option has the effect that columns from + ``B`` are added to the columns clause:: + + SELECT + b.id, b.a_id, b.data, b.flag, + a.id AS id_1, + a.data AS data_1 + FROM a JOIN b ON a.id = b.a_id AND b.flag = :flag_1 + + + The use of the :func:`_orm.contains_eager` option within the above + statement has no effect on the behavior of the + :func:`_orm.with_loader_criteria` option. If the + :func:`_orm.contains_eager` option were omitted, the SQL would be + the same as regards the FROM and WHERE clauses, where + :func:`_orm.with_loader_criteria` continues to add its criteria to + the ON clause of the JOIN. The addition of + :func:`_orm.contains_eager` only affects the columns clause, in that + additional columns against ``b`` are added which are then consumed + by the ORM to produce ``B`` instances. + + .. warning:: The use of a lambda inside of the call to + :func:`_orm.with_loader_criteria` is only invoked **once per unique + class**. Custom functions should not be invoked within this lambda. + See :ref:`engine_lambda_caching` for an overview of the "lambda SQL" + feature, which is for advanced use only. + + :param entity_or_base: a mapped class, or a class that is a super + class of a particular set of mapped classes, to which the rule + will apply. + + :param where_criteria: a Core SQL expression that applies limiting + criteria. This may also be a "lambda:" or Python function that + accepts a target class as an argument, when the given class is + a base with many different mapped subclasses. + + .. note:: To support pickling, use a module-level Python function to + produce the SQL expression instead of a lambda or a fixed SQL + expression, which tend to not be picklable. + + :param include_aliases: if True, apply the rule to :func:`_orm.aliased` + constructs as well. + + :param propagate_to_loaders: defaults to True, apply to relationship + loaders such as lazy loaders. This indicates that the + option object itself including SQL expression is carried along with + each loaded instance. Set to ``False`` to prevent the object from + being assigned to individual instances. + + + .. seealso:: + + :ref:`examples_session_orm_events` - includes examples of using + :func:`_orm.with_loader_criteria`. + + :ref:`do_orm_execute_global_criteria` - basic example on how to + combine :func:`_orm.with_loader_criteria` with the + :meth:`_orm.SessionEvents.do_orm_execute` event. + + :param track_closure_variables: when False, closure variables inside + of a lambda expression will not be used as part of + any cache key. This allows more complex expressions to be used + inside of a lambda expression but requires that the lambda ensures + it returns the identical SQL every time given a particular class. + + .. versionadded:: 1.4.0b2 + + """ + return LoaderCriteriaOption( + entity_or_base, + where_criteria, + loader_only, + include_aliases, + propagate_to_loaders, + track_closure_variables, + ) + + +def relationship( + argument: Optional[_RelationshipArgumentType[Any]] = None, + secondary: Optional[_RelationshipSecondaryArgument] = None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: Union[Literal["all"], bool] = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[ + Type[RelationshipProperty.Comparator[Any]] + ] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, + **kw: Any, +) -> _RelationshipDeclared[Any]: + """Provide a relationship between two mapped classes. + + This corresponds to a parent-child or associative table relationship. + The constructed class is an instance of :class:`.Relationship`. + + .. seealso:: + + :ref:`tutorial_orm_related_objects` - tutorial introduction + to :func:`_orm.relationship` in the :ref:`unified_tutorial` + + :ref:`relationship_config_toplevel` - narrative documentation + + :param argument: + This parameter refers to the class that is to be related. It + accepts several forms, including a direct reference to the target + class itself, the :class:`_orm.Mapper` instance for the target class, + a Python callable / lambda that will return a reference to the + class or :class:`_orm.Mapper` when called, and finally a string + name for the class, which will be resolved from the + :class:`_orm.registry` in use in order to locate the class, e.g.:: + + class SomeClass(Base): + # ... + + related = relationship("RelatedClass") + + The :paramref:`_orm.relationship.argument` may also be omitted from the + :func:`_orm.relationship` construct entirely, and instead placed inside + a :class:`_orm.Mapped` annotation on the left side, which should + include a Python collection type if the relationship is expected + to be a collection, such as:: + + class SomeClass(Base): + # ... + + related_items: Mapped[List["RelatedItem"]] = relationship() + + Or for a many-to-one or one-to-one relationship:: + + class SomeClass(Base): + # ... + + related_item: Mapped["RelatedItem"] = relationship() + + .. seealso:: + + :ref:`orm_declarative_properties` - further detail + on relationship configuration when using Declarative. + + :param secondary: + For a many-to-many relationship, specifies the intermediary + table, and is typically an instance of :class:`_schema.Table`. + In less common circumstances, the argument may also be specified + as an :class:`_expression.Alias` construct, or even a + :class:`_expression.Join` construct. + + :paramref:`_orm.relationship.secondary` may + also be passed as a callable function which is evaluated at + mapper initialization time. When using Declarative, it may also + be a string argument noting the name of a :class:`_schema.Table` + that is + present in the :class:`_schema.MetaData` + collection associated with the + parent-mapped :class:`_schema.Table`. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + The :paramref:`_orm.relationship.secondary` keyword argument is + typically applied in the case where the intermediary + :class:`_schema.Table` + is not otherwise expressed in any direct class mapping. If the + "secondary" table is also explicitly mapped elsewhere (e.g. as in + :ref:`association_pattern`), one should consider applying the + :paramref:`_orm.relationship.viewonly` flag so that this + :func:`_orm.relationship` + is not used for persistence operations which + may conflict with those of the association object pattern. + + .. seealso:: + + :ref:`relationships_many_to_many` - Reference example of "many + to many". + + :ref:`self_referential_many_to_many` - Specifics on using + many-to-many in a self-referential case. + + :ref:`declarative_many_to_many` - Additional options when using + Declarative. + + :ref:`association_pattern` - an alternative to + :paramref:`_orm.relationship.secondary` + when composing association + table relationships, allowing additional attributes to be + specified on the association table. + + :ref:`composite_secondary_join` - a lesser-used pattern which + in some cases can enable complex :func:`_orm.relationship` SQL + conditions to be used. + + :param active_history=False: + When ``True``, indicates that the "previous" value for a + many-to-one reference should be loaded when replaced, if + not already loaded. Normally, history tracking logic for + simple many-to-ones only needs to be aware of the "new" + value in order to perform a flush. This flag is available + for applications that make use of + :func:`.attributes.get_history` which also need to know + the "previous" value of the attribute. + + :param backref: + A reference to a string relationship name, or a :func:`_orm.backref` + construct, which will be used to automatically generate a new + :func:`_orm.relationship` on the related class, which then refers to this + one using a bi-directional :paramref:`_orm.relationship.back_populates` + configuration. + + In modern Python, explicit use of :func:`_orm.relationship` + with :paramref:`_orm.relationship.back_populates` should be preferred, + as it is more robust in terms of mapper configuration as well as + more conceptually straightforward. It also integrates with + new :pep:`484` typing features introduced in SQLAlchemy 2.0 which + is not possible with dynamically generated attributes. + + .. seealso:: + + :ref:`relationships_backref` - notes on using + :paramref:`_orm.relationship.backref` + + :ref:`tutorial_orm_related_objects` - in the :ref:`unified_tutorial`, + presents an overview of bi-directional relationship configuration + and behaviors using :paramref:`_orm.relationship.back_populates` + + :func:`.backref` - allows control over :func:`_orm.relationship` + configuration when using :paramref:`_orm.relationship.backref`. + + + :param back_populates: + Indicates the name of a :func:`_orm.relationship` on the related + class that will be synchronized with this one. It is usually + expected that the :func:`_orm.relationship` on the related class + also refer to this one. This allows objects on both sides of + each :func:`_orm.relationship` to synchronize in-Python state + changes and also provides directives to the :term:`unit of work` + flush process how changes along these relationships should + be persisted. + + .. seealso:: + + :ref:`tutorial_orm_related_objects` - in the :ref:`unified_tutorial`, + presents an overview of bi-directional relationship configuration + and behaviors. + + :ref:`relationship_patterns` - includes many examples of + :paramref:`_orm.relationship.back_populates`. + + :paramref:`_orm.relationship.backref` - legacy form which allows + more succinct configuration, but does not support explicit typing + + :param overlaps: + A string name or comma-delimited set of names of other relationships + on either this mapper, a descendant mapper, or a target mapper with + which this relationship may write to the same foreign keys upon + persistence. The only effect this has is to eliminate the + warning that this relationship will conflict with another upon + persistence. This is used for such relationships that are truly + capable of conflicting with each other on write, but the application + will ensure that no such conflicts occur. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`error_qzyx` - usage example + + :param cascade: + A comma-separated list of cascade rules which determines how + Session operations should be "cascaded" from parent to child. + This defaults to ``False``, which means the default cascade + should be used - this default cascade is ``"save-update, merge"``. + + The available cascades are ``save-update``, ``merge``, + ``expunge``, ``delete``, ``delete-orphan``, and ``refresh-expire``. + An additional option, ``all`` indicates shorthand for + ``"save-update, merge, refresh-expire, + expunge, delete"``, and is often used as in ``"all, delete-orphan"`` + to indicate that related objects should follow along with the + parent object in all cases, and be deleted when de-associated. + + .. seealso:: + + :ref:`unitofwork_cascades` - Full detail on each of the available + cascade options. + + :param cascade_backrefs=False: + Legacy; this flag is always False. + + .. versionchanged:: 2.0 "cascade_backrefs" functionality has been + removed. + + :param collection_class: + A class or callable that returns a new list-holding object. will + be used in place of a plain list for storing elements. + + .. seealso:: + + :ref:`custom_collections` - Introductory documentation and + examples. + + :param comparator_factory: + A class which extends :class:`.Relationship.Comparator` + which provides custom SQL clause generation for comparison + operations. + + .. seealso:: + + :class:`.PropComparator` - some detail on redefining comparators + at this level. + + :ref:`custom_comparators` - Brief intro to this feature. + + + :param distinct_target_key=None: + Indicate if a "subquery" eager load should apply the DISTINCT + keyword to the innermost SELECT statement. When left as ``None``, + the DISTINCT keyword will be applied in those cases when the target + columns do not comprise the full primary key of the target table. + When set to ``True``, the DISTINCT keyword is applied to the + innermost SELECT unconditionally. + + It may be desirable to set this flag to False when the DISTINCT is + reducing performance of the innermost subquery beyond that of what + duplicate innermost rows may be causing. + + .. seealso:: + + :ref:`loading_toplevel` - includes an introduction to subquery + eager loading. + + :param doc: + Docstring which will be applied to the resulting descriptor. + + :param foreign_keys: + + A list of columns which are to be used as "foreign key" + columns, or columns which refer to the value in a remote + column, within the context of this :func:`_orm.relationship` + object's :paramref:`_orm.relationship.primaryjoin` condition. + That is, if the :paramref:`_orm.relationship.primaryjoin` + condition of this :func:`_orm.relationship` is ``a.id == + b.a_id``, and the values in ``b.a_id`` are required to be + present in ``a.id``, then the "foreign key" column of this + :func:`_orm.relationship` is ``b.a_id``. + + In normal cases, the :paramref:`_orm.relationship.foreign_keys` + parameter is **not required.** :func:`_orm.relationship` will + automatically determine which columns in the + :paramref:`_orm.relationship.primaryjoin` condition are to be + considered "foreign key" columns based on those + :class:`_schema.Column` objects that specify + :class:`_schema.ForeignKey`, + or are otherwise listed as referencing columns in a + :class:`_schema.ForeignKeyConstraint` construct. + :paramref:`_orm.relationship.foreign_keys` is only needed when: + + 1. There is more than one way to construct a join from the local + table to the remote table, as there are multiple foreign key + references present. Setting ``foreign_keys`` will limit the + :func:`_orm.relationship` + to consider just those columns specified + here as "foreign". + + 2. The :class:`_schema.Table` being mapped does not actually have + :class:`_schema.ForeignKey` or + :class:`_schema.ForeignKeyConstraint` + constructs present, often because the table + was reflected from a database that does not support foreign key + reflection (MySQL MyISAM). + + 3. The :paramref:`_orm.relationship.primaryjoin` + argument is used to + construct a non-standard join condition, which makes use of + columns or expressions that do not normally refer to their + "parent" column, such as a join condition expressed by a + complex comparison using a SQL function. + + The :func:`_orm.relationship` construct will raise informative + error messages that suggest the use of the + :paramref:`_orm.relationship.foreign_keys` parameter when + presented with an ambiguous condition. In typical cases, + if :func:`_orm.relationship` doesn't raise any exceptions, the + :paramref:`_orm.relationship.foreign_keys` parameter is usually + not needed. + + :paramref:`_orm.relationship.foreign_keys` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`relationship_foreign_keys` + + :ref:`relationship_custom_foreign` + + :func:`.foreign` - allows direct annotation of the "foreign" + columns within a :paramref:`_orm.relationship.primaryjoin` + condition. + + :param info: Optional data dictionary which will be populated into the + :attr:`.MapperProperty.info` attribute of this object. + + :param innerjoin=False: + When ``True``, joined eager loads will use an inner join to join + against related tables instead of an outer join. The purpose + of this option is generally one of performance, as inner joins + generally perform better than outer joins. + + This flag can be set to ``True`` when the relationship references an + object via many-to-one using local foreign keys that are not + nullable, or when the reference is one-to-one or a collection that + is guaranteed to have one or at least one entry. + + The option supports the same "nested" and "unnested" options as + that of :paramref:`_orm.joinedload.innerjoin`. See that flag + for details on nested / unnested behaviors. + + .. seealso:: + + :paramref:`_orm.joinedload.innerjoin` - the option as specified by + loader option, including detail on nesting behavior. + + :ref:`what_kind_of_loading` - Discussion of some details of + various loader options. + + + :param join_depth: + When non-``None``, an integer value indicating how many levels + deep "eager" loaders should join on a self-referring or cyclical + relationship. The number counts how many times the same Mapper + shall be present in the loading condition along a particular join + branch. When left at its default of ``None``, eager loaders + will stop chaining when they encounter a the same target mapper + which is already higher up in the chain. This option applies + both to joined- and subquery- eager loaders. + + .. seealso:: + + :ref:`self_referential_eager_loading` - Introductory documentation + and examples. + + :param lazy='select': specifies + How the related items should be loaded. Default value is + ``select``. Values include: + + * ``select`` - items should be loaded lazily when the property is + first accessed, using a separate SELECT statement, or identity map + fetch for simple many-to-one references. + + * ``immediate`` - items should be loaded as the parents are loaded, + using a separate SELECT statement, or identity map fetch for + simple many-to-one references. + + * ``joined`` - items should be loaded "eagerly" in the same query as + that of the parent, using a JOIN or LEFT OUTER JOIN. Whether + the join is "outer" or not is determined by the + :paramref:`_orm.relationship.innerjoin` parameter. + + * ``subquery`` - items should be loaded "eagerly" as the parents are + loaded, using one additional SQL statement, which issues a JOIN to + a subquery of the original statement, for each collection + requested. + + * ``selectin`` - items should be loaded "eagerly" as the parents + are loaded, using one or more additional SQL statements, which + issues a JOIN to the immediate parent object, specifying primary + key identifiers using an IN clause. + + * ``noload`` - no loading should occur at any time. The related + collection will remain empty. The ``noload`` strategy is not + recommended for general use. For a general use "never load" + approach, see :ref:`write_only_relationship` + + * ``raise`` - lazy loading is disallowed; accessing + the attribute, if its value were not already loaded via eager + loading, will raise an :exc:`~sqlalchemy.exc.InvalidRequestError`. + This strategy can be used when objects are to be detached from + their attached :class:`.Session` after they are loaded. + + * ``raise_on_sql`` - lazy loading that emits SQL is disallowed; + accessing the attribute, if its value were not already loaded via + eager loading, will raise an + :exc:`~sqlalchemy.exc.InvalidRequestError`, **if the lazy load + needs to emit SQL**. If the lazy load can pull the related value + from the identity map or determine that it should be None, the + value is loaded. This strategy can be used when objects will + remain associated with the attached :class:`.Session`, however + additional SELECT statements should be blocked. + + * ``write_only`` - the attribute will be configured with a special + "virtual collection" that may receive + :meth:`_orm.WriteOnlyCollection.add` and + :meth:`_orm.WriteOnlyCollection.remove` commands to add or remove + individual objects, but will not under any circumstances load or + iterate the full set of objects from the database directly. Instead, + methods such as :meth:`_orm.WriteOnlyCollection.select`, + :meth:`_orm.WriteOnlyCollection.insert`, + :meth:`_orm.WriteOnlyCollection.update` and + :meth:`_orm.WriteOnlyCollection.delete` are provided which generate SQL + constructs that may be used to load and modify rows in bulk. Used for + large collections that are never appropriate to load at once into + memory. + + The ``write_only`` loader style is configured automatically when + the :class:`_orm.WriteOnlyMapped` annotation is provided on the + left hand side within a Declarative mapping. See the section + :ref:`write_only_relationship` for examples. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`write_only_relationship` - in the :ref:`queryguide_toplevel` + + * ``dynamic`` - the attribute will return a pre-configured + :class:`_query.Query` object for all read + operations, onto which further filtering operations can be + applied before iterating the results. + + The ``dynamic`` loader style is configured automatically when + the :class:`_orm.DynamicMapped` annotation is provided on the + left hand side within a Declarative mapping. See the section + :ref:`dynamic_relationship` for examples. + + .. legacy:: The "dynamic" lazy loader strategy is the legacy form of + what is now the "write_only" strategy described in the section + :ref:`write_only_relationship`. + + .. seealso:: + + :ref:`dynamic_relationship` - in the :ref:`queryguide_toplevel` + + :ref:`write_only_relationship` - more generally useful approach + for large collections that should not fully load into memory + + * True - a synonym for 'select' + + * False - a synonym for 'joined' + + * None - a synonym for 'noload' + + .. seealso:: + + :ref:`orm_queryguide_relationship_loaders` - Full documentation on + relationship loader configuration in the :ref:`queryguide_toplevel`. + + + :param load_on_pending=False: + Indicates loading behavior for transient or pending parent objects. + + When set to ``True``, causes the lazy-loader to + issue a query for a parent object that is not persistent, meaning it + has never been flushed. This may take effect for a pending object + when autoflush is disabled, or for a transient object that has been + "attached" to a :class:`.Session` but is not part of its pending + collection. + + The :paramref:`_orm.relationship.load_on_pending` + flag does not improve + behavior when the ORM is used normally - object references should be + constructed at the object level, not at the foreign key level, so + that they are present in an ordinary way before a flush proceeds. + This flag is not not intended for general use. + + .. seealso:: + + :meth:`.Session.enable_relationship_loading` - this method + establishes "load on pending" behavior for the whole object, and + also allows loading on objects that remain transient or + detached. + + :param order_by: + Indicates the ordering that should be applied when loading these + items. :paramref:`_orm.relationship.order_by` + is expected to refer to + one of the :class:`_schema.Column` + objects to which the target class is + mapped, or the attribute itself bound to the target class which + refers to the column. + + :paramref:`_orm.relationship.order_by` + may also be passed as a callable + function which is evaluated at mapper initialization time, and may + be passed as a Python-evaluable string when using Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + :param passive_deletes=False: + Indicates loading behavior during delete operations. + + A value of True indicates that unloaded child items should not + be loaded during a delete operation on the parent. Normally, + when a parent item is deleted, all child items are loaded so + that they can either be marked as deleted, or have their + foreign key to the parent set to NULL. Marking this flag as + True usually implies an ON DELETE rule is in + place which will handle updating/deleting child rows on the + database side. + + Additionally, setting the flag to the string value 'all' will + disable the "nulling out" of the child foreign keys, when the parent + object is deleted and there is no delete or delete-orphan cascade + enabled. This is typically used when a triggering or error raise + scenario is in place on the database side. Note that the foreign + key attributes on in-session child objects will not be changed after + a flush occurs so this is a very special use-case setting. + Additionally, the "nulling out" will still occur if the child + object is de-associated with the parent. + + .. seealso:: + + :ref:`passive_deletes` - Introductory documentation + and examples. + + :param passive_updates=True: + Indicates the persistence behavior to take when a referenced + primary key value changes in place, indicating that the referencing + foreign key columns will also need their value changed. + + When True, it is assumed that ``ON UPDATE CASCADE`` is configured on + the foreign key in the database, and that the database will + handle propagation of an UPDATE from a source column to + dependent rows. When False, the SQLAlchemy + :func:`_orm.relationship` + construct will attempt to emit its own UPDATE statements to + modify related targets. However note that SQLAlchemy **cannot** + emit an UPDATE for more than one level of cascade. Also, + setting this flag to False is not compatible in the case where + the database is in fact enforcing referential integrity, unless + those constraints are explicitly "deferred", if the target backend + supports it. + + It is highly advised that an application which is employing + mutable primary keys keeps ``passive_updates`` set to True, + and instead uses the referential integrity features of the database + itself in order to handle the change efficiently and fully. + + .. seealso:: + + :ref:`passive_updates` - Introductory documentation and + examples. + + :paramref:`.mapper.passive_updates` - a similar flag which + takes effect for joined-table inheritance mappings. + + :param post_update: + This indicates that the relationship should be handled by a + second UPDATE statement after an INSERT or before a + DELETE. This flag is used to handle saving bi-directional + dependencies between two individual rows (i.e. each row + references the other), where it would otherwise be impossible to + INSERT or DELETE both rows fully since one row exists before the + other. Use this flag when a particular mapping arrangement will + incur two rows that are dependent on each other, such as a table + that has a one-to-many relationship to a set of child rows, and + also has a column that references a single child row within that + list (i.e. both tables contain a foreign key to each other). If + a flush operation returns an error that a "cyclical + dependency" was detected, this is a cue that you might want to + use :paramref:`_orm.relationship.post_update` to "break" the cycle. + + .. seealso:: + + :ref:`post_update` - Introductory documentation and examples. + + :param primaryjoin: + A SQL expression that will be used as the primary + join of the child object against the parent object, or in a + many-to-many relationship the join of the parent object to the + association table. By default, this value is computed based on the + foreign key relationships of the parent and child tables (or + association table). + + :paramref:`_orm.relationship.primaryjoin` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`relationship_primaryjoin` + + :param remote_side: + Used for self-referential relationships, indicates the column or + list of columns that form the "remote side" of the relationship. + + :paramref:`_orm.relationship.remote_side` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`self_referential` - in-depth explanation of how + :paramref:`_orm.relationship.remote_side` + is used to configure self-referential relationships. + + :func:`.remote` - an annotation function that accomplishes the + same purpose as :paramref:`_orm.relationship.remote_side`, + typically + when a custom :paramref:`_orm.relationship.primaryjoin` condition + is used. + + :param query_class: + A :class:`_query.Query` + subclass that will be used internally by the + ``AppenderQuery`` returned by a "dynamic" relationship, that + is, a relationship that specifies ``lazy="dynamic"`` or was + otherwise constructed using the :func:`_orm.dynamic_loader` + function. + + .. seealso:: + + :ref:`dynamic_relationship` - Introduction to "dynamic" + relationship loaders. + + :param secondaryjoin: + A SQL expression that will be used as the join of + an association table to the child object. By default, this value is + computed based on the foreign key relationships of the association + and child tables. + + :paramref:`_orm.relationship.secondaryjoin` may also be passed as a + callable function which is evaluated at mapper initialization time, + and may be passed as a Python-evaluable string when using + Declarative. + + .. warning:: When passed as a Python-evaluable string, the + argument is interpreted using Python's ``eval()`` function. + **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. + See :ref:`declarative_relationship_eval` for details on + declarative evaluation of :func:`_orm.relationship` arguments. + + .. seealso:: + + :ref:`relationship_primaryjoin` + + :param single_parent: + When True, installs a validator which will prevent objects + from being associated with more than one parent at a time. + This is used for many-to-one or many-to-many relationships that + should be treated either as one-to-one or one-to-many. Its usage + is optional, except for :func:`_orm.relationship` constructs which + are many-to-one or many-to-many and also + specify the ``delete-orphan`` cascade option. The + :func:`_orm.relationship` construct itself will raise an error + instructing when this option is required. + + .. seealso:: + + :ref:`unitofwork_cascades` - includes detail on when the + :paramref:`_orm.relationship.single_parent` + flag may be appropriate. + + :param uselist: + A boolean that indicates if this property should be loaded as a + list or a scalar. In most cases, this value is determined + automatically by :func:`_orm.relationship` at mapper configuration + time. When using explicit :class:`_orm.Mapped` annotations, + :paramref:`_orm.relationship.uselist` may be derived from the + whether or not the annotation within :class:`_orm.Mapped` contains + a collection class. + Otherwise, :paramref:`_orm.relationship.uselist` may be derived from + the type and direction + of the relationship - one to many forms a list, many to one + forms a scalar, many to many is a list. If a scalar is desired + where normally a list would be present, such as a bi-directional + one-to-one relationship, use an appropriate :class:`_orm.Mapped` + annotation or set :paramref:`_orm.relationship.uselist` to False. + + The :paramref:`_orm.relationship.uselist` + flag is also available on an + existing :func:`_orm.relationship` + construct as a read-only attribute, + which can be used to determine if this :func:`_orm.relationship` + deals + with collections or scalar attributes:: + + >>> User.addresses.property.uselist + True + + .. seealso:: + + :ref:`relationships_one_to_one` - Introduction to the "one to + one" relationship pattern, which is typically when an alternate + setting for :paramref:`_orm.relationship.uselist` is involved. + + :param viewonly=False: + When set to ``True``, the relationship is used only for loading + objects, and not for any persistence operation. A + :func:`_orm.relationship` which specifies + :paramref:`_orm.relationship.viewonly` can work + with a wider range of SQL operations within the + :paramref:`_orm.relationship.primaryjoin` condition, including + operations that feature the use of a variety of comparison operators + as well as SQL functions such as :func:`_expression.cast`. The + :paramref:`_orm.relationship.viewonly` + flag is also of general use when defining any kind of + :func:`_orm.relationship` that doesn't represent + the full set of related objects, to prevent modifications of the + collection from resulting in persistence operations. + + .. seealso:: + + :ref:`relationship_viewonly_notes` - more details on best practices + when using :paramref:`_orm.relationship.viewonly`. + + :param sync_backref: + A boolean that enables the events used to synchronize the in-Python + attributes when this relationship is target of either + :paramref:`_orm.relationship.backref` or + :paramref:`_orm.relationship.back_populates`. + + Defaults to ``None``, which indicates that an automatic value should + be selected based on the value of the + :paramref:`_orm.relationship.viewonly` flag. When left at its + default, changes in state will be back-populated only if neither + sides of a relationship is viewonly. + + .. versionadded:: 1.3.17 + + .. versionchanged:: 1.4 - A relationship that specifies + :paramref:`_orm.relationship.viewonly` automatically implies + that :paramref:`_orm.relationship.sync_backref` is ``False``. + + .. seealso:: + + :paramref:`_orm.relationship.viewonly` + + :param omit_join: + Allows manual control over the "selectin" automatic join + optimization. Set to ``False`` to disable the "omit join" feature + added in SQLAlchemy 1.3; or leave as ``None`` to leave automatic + optimization in place. + + .. note:: This flag may only be set to ``False``. It is not + necessary to set it to ``True`` as the "omit_join" optimization is + automatically detected; if it is not detected, then the + optimization is not supported. + + .. versionchanged:: 1.3.11 setting ``omit_join`` to True will now + emit a warning as this was not the intended use of this flag. + + .. versionadded:: 1.3 + + :param init: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__init__()`` + method as generated by the dataclass process. + :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, + specifies if the mapped attribute should be part of the ``__repr__()`` + method as generated by the dataclass process. + :param default_factory: Specific to + :ref:`orm_declarative_native_dataclasses`, + specifies a default-value generation function that will take place + as part of the ``__init__()`` + method as generated by the dataclass process. + :param compare: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be included in comparison operations when generating the + ``__eq__()`` and ``__ne__()`` methods for the mapped class. + + .. versionadded:: 2.0.0b4 + + :param kw_only: Specific to + :ref:`orm_declarative_native_dataclasses`, indicates if this field + should be marked as keyword-only when generating the ``__init__()``. + + + """ + + return _RelationshipDeclared( + argument, + secondary=secondary, + uselist=uselist, + collection_class=collection_class, + primaryjoin=primaryjoin, + secondaryjoin=secondaryjoin, + back_populates=back_populates, + order_by=order_by, + backref=backref, + overlaps=overlaps, + post_update=post_update, + cascade=cascade, + viewonly=viewonly, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + lazy=lazy, + passive_deletes=passive_deletes, + passive_updates=passive_updates, + active_history=active_history, + enable_typechecks=enable_typechecks, + foreign_keys=foreign_keys, + remote_side=remote_side, + join_depth=join_depth, + comparator_factory=comparator_factory, + single_parent=single_parent, + innerjoin=innerjoin, + distinct_target_key=distinct_target_key, + load_on_pending=load_on_pending, + query_class=query_class, + info=info, + omit_join=omit_join, + sync_backref=sync_backref, + **kw, + ) + + +def synonym( + name: str, + *, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Union[_NoArg, _T] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> Synonym[Any]: + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + e.g.:: + + class MyClass(Base): + __tablename__ = 'my_table' + + id = Column(Integer, primary_key=True) + job_status = Column(String(50)) + + status = synonym("job_status") + + + :param name: the name of the existing mapped property. This + can refer to the string name ORM-mapped attribute + configured on the class, including column-bound attributes + and relationships. + + :param descriptor: a Python :term:`descriptor` that will be used + as a getter (and potentially a setter) when this attribute is + accessed at the instance level. + + :param map_column: **For classical mappings and mappings against + an existing Table object only**. if ``True``, the :func:`.synonym` + construct will locate the :class:`_schema.Column` + object upon the mapped + table that would normally be associated with the attribute name of + this synonym, and produce a new :class:`.ColumnProperty` that instead + maps this :class:`_schema.Column` + to the alternate name given as the "name" + argument of the synonym; in this way, the usual step of redefining + the mapping of the :class:`_schema.Column` + to be under a different name is + unnecessary. This is usually intended to be used when a + :class:`_schema.Column` + is to be replaced with an attribute that also uses a + descriptor, that is, in conjunction with the + :paramref:`.synonym.descriptor` parameter:: + + my_table = Table( + "my_table", metadata, + Column('id', Integer, primary_key=True), + Column('job_status', String(50)) + ) + + class MyClass: + @property + def _job_status_descriptor(self): + return "Status: %s" % self._job_status + + + mapper( + MyClass, my_table, properties={ + "job_status": synonym( + "_job_status", map_column=True, + descriptor=MyClass._job_status_descriptor) + } + ) + + Above, the attribute named ``_job_status`` is automatically + mapped to the ``job_status`` column:: + + >>> j1 = MyClass() + >>> j1._job_status = "employed" + >>> j1.job_status + Status: employed + + When using Declarative, in order to provide a descriptor in + conjunction with a synonym, use the + :func:`sqlalchemy.ext.declarative.synonym_for` helper. However, + note that the :ref:`hybrid properties ` feature + should usually be preferred, particularly when redefining attribute + behavior. + + :param info: Optional data dictionary which will be populated into the + :attr:`.InspectionAttr.info` attribute of this object. + + :param comparator_factory: A subclass of :class:`.PropComparator` + that will provide custom comparison behavior at the SQL expression + level. + + .. note:: + + For the use case of providing an attribute which redefines both + Python-level and SQL-expression level behavior of an attribute, + please refer to the Hybrid attribute introduced at + :ref:`mapper_hybrids` for a more effective technique. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + :func:`.synonym_for` - a helper oriented towards Declarative + + :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an + updated approach to augmenting attribute behavior more flexibly + than can be achieved with synonyms. + + """ + return Synonym( + name, + map_column=map_column, + descriptor=descriptor, + comparator_factory=comparator_factory, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + doc=doc, + info=info, + ) + + +def create_session( + bind: Optional[_SessionBind] = None, **kwargs: Any +) -> Session: + r"""Create a new :class:`.Session` + with no automation enabled by default. + + This function is used primarily for testing. The usual + route to :class:`.Session` creation is via its constructor + or the :func:`.sessionmaker` function. + + :param bind: optional, a single Connectable to use for all + database access in the created + :class:`~sqlalchemy.orm.session.Session`. + + :param \*\*kwargs: optional, passed through to the + :class:`.Session` constructor. + + :returns: an :class:`~sqlalchemy.orm.session.Session` instance + + The defaults of create_session() are the opposite of that of + :func:`sessionmaker`; ``autoflush`` and ``expire_on_commit`` are + False. + + Usage:: + + >>> from sqlalchemy.orm import create_session + >>> session = create_session() + + It is recommended to use :func:`sessionmaker` instead of + create_session(). + + """ + + kwargs.setdefault("autoflush", False) + kwargs.setdefault("expire_on_commit", False) + return Session(bind=bind, **kwargs) + + +def _mapper_fn(*arg: Any, **kw: Any) -> NoReturn: + """Placeholder for the now-removed ``mapper()`` function. + + Classical mappings should be performed using the + :meth:`_orm.registry.map_imperatively` method. + + This symbol remains in SQLAlchemy 2.0 to suit the deprecated use case + of using the ``mapper()`` function as a target for ORM event listeners, + which failed to be marked as deprecated in the 1.4 series. + + Global ORM mapper listeners should instead use the :class:`_orm.Mapper` + class as the target. + + .. versionchanged:: 2.0 The ``mapper()`` function was removed; the + symbol remains temporarily as a placeholder for the event listening + use case. + + """ + raise InvalidRequestError( + "The 'sqlalchemy.orm.mapper()' function is removed as of " + "SQLAlchemy 2.0. Use the " + "'sqlalchemy.orm.registry.map_imperatively()` " + "method of the ``sqlalchemy.orm.registry`` class to perform " + "classical mapping." + ) + + +def dynamic_loader( + argument: Optional[_RelationshipArgumentType[Any]] = None, **kw: Any +) -> RelationshipProperty[Any]: + """Construct a dynamically-loading mapper property. + + This is essentially the same as + using the ``lazy='dynamic'`` argument with :func:`relationship`:: + + dynamic_loader(SomeClass) + + # is the same as + + relationship(SomeClass, lazy="dynamic") + + See the section :ref:`dynamic_relationship` for more details + on dynamic loading. + + """ + kw["lazy"] = "dynamic" + return relationship(argument, **kw) + + +def backref(name: str, **kwargs: Any) -> ORMBackrefArgument: + """When using the :paramref:`_orm.relationship.backref` parameter, + provides specific parameters to be used when the new + :func:`_orm.relationship` is generated. + + E.g.:: + + 'items':relationship( + SomeItem, backref=backref('parent', lazy='subquery')) + + The :paramref:`_orm.relationship.backref` parameter is generally + considered to be legacy; for modern applications, using + explicit :func:`_orm.relationship` constructs linked together using + the :paramref:`_orm.relationship.back_populates` parameter should be + preferred. + + .. seealso:: + + :ref:`relationships_backref` - background on backrefs + + """ + + return (name, kwargs) + + +def deferred( + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + group: Optional[str] = None, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + default: Optional[Any] = _NoArg.NO_ARG, + default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> MappedSQLExpression[_T]: + r"""Indicate a column-based mapped attribute that by default will + not load unless accessed. + + When using :func:`_orm.mapped_column`, the same functionality as + that of :func:`_orm.deferred` construct is provided by using the + :paramref:`_orm.mapped_column.deferred` parameter. + + :param \*columns: columns to be mapped. This is typically a single + :class:`_schema.Column` object, + however a collection is supported in order + to support multiple columns mapped under the same attribute. + + :param raiseload: boolean, if True, indicates an exception should be raised + if the load operation is to take place. + + .. versionadded:: 1.4 + + + Additional arguments are the same as that of :func:`_orm.column_property`. + + .. seealso:: + + :ref:`orm_queryguide_deferred_imperative` + + """ + return MappedSQLExpression( + column, + *additional_columns, + attribute_options=_AttributeOptions( + init, repr, default, default_factory, compare, kw_only + ), + group=group, + deferred=True, + raiseload=raiseload, + comparator_factory=comparator_factory, + active_history=active_history, + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + ) + + +def query_expression( + default_expr: _ORMColumnExprArgument[_T] = sql.null(), + *, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + compare: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, +) -> MappedSQLExpression[_T]: + """Indicate an attribute that populates from a query-time SQL expression. + + :param default_expr: Optional SQL expression object that will be used in + all cases if not assigned later with :func:`_orm.with_expression`. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`orm_queryguide_with_expression` - background and usage examples + + """ + prop = MappedSQLExpression( + default_expr, + attribute_options=_AttributeOptions( + False, + repr, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + compare, + _NoArg.NO_ARG, + ), + expire_on_flush=expire_on_flush, + info=info, + doc=doc, + _assume_readonly_dc_attributes=True, + ) + + prop.strategy_key = (("query_expression", True),) + return prop + + +def clear_mappers() -> None: + """Remove all mappers from all classes. + + .. versionchanged:: 1.4 This function now locates all + :class:`_orm.registry` objects and calls upon the + :meth:`_orm.registry.dispose` method of each. + + This function removes all instrumentation from classes and disposes + of their associated mappers. Once called, the classes are unmapped + and can be later re-mapped with new mappers. + + :func:`.clear_mappers` is *not* for normal use, as there is literally no + valid usage for it outside of very specific testing scenarios. Normally, + mappers are permanent structural components of user-defined classes, and + are never discarded independently of their class. If a mapped class + itself is garbage collected, its mapper is automatically disposed of as + well. As such, :func:`.clear_mappers` is only for usage in test suites + that re-use the same classes with different mappings, which is itself an + extremely rare use case - the only such use case is in fact SQLAlchemy's + own test suite, and possibly the test suites of other ORM extension + libraries which intend to test various combinations of mapper construction + upon a fixed set of classes. + + """ + + mapperlib._dispose_registries(mapperlib._all_registries(), False) + + +# I would really like a way to get the Type[] here that shows up +# in a different way in typing tools, however there is no current method +# that is accepted by mypy (subclass of Type[_O] works in pylance, rejected +# by mypy). +AliasedType = Annotated[Type[_O], "aliased"] + + +@overload +def aliased( + element: Type[_O], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedType[_O]: ... + + +@overload +def aliased( + element: Union[AliasedClass[_O], Mapper[_O], AliasedInsp[_O]], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> AliasedClass[_O]: ... + + +@overload +def aliased( + element: FromClause, + alias: None = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> FromClause: ... + + +def aliased( + element: Union[_EntityType[_O], FromClause], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, +) -> Union[AliasedClass[_O], FromClause, AliasedType[_O]]: + """Produce an alias of the given element, usually an :class:`.AliasedClass` + instance. + + E.g.:: + + my_alias = aliased(MyClass) + + stmt = select(MyClass, my_alias).filter(MyClass.id > my_alias.id) + result = session.execute(stmt) + + The :func:`.aliased` function is used to create an ad-hoc mapping of a + mapped class to a new selectable. By default, a selectable is generated + from the normally mapped selectable (typically a :class:`_schema.Table` + ) using the + :meth:`_expression.FromClause.alias` method. However, :func:`.aliased` + can also be + used to link the class to a new :func:`_expression.select` statement. + Also, the :func:`.with_polymorphic` function is a variant of + :func:`.aliased` that is intended to specify a so-called "polymorphic + selectable", that corresponds to the union of several joined-inheritance + subclasses at once. + + For convenience, the :func:`.aliased` function also accepts plain + :class:`_expression.FromClause` constructs, such as a + :class:`_schema.Table` or + :func:`_expression.select` construct. In those cases, the + :meth:`_expression.FromClause.alias` + method is called on the object and the new + :class:`_expression.Alias` object returned. The returned + :class:`_expression.Alias` is not + ORM-mapped in this case. + + .. seealso:: + + :ref:`tutorial_orm_entity_aliases` - in the :ref:`unified_tutorial` + + :ref:`orm_queryguide_orm_aliases` - in the :ref:`queryguide_toplevel` + + :param element: element to be aliased. Is normally a mapped class, + but for convenience can also be a :class:`_expression.FromClause` + element. + + :param alias: Optional selectable unit to map the element to. This is + usually used to link the object to a subquery, and should be an aliased + select construct as one would produce from the + :meth:`_query.Query.subquery` method or + the :meth:`_expression.Select.subquery` or + :meth:`_expression.Select.alias` methods of the :func:`_expression.select` + construct. + + :param name: optional string name to use for the alias, if not specified + by the ``alias`` parameter. The name, among other things, forms the + attribute name that will be accessible via tuples returned by a + :class:`_query.Query` object. Not supported when creating aliases + of :class:`_sql.Join` objects. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. + + :param adapt_on_names: if True, more liberal "matching" will be used when + mapping the mapped columns of the ORM entity to those of the + given selectable - a name-based match will be performed if the + given selectable doesn't otherwise have a column that corresponds + to one on the entity. The use case for this is when associating + an entity with some derived selectable such as one that uses + aggregate functions:: + + class UnitPrice(Base): + __tablename__ = 'unit_price' + ... + unit_id = Column(Integer) + price = Column(Numeric) + + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() + + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) + + Above, functions on ``aggregated_unit_price`` which refer to + ``.price`` will return the + ``func.sum(UnitPrice.price).label('price')`` column, as it is + matched on the name "price". Ordinarily, the "price" function + wouldn't have any "column correspondence" to the actual + ``UnitPrice.price`` column as it is not a proxy of the original. + + """ + return AliasedInsp._alias_factory( + element, + alias=alias, + name=name, + flat=flat, + adapt_on_names=adapt_on_names, + ) + + +def with_polymorphic( + base: Union[Type[_O], Mapper[_O]], + classes: Union[Literal["*"], Iterable[Type[Any]]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, +) -> AliasedClass[_O]: + """Produce an :class:`.AliasedClass` construct which specifies + columns for descendant mappers of the given base. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + .. seealso:: + + :ref:`with_polymorphic` - full discussion of + :func:`_orm.with_polymorphic`. + + :param base: Base class to be aliased. + + :param classes: a single class or mapper, or list of + class/mappers, which inherit from the base class. + Alternatively, it may also be the string ``'*'``, in which case + all descending mapped classes will be added to the FROM clause. + + :param aliased: when True, the selectable will be aliased. For a + JOIN, this means the JOIN will be SELECTed from inside of a subquery + unless the :paramref:`_orm.with_polymorphic.flat` flag is set to + True, which is recommended for simpler use cases. + + :param flat: Boolean, will be passed through to the + :meth:`_expression.FromClause.alias` call so that aliases of + :class:`_expression.Join` objects will alias the individual tables + inside the join, rather than creating a subquery. This is generally + supported by all modern databases with regards to right-nested joins + and generally produces more efficient queries. Setting this flag is + recommended as long as the resulting SQL is functional. + + :param selectable: a table or subquery that will + be used in place of the generated FROM clause. This argument is + required if any of the desired classes use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` argument + must represent the full set of tables and columns mapped by every + mapped class. Otherwise, the unaccounted mapped columns will + result in their table being appended directly to the FROM clause + which will usually lead to incorrect results. + + When left at its default value of ``False``, the polymorphic + selectable assigned to the base mapper is used for selecting rows. + However, it may also be passed as ``None``, which will bypass the + configured polymorphic selectable and instead construct an ad-hoc + selectable for the target classes given; for joined table inheritance + this will be a join that includes all target mappers and their + subclasses. + + :param polymorphic_on: a column to be used as the "discriminator" + column for the given selectable. If not given, the polymorphic_on + attribute of the base classes' mapper will be used, if any. This + is useful for mappings that don't have polymorphic loading + behavior by default. + + :param innerjoin: if True, an INNER JOIN will be used. This should + only be specified if querying for one specific subtype only + + :param adapt_on_names: Passes through the + :paramref:`_orm.aliased.adapt_on_names` + parameter to the aliased object. This may be useful in situations where + the given selectable is not directly related to the existing mapped + selectable. + + .. versionadded:: 1.4.33 + + """ + return AliasedInsp._with_polymorphic_factory( + base, + classes, + selectable=selectable, + flat=flat, + polymorphic_on=polymorphic_on, + adapt_on_names=adapt_on_names, + aliased=aliased, + innerjoin=innerjoin, + _use_mapper_path=_use_mapper_path, + ) + + +def join( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, +) -> _ORMJoin: + r"""Produce an inner join between left and right clauses. + + :func:`_orm.join` is an extension to the core join interface + provided by :func:`_expression.join()`, where the + left and right selectable may be not only core selectable + objects such as :class:`_schema.Table`, but also mapped classes or + :class:`.AliasedClass` instances. The "on" clause can + be a SQL expression or an ORM mapped attribute + referencing a configured :func:`_orm.relationship`. + + :func:`_orm.join` is not commonly needed in modern usage, + as its functionality is encapsulated within that of the + :meth:`_sql.Select.join` and :meth:`_query.Query.join` + methods. which feature a + significant amount of automation beyond :func:`_orm.join` + by itself. Explicit use of :func:`_orm.join` + with ORM-enabled SELECT statements involves use of the + :meth:`_sql.Select.select_from` method, as in:: + + from sqlalchemy.orm import join + stmt = select(User).\ + select_from(join(User, Address, User.addresses)).\ + filter(Address.email_address=='foo@bar.com') + + In modern SQLAlchemy the above join can be written more + succinctly as:: + + stmt = select(User).\ + join(User.addresses).\ + filter(Address.email_address=='foo@bar.com') + + .. warning:: using :func:`_orm.join` directly may not work properly + with modern ORM options such as :func:`_orm.with_loader_criteria`. + It is strongly recommended to use the idiomatic join patterns + provided by methods such as :meth:`.Select.join` and + :meth:`.Select.join_from` when creating ORM joins. + + .. seealso:: + + :ref:`orm_queryguide_joins` - in the :ref:`queryguide_toplevel` for + background on idiomatic ORM join patterns + + """ + return _ORMJoin(left, right, onclause, isouter, full) + + +def outerjoin( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, +) -> _ORMJoin: + """Produce a left outer join between left and right clauses. + + This is the "outer join" version of the :func:`_orm.join` function, + featuring the same behavior except that an OUTER JOIN is generated. + See that function's documentation for other usage details. + + """ + return _ORMJoin(left, right, onclause, True, full) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/_typing.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/_typing.py new file mode 100644 index 0000000..f8ac059 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/_typing.py @@ -0,0 +1,179 @@ +# orm/_typing.py +# Copyright (C) 2022-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import operator +from typing import Any +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from ..engine.interfaces import _CoreKnownExecutionOptions +from ..sql import roles +from ..sql._orm_types import DMLStrategyArgument as DMLStrategyArgument +from ..sql._orm_types import ( + SynchronizeSessionArgument as SynchronizeSessionArgument, +) +from ..sql._typing import _HasClauseElement +from ..sql.elements import ColumnElement +from ..util.typing import Protocol +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from .attributes import AttributeImpl + from .attributes import CollectionAttributeImpl + from .attributes import HasCollectionAdapter + from .attributes import QueryableAttribute + from .base import PassiveFlag + from .decl_api import registry as _registry_type + from .interfaces import InspectionAttr + from .interfaces import MapperProperty + from .interfaces import ORMOption + from .interfaces import UserDefinedOption + from .mapper import Mapper + from .relationships import RelationshipProperty + from .state import InstanceState + from .util import AliasedClass + from .util import AliasedInsp + from ..sql._typing import _CE + from ..sql.base import ExecutableOption + +_T = TypeVar("_T", bound=Any) + + +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + +_O = TypeVar("_O", bound=object) +"""The 'ORM mapped object' type. + +""" + + +if TYPE_CHECKING: + _RegistryType = _registry_type + +_InternalEntityType = Union["Mapper[_T]", "AliasedInsp[_T]"] + +_ExternalEntityType = Union[Type[_T], "AliasedClass[_T]"] + +_EntityType = Union[ + Type[_T], "AliasedClass[_T]", "Mapper[_T]", "AliasedInsp[_T]" +] + + +_ClassDict = Mapping[str, Any] +_InstanceDict = Dict[str, Any] + +_IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] + +_ORMColumnExprArgument = Union[ + ColumnElement[_T], + _HasClauseElement[_T], + roles.ExpressionElementRole[_T], +] + + +_ORMCOLEXPR = TypeVar("_ORMCOLEXPR", bound=ColumnElement[Any]) + + +class _OrmKnownExecutionOptions(_CoreKnownExecutionOptions, total=False): + populate_existing: bool + autoflush: bool + synchronize_session: SynchronizeSessionArgument + dml_strategy: DMLStrategyArgument + is_delete_using: bool + is_update_from: bool + render_nulls: bool + + +OrmExecuteOptionsParameter = Union[ + _OrmKnownExecutionOptions, Mapping[str, Any] +] + + +class _ORMAdapterProto(Protocol): + """protocol for the :class:`.AliasedInsp._orm_adapt_element` method + which is a synonym for :class:`.AliasedInsp._adapt_element`. + + + """ + + def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: ... + + +class _LoaderCallable(Protocol): + def __call__( + self, state: InstanceState[Any], passive: PassiveFlag + ) -> Any: ... + + +def is_orm_option( + opt: ExecutableOption, +) -> TypeGuard[ORMOption]: + return not opt._is_core + + +def is_user_defined_option( + opt: ExecutableOption, +) -> TypeGuard[UserDefinedOption]: + return not opt._is_core and opt._is_user_defined # type: ignore + + +def is_composite_class(obj: Any) -> bool: + # inlining is_dataclass(obj) + return hasattr(obj, "__composite_values__") or hasattr( + obj, "__dataclass_fields__" + ) + + +if TYPE_CHECKING: + + def insp_is_mapper_property( + obj: Any, + ) -> TypeGuard[MapperProperty[Any]]: ... + + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: ... + + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: ... + + def insp_is_attribute( + obj: InspectionAttr, + ) -> TypeGuard[QueryableAttribute[Any]]: ... + + def attr_is_internal_proxy( + obj: InspectionAttr, + ) -> TypeGuard[QueryableAttribute[Any]]: ... + + def prop_is_relationship( + prop: MapperProperty[Any], + ) -> TypeGuard[RelationshipProperty[Any]]: ... + + def is_collection_impl( + impl: AttributeImpl, + ) -> TypeGuard[CollectionAttributeImpl]: ... + + def is_has_collection_adapter( + impl: AttributeImpl, + ) -> TypeGuard[HasCollectionAdapter]: ... + +else: + insp_is_mapper_property = operator.attrgetter("is_property") + insp_is_mapper = operator.attrgetter("is_mapper") + insp_is_aliased_class = operator.attrgetter("is_aliased_class") + insp_is_attribute = operator.attrgetter("is_attribute") + attr_is_internal_proxy = operator.attrgetter("_is_internal_proxy") + is_collection_impl = operator.attrgetter("collection") + prop_is_relationship = operator.attrgetter("_is_relationship") + is_has_collection_adapter = operator.attrgetter( + "_is_has_collection_adapter" + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/attributes.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/attributes.py new file mode 100644 index 0000000..5b16ce3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/attributes.py @@ -0,0 +1,2835 @@ +# orm/attributes.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Defines instrumentation for class attributes and their interaction +with instances. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + + +""" + +from __future__ import annotations + +import dataclasses +import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Iterable +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import collections +from . import exc as orm_exc +from . import interfaces +from ._typing import insp_is_aliased_class +from .base import _DeclarativeMapped +from .base import ATTR_EMPTY +from .base import ATTR_WAS_SET +from .base import CALLABLES_OK +from .base import DEFERRED_HISTORY_LOAD +from .base import INCLUDE_PENDING_MUTATIONS # noqa +from .base import INIT_OK +from .base import instance_dict as instance_dict +from .base import instance_state as instance_state +from .base import instance_str +from .base import LOAD_AGAINST_COMMITTED +from .base import LoaderCallableStatus +from .base import manager_of_class as manager_of_class +from .base import Mapped as Mapped # noqa +from .base import NEVER_SET # noqa +from .base import NO_AUTOFLUSH +from .base import NO_CHANGE # noqa +from .base import NO_KEY +from .base import NO_RAISE +from .base import NO_VALUE +from .base import NON_PERSISTENT_OK # noqa +from .base import opt_manager_of_class as opt_manager_of_class +from .base import PASSIVE_CLASS_MISMATCH # noqa +from .base import PASSIVE_NO_FETCH +from .base import PASSIVE_NO_FETCH_RELATED # noqa +from .base import PASSIVE_NO_INITIALIZE +from .base import PASSIVE_NO_RESULT +from .base import PASSIVE_OFF +from .base import PASSIVE_ONLY_PERSISTENT +from .base import PASSIVE_RETURN_NO_VALUE +from .base import PassiveFlag +from .base import RELATED_OBJECT_OK # noqa +from .base import SQL_OK # noqa +from .base import SQLORMExpression +from .base import state_str +from .. import event +from .. import exc +from .. import inspection +from .. import util +from ..event import dispatcher +from ..event import EventTarget +from ..sql import base as sql_base +from ..sql import cache_key +from ..sql import coercions +from ..sql import roles +from ..sql import visitors +from ..sql.cache_key import HasCacheKey +from ..sql.visitors import _TraverseInternalsType +from ..sql.visitors import InternalTraversal +from ..util.typing import Literal +from ..util.typing import Self +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _LoaderCallable + from ._typing import _O + from .collections import _AdaptedCollectionProtocol + from .collections import CollectionAdapter + from .interfaces import MapperProperty + from .relationships import RelationshipProperty + from .state import InstanceState + from .util import AliasedInsp + from .writeonly import WriteOnlyAttributeImpl + from ..event.base import _Dispatch + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _AnnotationDict + from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.operators import OperatorType + from ..sql.selectable import FromClause + + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + + +_AllPendingType = Sequence[ + Tuple[Optional["InstanceState[Any]"], Optional[object]] +] + + +_UNKNOWN_ATTR_KEY = object() + + +@inspection._self_inspects +class QueryableAttribute( + _DeclarativeMapped[_T_co], + SQLORMExpression[_T_co], + interfaces.InspectionAttr, + interfaces.PropComparator[_T_co], + roles.JoinTargetRole, + roles.OnClauseRole, + sql_base.Immutable, + cache_key.SlotsMemoizedHasCacheKey, + util.MemoizedSlots, + EventTarget, +): + """Base class for :term:`descriptor` objects that intercept + attribute events on behalf of a :class:`.MapperProperty` + object. The actual :class:`.MapperProperty` is accessible + via the :attr:`.QueryableAttribute.property` + attribute. + + + .. seealso:: + + :class:`.InstrumentedAttribute` + + :class:`.MapperProperty` + + :attr:`_orm.Mapper.all_orm_descriptors` + + :attr:`_orm.Mapper.attrs` + """ + + __slots__ = ( + "class_", + "key", + "impl", + "comparator", + "property", + "parent", + "expression", + "_of_type", + "_extra_criteria", + "_slots_dispatch", + "_propagate_attrs", + "_doc", + ) + + is_attribute = True + + dispatch: dispatcher[QueryableAttribute[_T_co]] + + class_: _ExternalEntityType[Any] + key: str + parententity: _InternalEntityType[Any] + impl: AttributeImpl + comparator: interfaces.PropComparator[_T_co] + _of_type: Optional[_InternalEntityType[Any]] + _extra_criteria: Tuple[ColumnElement[bool], ...] + _doc: Optional[str] + + # PropComparator has a __visit_name__ to participate within + # traversals. Disambiguate the attribute vs. a comparator. + __visit_name__ = "orm_instrumented_attribute" + + def __init__( + self, + class_: _ExternalEntityType[_O], + key: str, + parententity: _InternalEntityType[_O], + comparator: interfaces.PropComparator[_T_co], + impl: Optional[AttributeImpl] = None, + of_type: Optional[_InternalEntityType[Any]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ): + self.class_ = class_ + self.key = key + + self._parententity = self.parent = parententity + + # this attribute is non-None after mappers are set up, however in the + # interim class manager setup, there's a check for None to see if it + # needs to be populated, so we assign None here leaving the attribute + # in a temporarily not-type-correct state + self.impl = impl # type: ignore + + assert comparator is not None + self.comparator = comparator + self._of_type = of_type + self._extra_criteria = extra_criteria + self._doc = None + + manager = opt_manager_of_class(class_) + # manager is None in the case of AliasedClass + if manager: + # propagate existing event listeners from + # immediate superclass + for base in manager._bases: + if key in base: + self.dispatch._update(base[key].dispatch) + if base[key].dispatch._active_history: + self.dispatch._active_history = True # type: ignore + + _cache_key_traversal = [ + ("key", visitors.ExtendedInternalTraversal.dp_string), + ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), + ] + + def __reduce__(self) -> Any: + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + return ( + _queryable_attribute_unreduce, + ( + self.key, + self._parententity.mapper.class_, + self._parententity, + self._parententity.entity, + ), + ) + + @property + def _impl_uses_objects(self) -> bool: + return self.impl.uses_objects + + def get_history( + self, instance: Any, passive: PassiveFlag = PASSIVE_OFF + ) -> History: + return self.impl.get_history( + instance_state(instance), instance_dict(instance), passive + ) + + @property + def info(self) -> _InfoType: + """Return the 'info' dictionary for the underlying SQL element. + + The behavior here is as follows: + + * If the attribute is a column-mapped property, i.e. + :class:`.ColumnProperty`, which is mapped directly + to a schema-level :class:`_schema.Column` object, this attribute + will return the :attr:`.SchemaItem.info` dictionary associated + with the core-level :class:`_schema.Column` object. + + * If the attribute is a :class:`.ColumnProperty` but is mapped to + any other kind of SQL expression other than a + :class:`_schema.Column`, + the attribute will refer to the :attr:`.MapperProperty.info` + dictionary associated directly with the :class:`.ColumnProperty`, + assuming the SQL expression itself does not have its own ``.info`` + attribute (which should be the case, unless a user-defined SQL + construct has defined one). + + * If the attribute refers to any other kind of + :class:`.MapperProperty`, including :class:`.Relationship`, + the attribute will refer to the :attr:`.MapperProperty.info` + dictionary associated with that :class:`.MapperProperty`. + + * To access the :attr:`.MapperProperty.info` dictionary of the + :class:`.MapperProperty` unconditionally, including for a + :class:`.ColumnProperty` that's associated directly with a + :class:`_schema.Column`, the attribute can be referred to using + :attr:`.QueryableAttribute.property` attribute, as + ``MyClass.someattribute.property.info``. + + .. seealso:: + + :attr:`.SchemaItem.info` + + :attr:`.MapperProperty.info` + + """ + return self.comparator.info + + parent: _InternalEntityType[Any] + """Return an inspection instance representing the parent. + + This will be either an instance of :class:`_orm.Mapper` + or :class:`.AliasedInsp`, depending upon the nature + of the parent entity which this attribute is associated + with. + + """ + + expression: ColumnElement[_T_co] + """The SQL expression object represented by this + :class:`.QueryableAttribute`. + + This will typically be an instance of a :class:`_sql.ColumnElement` + subclass representing a column expression. + + """ + + def _memoized_attr_expression(self) -> ColumnElement[_T]: + annotations: _AnnotationDict + + # applies only to Proxy() as used by hybrid. + # currently is an exception to typing rather than feeding through + # non-string keys. + # ideally Proxy() would have a separate set of methods to deal + # with this case. + entity_namespace = self._entity_namespace + assert isinstance(entity_namespace, HasCacheKey) + + if self.key is _UNKNOWN_ATTR_KEY: + annotations = {"entity_namespace": entity_namespace} + else: + annotations = { + "proxy_key": self.key, + "proxy_owner": self._parententity, + "entity_namespace": entity_namespace, + } + + ce = self.comparator.__clause_element__() + try: + if TYPE_CHECKING: + assert isinstance(ce, ColumnElement) + anno = ce._annotate + except AttributeError as ae: + raise exc.InvalidRequestError( + 'When interpreting attribute "%s" as a SQL expression, ' + "expected __clause_element__() to return " + "a ClauseElement object, got: %r" % (self, ce) + ) from ae + else: + return anno(annotations) + + def _memoized_attr__propagate_attrs(self) -> _PropagateAttrsType: + # this suits the case in coercions where we don't actually + # call ``__clause_element__()`` but still need to get + # resolved._propagate_attrs. See #6558. + return util.immutabledict( + { + "compile_state_plugin": "orm", + "plugin_subject": self._parentmapper, + } + ) + + @property + def _entity_namespace(self) -> _InternalEntityType[Any]: + return self._parententity + + @property + def _annotations(self) -> _AnnotationDict: + return self.__clause_element__()._annotations + + def __clause_element__(self) -> ColumnElement[_T_co]: + return self.expression + + @property + def _from_objects(self) -> List[FromClause]: + return self.expression._from_objects + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + """Return setter tuples for a bulk UPDATE.""" + + return self.comparator._bulk_update_tuples(value) + + def adapt_to_entity(self, adapt_to_entity: AliasedInsp[Any]) -> Self: + assert not self._of_type + return self.__class__( + adapt_to_entity.entity, + self.key, + impl=self.impl, + comparator=self.comparator.adapt_to_entity(adapt_to_entity), + parententity=adapt_to_entity, + ) + + def of_type(self, entity: _EntityType[Any]) -> QueryableAttribute[_T]: + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator.of_type(entity), + of_type=inspection.inspect(entity), + extra_criteria=self._extra_criteria, + ) + + def and_( + self, *clauses: _ColumnExpressionArgument[bool] + ) -> QueryableAttribute[bool]: + if TYPE_CHECKING: + assert isinstance(self.comparator, RelationshipProperty.Comparator) + + exprs = tuple( + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(clauses) + ) + + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator.and_(*exprs), + of_type=self._of_type, + extra_criteria=self._extra_criteria + exprs, + ) + + def _clone(self, **kw: Any) -> QueryableAttribute[_T]: + return QueryableAttribute( + self.class_, + self.key, + self._parententity, + impl=self.impl, + comparator=self.comparator, + of_type=self._of_type, + extra_criteria=self._extra_criteria, + ) + + def label(self, name: Optional[str]) -> Label[_T_co]: + return self.__clause_element__().label(name) + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: + return self.impl.hasparent(state, optimistic=optimistic) is not False + + def __getattr__(self, key: str) -> Any: + try: + return util.MemoizedSlots.__getattr__(self, key) + except AttributeError: + pass + + try: + return getattr(self.comparator, key) + except AttributeError as err: + raise AttributeError( + "Neither %r object nor %r object associated with %s " + "has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + self, + key, + ) + ) from err + + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" + + def _memoized_attr_property(self) -> Optional[MapperProperty[Any]]: + return self.comparator.property + + +def _queryable_attribute_unreduce( + key: str, + mapped_class: Type[_O], + parententity: _InternalEntityType[_O], + entity: _ExternalEntityType[Any], +) -> Any: + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + if insp_is_aliased_class(parententity): + return entity._get_from_serialized(key, mapped_class, parententity) + else: + return getattr(entity, key) + + +class InstrumentedAttribute(QueryableAttribute[_T_co]): + """Class bound instrumented attribute which adds basic + :term:`descriptor` methods. + + See :class:`.QueryableAttribute` for a description of most features. + + + """ + + __slots__ = () + + inherit_cache = True + """:meta private:""" + + # hack to make __doc__ writeable on instances of + # InstrumentedAttribute, while still keeping classlevel + # __doc__ correct + + @util.rw_hybridproperty + def __doc__(self) -> Optional[str]: + return self._doc + + @__doc__.setter # type: ignore + def __doc__(self, value: Optional[str]) -> None: + self._doc = value + + @__doc__.classlevel # type: ignore + def __doc__(cls) -> Optional[str]: + return super().__doc__ + + def __set__(self, instance: object, value: Any) -> None: + self.impl.set( + instance_state(instance), instance_dict(instance), value, None + ) + + def __delete__(self, instance: object) -> None: + self.impl.delete(instance_state(instance), instance_dict(instance)) + + @overload + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T_co]: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T_co: ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: + if instance is None: + return self + + dict_ = instance_dict(instance) + if self.impl.supports_population and self.key in dict_: + return dict_[self.key] # type: ignore[no-any-return] + else: + try: + state = instance_state(instance) + except AttributeError as err: + raise orm_exc.UnmappedInstanceError(instance) from err + return self.impl.get(state, dict_) # type: ignore[no-any-return] + + +@dataclasses.dataclass(frozen=True) +class AdHocHasEntityNamespace(HasCacheKey): + _traverse_internals: ClassVar[_TraverseInternalsType] = [ + ("_entity_namespace", InternalTraversal.dp_has_cache_key), + ] + + # py37 compat, no slots=True on dataclass + __slots__ = ("_entity_namespace",) + _entity_namespace: _InternalEntityType[Any] + is_mapper: ClassVar[bool] = False + is_aliased_class: ClassVar[bool] = False + + @property + def entity_namespace(self): + return self._entity_namespace.entity_namespace + + +def create_proxied_attribute( + descriptor: Any, +) -> Callable[..., QueryableAttribute[Any]]: + """Create an QueryableAttribute / user descriptor hybrid. + + Returns a new QueryableAttribute type that delegates descriptor + behavior and getattr() to the given descriptor. + """ + + # TODO: can move this to descriptor_props if the need for this + # function is removed from ext/hybrid.py + + class Proxy(QueryableAttribute[Any]): + """Presents the :class:`.QueryableAttribute` interface as a + proxy on top of a Python descriptor / :class:`.PropComparator` + combination. + + """ + + _extra_criteria = () + + # the attribute error catches inside of __getattr__ basically create a + # singularity if you try putting slots on this too + # __slots__ = ("descriptor", "original_property", "_comparator") + + def __init__( + self, + class_, + key, + descriptor, + comparator, + adapt_to_entity=None, + doc=None, + original_property=None, + ): + self.class_ = class_ + self.key = key + self.descriptor = descriptor + self.original_property = original_property + self._comparator = comparator + self._adapt_to_entity = adapt_to_entity + self._doc = self.__doc__ = doc + + @property + def _parententity(self): + return inspection.inspect(self.class_, raiseerr=False) + + @property + def parent(self): + return inspection.inspect(self.class_, raiseerr=False) + + _is_internal_proxy = True + + _cache_key_traversal = [ + ("key", visitors.ExtendedInternalTraversal.dp_string), + ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), + ] + + @property + def _impl_uses_objects(self): + return ( + self.original_property is not None + and getattr(self.class_, self.key).impl.uses_objects + ) + + @property + def _entity_namespace(self): + if hasattr(self._comparator, "_parententity"): + return self._comparator._parententity + else: + # used by hybrid attributes which try to remain + # agnostic of any ORM concepts like mappers + return AdHocHasEntityNamespace(self._parententity) + + @property + def property(self): + return self.comparator.property + + @util.memoized_property + def comparator(self): + if callable(self._comparator): + self._comparator = self._comparator() + if self._adapt_to_entity: + self._comparator = self._comparator.adapt_to_entity( + self._adapt_to_entity + ) + return self._comparator + + def adapt_to_entity(self, adapt_to_entity): + return self.__class__( + adapt_to_entity.entity, + self.key, + self.descriptor, + self._comparator, + adapt_to_entity, + ) + + def _clone(self, **kw): + return self.__class__( + self.class_, + self.key, + self.descriptor, + self._comparator, + adapt_to_entity=self._adapt_to_entity, + original_property=self.original_property, + ) + + def __get__(self, instance, owner): + retval = self.descriptor.__get__(instance, owner) + # detect if this is a plain Python @property, which just returns + # itself for class level access. If so, then return us. + # Otherwise, return the object returned by the descriptor. + if retval is self.descriptor and instance is None: + return self + else: + return retval + + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" + + def __getattr__(self, attribute): + """Delegate __getattr__ to the original descriptor and/or + comparator.""" + + # this is unfortunately very complicated, and is easily prone + # to recursion overflows when implementations of related + # __getattr__ schemes are changed + + try: + return util.MemoizedSlots.__getattr__(self, attribute) + except AttributeError: + pass + + try: + return getattr(descriptor, attribute) + except AttributeError as err: + if attribute == "comparator": + raise AttributeError("comparator") from err + try: + # comparator itself might be unreachable + comparator = self.comparator + except AttributeError as err2: + raise AttributeError( + "Neither %r object nor unconfigured comparator " + "object associated with %s has an attribute %r" + % (type(descriptor).__name__, self, attribute) + ) from err2 + else: + try: + return getattr(comparator, attribute) + except AttributeError as err3: + raise AttributeError( + "Neither %r object nor %r object " + "associated with %s has an attribute %r" + % ( + type(descriptor).__name__, + type(comparator).__name__, + self, + attribute, + ) + ) from err3 + + Proxy.__name__ = type(descriptor).__name__ + "Proxy" + + util.monkeypatch_proxied_specials( + Proxy, type(descriptor), name="descriptor", from_instance=descriptor + ) + return Proxy + + +OP_REMOVE = util.symbol("REMOVE") +OP_APPEND = util.symbol("APPEND") +OP_REPLACE = util.symbol("REPLACE") +OP_BULK_REPLACE = util.symbol("BULK_REPLACE") +OP_MODIFIED = util.symbol("MODIFIED") + + +class AttributeEventToken: + """A token propagated throughout the course of a chain of attribute + events. + + Serves as an indicator of the source of the event and also provides + a means of controlling propagation across a chain of attribute + operations. + + The :class:`.Event` object is sent as the ``initiator`` argument + when dealing with events such as :meth:`.AttributeEvents.append`, + :meth:`.AttributeEvents.set`, + and :meth:`.AttributeEvents.remove`. + + The :class:`.Event` object is currently interpreted by the backref + event handlers, and is used to control the propagation of operations + across two mutually-dependent attributes. + + .. versionchanged:: 2.0 Changed the name from ``AttributeEvent`` + to ``AttributeEventToken``. + + :attribute impl: The :class:`.AttributeImpl` which is the current event + initiator. + + :attribute op: The symbol :attr:`.OP_APPEND`, :attr:`.OP_REMOVE`, + :attr:`.OP_REPLACE`, or :attr:`.OP_BULK_REPLACE`, indicating the + source operation. + + """ + + __slots__ = "impl", "op", "parent_token" + + def __init__(self, attribute_impl: AttributeImpl, op: util.symbol): + self.impl = attribute_impl + self.op = op + self.parent_token = self.impl.parent_token + + def __eq__(self, other): + return ( + isinstance(other, AttributeEventToken) + and other.impl is self.impl + and other.op == self.op + ) + + @property + def key(self): + return self.impl.key + + def hasparent(self, state): + return self.impl.hasparent(state) + + +AttributeEvent = AttributeEventToken # legacy +Event = AttributeEventToken # legacy + + +class AttributeImpl: + """internal implementation for instrumented attributes.""" + + collection: bool + default_accepts_scalar_loader: bool + uses_objects: bool + supports_population: bool + dynamic: bool + + _is_has_collection_adapter = False + + _replace_token: AttributeEventToken + _remove_token: AttributeEventToken + _append_token: AttributeEventToken + + def __init__( + self, + class_: _ExternalEntityType[_O], + key: str, + callable_: Optional[_LoaderCallable], + dispatch: _Dispatch[QueryableAttribute[Any]], + trackparent: bool = False, + compare_function: Optional[Callable[..., bool]] = None, + active_history: bool = False, + parent_token: Optional[AttributeEventToken] = None, + load_on_unexpire: bool = True, + send_modified_events: bool = True, + accepts_scalar_loader: Optional[bool] = None, + **kwargs: Any, + ): + r"""Construct an AttributeImpl. + + :param \class_: associated class + + :param key: string name of the attribute + + :param \callable_: + optional function which generates a callable based on a parent + instance, which produces the "default" values for a scalar or + collection attribute when it's first accessed, if not present + already. + + :param trackparent: + if True, attempt to track if an instance has a parent attached + to it via this attribute. + + :param compare_function: + a function that compares two values which are normally + assignable to this attribute. + + :param active_history: + indicates that get_history() should always return the "old" value, + even if it means executing a lazy callable upon attribute change. + + :param parent_token: + Usually references the MapperProperty, used as a key for + the hasparent() function to identify an "owning" attribute. + Allows multiple AttributeImpls to all match a single + owner attribute. + + :param load_on_unexpire: + if False, don't include this attribute in a load-on-expired + operation, i.e. the "expired_attribute_loader" process. + The attribute can still be in the "expired" list and be + considered to be "expired". Previously, this flag was called + "expire_missing" and is only used by a deferred column + attribute. + + :param send_modified_events: + if False, the InstanceState._modified_event method will have no + effect; this means the attribute will never show up as changed in a + history entry. + + """ + self.class_ = class_ + self.key = key + self.callable_ = callable_ + self.dispatch = dispatch + self.trackparent = trackparent + self.parent_token = parent_token or self + self.send_modified_events = send_modified_events + if compare_function is None: + self.is_equal = operator.eq + else: + self.is_equal = compare_function + + if accepts_scalar_loader is not None: + self.accepts_scalar_loader = accepts_scalar_loader + else: + self.accepts_scalar_loader = self.default_accepts_scalar_loader + + _deferred_history = kwargs.pop("_deferred_history", False) + self._deferred_history = _deferred_history + + if active_history: + self.dispatch._active_history = True + + self.load_on_unexpire = load_on_unexpire + self._modified_token = AttributeEventToken(self, OP_MODIFIED) + + __slots__ = ( + "class_", + "key", + "callable_", + "dispatch", + "trackparent", + "parent_token", + "send_modified_events", + "is_equal", + "load_on_unexpire", + "_modified_token", + "accepts_scalar_loader", + "_deferred_history", + ) + + def __str__(self) -> str: + return f"{self.class_.__name__}.{self.key}" + + def _get_active_history(self): + """Backwards compat for impl.active_history""" + + return self.dispatch._active_history + + def _set_active_history(self, value): + self.dispatch._active_history = value + + active_history = property(_get_active_history, _set_active_history) + + def hasparent( + self, state: InstanceState[Any], optimistic: bool = False + ) -> bool: + """Return the boolean value of a `hasparent` flag attached to + the given state. + + The `optimistic` flag determines what the default return value + should be if no `hasparent` flag can be located. + + As this function is used to determine if an instance is an + *orphan*, instances that were loaded from storage should be + assumed to not be orphans, until a True/False value for this + flag is set. + + An instance attribute that is loaded by a callable function + will also not have a `hasparent` flag. + + """ + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg + + return ( + state.parents.get(id(self.parent_token), optimistic) is not False + ) + + def sethasparent( + self, + state: InstanceState[Any], + parent_state: InstanceState[Any], + value: bool, + ) -> None: + """Set a boolean flag on the given item corresponding to + whether or not it is attached to a parent object via the + attribute represented by this ``InstrumentedAttribute``. + + """ + msg = "This AttributeImpl is not configured to track parents." + assert self.trackparent, msg + + id_ = id(self.parent_token) + if value: + state.parents[id_] = parent_state + else: + if id_ in state.parents: + last_parent = state.parents[id_] + + if ( + last_parent is not False + and last_parent.key != parent_state.key + ): + if last_parent.obj() is None: + raise orm_exc.StaleDataError( + "Removing state %s from parent " + "state %s along attribute '%s', " + "but the parent record " + "has gone stale, can't be sure this " + "is the most recent parent." + % ( + state_str(state), + state_str(parent_state), + self.key, + ) + ) + + return + + state.parents[id_] = False + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + raise NotImplementedError() + + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: + """Return a list of tuples of (state, obj) + for all objects in this attribute's current state + + history. + + Only applies to object-based attributes. + + This is an inlining of existing functionality + which roughly corresponds to: + + get_state_history( + state, + key, + passive=PASSIVE_NO_INITIALIZE).sum() + + """ + raise NotImplementedError() + + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> Any: + """Produce an empty value for an uninitialized scalar attribute.""" + + assert self.key not in dict_, ( + "_default_value should only be invoked for an " + "uninitialized or expired attribute" + ) + + value = None + for fn in self.dispatch.init_scalar: + ret = fn(state, value, dict_) + if ret is not ATTR_EMPTY: + value = ret + + return value + + def get( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: + """Retrieve a value from the given object. + If a callable is assembled on this object's attribute, and + passive is False, the callable will be executed and the + resulting value will be set as the new value for this attribute. + """ + if self.key in dict_: + return dict_[self.key] + else: + # if history present, don't load + key = self.key + if ( + key not in state.committed_state + or state.committed_state[key] is NO_VALUE + ): + if not passive & CALLABLES_OK: + return PASSIVE_NO_RESULT + + value = self._fire_loader_callables(state, key, passive) + + if value is PASSIVE_NO_RESULT or value is NO_VALUE: + return value + elif value is ATTR_WAS_SET: + try: + return dict_[key] + except KeyError as err: + # TODO: no test coverage here. + raise KeyError( + "Deferred loader for attribute " + "%r failed to populate " + "correctly" % key + ) from err + elif value is not ATTR_EMPTY: + return self.set_committed_value(state, dict_, value) + + if not passive & INIT_OK: + return NO_VALUE + else: + return self._default_value(state, dict_) + + def _fire_loader_callables( + self, state: InstanceState[Any], key: str, passive: PassiveFlag + ) -> Any: + if ( + self.accepts_scalar_loader + and self.load_on_unexpire + and key in state.expired_attributes + ): + return state._load_expired(state, passive) + elif key in state.callables: + callable_ = state.callables[key] + return callable_(state, passive) + elif self.callable_: + return self.callable_(state, passive) + else: + return ATTR_EMPTY + + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + self.set(state, dict_, value, initiator, passive=passive) + + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + self.set( + state, dict_, None, initiator, passive=passive, check_old=value + ) + + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + self.set( + state, + dict_, + None, + initiator, + passive=passive, + check_old=value, + pop=True, + ) + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: + raise NotImplementedError() + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + raise NotImplementedError() + + def get_committed_value( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Any: + """return the unchanged value of this attribute""" + + if self.key in state.committed_state: + value = state.committed_state[self.key] + if value is NO_VALUE: + return None + else: + return value + else: + return self.get(state, dict_, passive=passive) + + def set_committed_value(self, state, dict_, value): + """set an attribute value on the given instance and 'commit' it.""" + + dict_[self.key] = value + state._commit(dict_, [self.key]) + return value + + +class ScalarAttributeImpl(AttributeImpl): + """represents a scalar value-holding InstrumentedAttribute.""" + + default_accepts_scalar_loader = True + uses_objects = False + supports_population = True + collection = False + dynamic = False + + __slots__ = "_replace_token", "_append_token", "_remove_token" + + def __init__(self, *arg, **kw): + super().__init__(*arg, **kw) + self._replace_token = self._append_token = AttributeEventToken( + self, OP_REPLACE + ) + self._remove_token = AttributeEventToken(self, OP_REMOVE) + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.dispatch.remove: + self.fire_remove_event(state, dict_, old, self._remove_token) + state._modified_event(dict_, self, old) + + existing = dict_.pop(self.key, NO_VALUE) + if ( + existing is NO_VALUE + and old is NO_VALUE + and not state.expired + and self.key not in state.expired_attributes + ): + raise AttributeError("%s object does not have a value" % self) + + def get_history( + self, + state: InstanceState[Any], + dict_: Dict[str, Any], + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + if self.key in dict_: + return History.from_scalar_attribute(self, state, dict_[self.key]) + elif self.key in state.committed_state: + return History.from_scalar_attribute(self, state, NO_VALUE) + else: + if passive & INIT_OK: + passive ^= INIT_OK + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + else: + return History.from_scalar_attribute(self, state, current) + + def set( + self, + state: InstanceState[Any], + dict_: Dict[str, Any], + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Optional[object] = None, + pop: bool = False, + ) -> None: + if self.dispatch._active_history: + old = self.get(state, dict_, PASSIVE_RETURN_NO_VALUE) + else: + old = dict_.get(self.key, NO_VALUE) + + if self.dispatch.set: + value = self.fire_replace_event( + state, dict_, value, old, initiator + ) + state._modified_event(dict_, self, old) + dict_[self.key] = value + + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: + for fn in self.dispatch.set: + value = fn( + state, value, previous, initiator or self._replace_token + ) + return value + + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + +class ScalarObjectAttributeImpl(ScalarAttributeImpl): + """represents a scalar-holding InstrumentedAttribute, + where the target object is also instrumented. + + Adds events to delete/set operations. + + """ + + default_accepts_scalar_loader = False + uses_objects = True + supports_population = True + collection = False + + __slots__ = () + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + if self.dispatch._active_history: + old = self.get( + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED, + ) + else: + old = self.get( + state, + dict_, + passive=PASSIVE_NO_FETCH ^ INIT_OK + | LOAD_AGAINST_COMMITTED + | NO_RAISE, + ) + + self.fire_remove_event(state, dict_, old, self._remove_token) + + existing = dict_.pop(self.key, NO_VALUE) + + # if the attribute is expired, we currently have no way to tell + # that an object-attribute was expired vs. not loaded. So + # for this test, we look to see if the object has a DB identity. + if ( + existing is NO_VALUE + and old is not PASSIVE_NO_RESULT + and state.key is None + ): + raise AttributeError("%s object does not have a value" % self) + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + if self.key in dict_: + current = dict_[self.key] + else: + if passive & INIT_OK: + passive ^= INIT_OK + current = self.get(state, dict_, passive=passive) + if current is PASSIVE_NO_RESULT: + return HISTORY_BLANK + + if not self._deferred_history: + return History.from_object_attribute(self, state, current) + else: + original = state.committed_state.get(self.key, _NO_HISTORY) + if original is PASSIVE_NO_RESULT: + loader_passive = passive | ( + PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED + | NO_RAISE + | DEFERRED_HISTORY_LOAD + ) + original = self._fire_loader_callables( + state, self.key, loader_passive + ) + return History.from_object_attribute( + self, state, current, original=original + ) + + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: + if self.key in dict_: + current = dict_[self.key] + elif passive & CALLABLES_OK: + current = self.get(state, dict_, passive=passive) + else: + return [] + + ret: _AllPendingType + + # can't use __hash__(), can't use __eq__() here + if ( + current is not None + and current is not PASSIVE_NO_RESULT + and current is not NO_VALUE + ): + ret = [(instance_state(current), current)] + else: + ret = [(None, None)] + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if ( + original is not None + and original is not PASSIVE_NO_RESULT + and original is not NO_VALUE + and original is not current + ): + ret.append((instance_state(original), original)) + return ret + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + ) -> None: + """Set a value on the given InstanceState.""" + + if self.dispatch._active_history: + old = self.get( + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT + | NO_AUTOFLUSH + | LOAD_AGAINST_COMMITTED, + ) + else: + old = self.get( + state, + dict_, + passive=PASSIVE_NO_FETCH ^ INIT_OK + | LOAD_AGAINST_COMMITTED + | NO_RAISE, + ) + + if ( + check_old is not None + and old is not PASSIVE_NO_RESULT + and check_old is not old + ): + if pop: + return + else: + raise ValueError( + "Object %s not associated with %s on attribute '%s'" + % (instance_str(check_old), state_str(state), self.key) + ) + + value = self.fire_replace_event(state, dict_, value, old, initiator) + dict_[self.key] = value + + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + ) -> None: + if self.trackparent and value not in ( + None, + PASSIVE_NO_RESULT, + NO_VALUE, + ): + self.sethasparent(instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + state._modified_event(dict_, self, value) + + def fire_replace_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + previous: Any, + initiator: Optional[AttributeEventToken], + ) -> _T: + if self.trackparent: + if previous is not value and previous not in ( + None, + PASSIVE_NO_RESULT, + NO_VALUE, + ): + self.sethasparent(instance_state(previous), state, False) + + for fn in self.dispatch.set: + value = fn( + state, value, previous, initiator or self._replace_token + ) + + state._modified_event(dict_, self, previous) + + if self.trackparent: + if value is not None: + self.sethasparent(instance_state(value), state, True) + + return value + + +class HasCollectionAdapter: + __slots__ = () + + collection: bool + _is_has_collection_adapter = True + + def _dispose_previous_collection( + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: + raise NotImplementedError() + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Literal[None] = ..., + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = ..., + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + raise NotImplementedError() + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: + raise NotImplementedError() + + +if TYPE_CHECKING: + + def _is_collection_attribute_impl( + impl: AttributeImpl, + ) -> TypeGuard[CollectionAttributeImpl]: ... + +else: + _is_collection_attribute_impl = operator.attrgetter("collection") + + +class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): + """A collection-holding attribute that instruments changes in membership. + + Only handles collections of instrumented objects. + + InstrumentedCollectionAttribute holds an arbitrary, user-specified + container object (defaulting to a list) and brokers access to the + CollectionAdapter, a "view" onto that object that presents consistent bag + semantics to the orm layer independent of the user data implementation. + + """ + + uses_objects = True + collection = True + default_accepts_scalar_loader = False + supports_population = True + dynamic = False + + _bulk_replace_token: AttributeEventToken + + __slots__ = ( + "copy", + "collection_factory", + "_append_token", + "_remove_token", + "_bulk_replace_token", + "_duck_typed_as", + ) + + def __init__( + self, + class_, + key, + callable_, + dispatch, + typecallable=None, + trackparent=False, + copy_function=None, + compare_function=None, + **kwargs, + ): + super().__init__( + class_, + key, + callable_, + dispatch, + trackparent=trackparent, + compare_function=compare_function, + **kwargs, + ) + + if copy_function is None: + copy_function = self.__copy + self.copy = copy_function + self.collection_factory = typecallable + self._append_token = AttributeEventToken(self, OP_APPEND) + self._remove_token = AttributeEventToken(self, OP_REMOVE) + self._bulk_replace_token = AttributeEventToken(self, OP_BULK_REPLACE) + self._duck_typed_as = util.duck_type_collection( + self.collection_factory() + ) + + if getattr(self.collection_factory, "_sa_linker", None): + + @event.listens_for(self, "init_collection") + def link(target, collection, collection_adapter): + collection._sa_linker(collection_adapter) + + @event.listens_for(self, "dispose_collection") + def unlink(target, collection, collection_adapter): + collection._sa_linker(None) + + def __copy(self, item): + return [y for y in collections.collection_adapter(item)] + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_OFF, + ) -> History: + current = self.get(state, dict_, passive=passive) + + if current is PASSIVE_NO_RESULT: + if ( + passive & PassiveFlag.INCLUDE_PENDING_MUTATIONS + and self.key in state._pending_mutations + ): + pending = state._pending_mutations[self.key] + return pending.merge_with_history(HISTORY_BLANK) + else: + return HISTORY_BLANK + else: + if passive & PassiveFlag.INCLUDE_PENDING_MUTATIONS: + # this collection is loaded / present. should not be any + # pending mutations + assert self.key not in state._pending_mutations + + return History.from_collection(self, state, current) + + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PASSIVE_NO_INITIALIZE, + ) -> _AllPendingType: + # NOTE: passive is ignored here at the moment + + if self.key not in dict_: + return [] + + current = dict_[self.key] + current = getattr(current, "_sa_adapter") + + if self.key in state.committed_state: + original = state.committed_state[self.key] + if original is not NO_VALUE: + current_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in current + ] + original_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in original + ] + + current_set = dict(current_states) + original_set = dict(original_states) + + return ( + [ + (s, o) + for s, o in current_states + if s not in original_set + ] + + [(s, o) for s, o in current_states if s in original_set] + + [ + (s, o) + for s, o in original_states + if s not in current_set + ] + ) + + return [(instance_state(o), o) for o in current] + + def fire_append_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> _T: + for fn in self.dispatch.append: + value = fn(state, value, initiator or self._append_token, key=key) + + state._modified_event(dict_, self, NO_VALUE, True) + + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), state, True) + + return value + + def fire_append_wo_mutation_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: _T, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> _T: + for fn in self.dispatch.append_wo_mutation: + value = fn(state, value, initiator or self._append_token, key=key) + + return value + + def fire_pre_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> None: + """A special event used for pop() operations. + + The "remove" event needs to have the item to be removed passed to + it, which in the case of pop from a set, we don't have a way to access + the item before the operation. the event is used for all pop() + operations (even though set.pop is the one where it is really needed). + + """ + state._modified_event(dict_, self, NO_VALUE, True) + + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + key: Optional[Any], + ) -> None: + if self.trackparent and value is not None: + self.sethasparent(instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token, key=key) + + state._modified_event(dict_, self, NO_VALUE, True) + + def delete(self, state: InstanceState[Any], dict_: _InstanceDict) -> None: + if self.key not in dict_: + return + + state._modified_event(dict_, self, NO_VALUE, True) + + collection = self.get_collection(state, state.dict) + collection.clear_with_event() + + # key is always present because we checked above. e.g. + # del is a no-op if collection not present. + del dict_[self.key] + + def _default_value( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> _AdaptedCollectionProtocol: + """Produce an empty collection for an un-initialized attribute""" + + assert self.key not in dict_, ( + "_default_value should only be invoked for an " + "uninitialized or expired attribute" + ) + + if self.key in state._empty_collections: + return state._empty_collections[self.key] + + adapter, user_data = self._initialize_collection(state) + adapter._set_empty(user_data) + return user_data + + def _initialize_collection( + self, state: InstanceState[Any] + ) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]: + adapter, collection = state.manager.initialize_collection( + self.key, state, self.collection_factory + ) + + self.dispatch.init_collection(state, collection, adapter) + + return adapter, collection + + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + collection = self.get_collection( + state, dict_, user_data=None, passive=passive + ) + if collection is PASSIVE_NO_RESULT: + value = self.fire_append_event( + state, dict_, value, initiator, key=NO_KEY + ) + assert ( + self.key not in dict_ + ), "Collection was loaded during event handling." + state._get_pending_mutation(self.key).append(value) + else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) + collection.append_with_event(value, initiator) + + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + collection = self.get_collection( + state, state.dict, user_data=None, passive=passive + ) + if collection is PASSIVE_NO_RESULT: + self.fire_remove_event(state, dict_, value, initiator, key=NO_KEY) + assert ( + self.key not in dict_ + ), "Collection was loaded during event handling." + state._get_pending_mutation(self.key).remove(value) + else: + if TYPE_CHECKING: + assert isinstance(collection, CollectionAdapter) + collection.remove_with_event(value, initiator) + + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PASSIVE_OFF, + ) -> None: + try: + # TODO: better solution here would be to add + # a "popper" role to collections.py to complement + # "remover". + self.remove(state, dict_, value, initiator, passive=passive) + except (ValueError, KeyError, IndexError): + pass + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: + iterable = orig_iterable = value + new_keys = None + + # pulling a new collection first so that an adaptation exception does + # not trigger a lazy load of the old collection. + new_collection, user_data = self._initialize_collection(state) + if _adapt: + if new_collection._converter is not None: + iterable = new_collection._converter(iterable) + else: + setting_type = util.duck_type_collection(iterable) + receiving_type = self._duck_typed_as + + if setting_type is not receiving_type: + given = ( + iterable is None + and "None" + or iterable.__class__.__name__ + ) + wanted = self._duck_typed_as.__name__ + raise TypeError( + "Incompatible collection type: %s is not %s-like" + % (given, wanted) + ) + + # If the object is an adapted collection, return the (iterable) + # adapter. + if hasattr(iterable, "_sa_iterator"): + iterable = iterable._sa_iterator() + elif setting_type is dict: + new_keys = list(iterable) + iterable = iterable.values() + else: + iterable = iter(iterable) + elif util.duck_type_collection(iterable) is dict: + new_keys = list(value) + + new_values = list(iterable) + + evt = self._bulk_replace_token + + self.dispatch.bulk_replace(state, new_values, evt, keys=new_keys) + + # propagate NO_RAISE in passive through to the get() for the + # existing object (ticket #8862) + old = self.get( + state, + dict_, + passive=PASSIVE_ONLY_PERSISTENT ^ (passive & PassiveFlag.NO_RAISE), + ) + if old is PASSIVE_NO_RESULT: + old = self._default_value(state, dict_) + elif old is orig_iterable: + # ignore re-assignment of the current collection, as happens + # implicitly with in-place operators (foo.collection |= other) + return + + # place a copy of "old" in state.committed_state + state._modified_event(dict_, self, old, True) + + old_collection = old._sa_adapter + + dict_[self.key] = user_data + + collections.bulk_replace( + new_values, old_collection, new_collection, initiator=evt + ) + + self._dispose_previous_collection(state, old, old_collection, True) + + def _dispose_previous_collection( + self, + state: InstanceState[Any], + collection: _AdaptedCollectionProtocol, + adapter: CollectionAdapter, + fire_event: bool, + ) -> None: + del collection._sa_adapter + + # discarding old collection make sure it is not referenced in empty + # collections. + state._empty_collections.pop(self.key, None) + if fire_event: + self.dispatch.dispose_collection(state, collection, adapter) + + def _invalidate_collection( + self, collection: _AdaptedCollectionProtocol + ) -> None: + adapter = getattr(collection, "_sa_adapter") + adapter.invalidated = True + + def set_committed_value( + self, state: InstanceState[Any], dict_: _InstanceDict, value: Any + ) -> _AdaptedCollectionProtocol: + """Set an attribute value on the given instance and 'commit' it.""" + + collection, user_data = self._initialize_collection(state) + + if value: + collection.append_multiple_without_event(value) + + state.dict[self.key] = user_data + + state._commit(dict_, [self.key]) + + if self.key in state._pending_mutations: + # pending items exist. issue a modified event, + # add/remove new items. + state._modified_event(dict_, self, user_data, True) + + pending = state._pending_mutations.pop(self.key) + added = pending.added_items + removed = pending.deleted_items + for item in added: + collection.append_without_event(item) + for item in removed: + collection.remove_without_event(item) + + return user_data + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Literal[None] = ..., + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + """Retrieve the CollectionAdapter associated with the given state. + + if user_data is None, retrieves it from the state using normal + "get()" rules, which will fire lazy callables or return the "empty" + collection value. + + """ + if user_data is None: + fetch_user_data = self.get(state, dict_, passive=passive) + if fetch_user_data is LoaderCallableStatus.PASSIVE_NO_RESULT: + return fetch_user_data + else: + user_data = cast("_AdaptedCollectionProtocol", fetch_user_data) + + return user_data._sa_adapter + + +def backref_listeners( + attribute: QueryableAttribute[Any], key: str, uselist: bool +) -> None: + """Apply listeners to synchronize a two-way relationship.""" + + # use easily recognizable names for stack traces. + + # in the sections marked "tokens to test for a recursive loop", + # this is somewhat brittle and very performance-sensitive logic + # that is specific to how we might arrive at each event. a marker + # that can target us directly to arguments being invoked against + # the impl might be simpler, but could interfere with other systems. + + parent_token = attribute.impl.parent_token + parent_impl = attribute.impl + + def _acceptable_key_err(child_state, initiator, child_impl): + raise ValueError( + "Bidirectional attribute conflict detected: " + 'Passing object %s to attribute "%s" ' + 'triggers a modify event on attribute "%s" ' + 'via the backref "%s".' + % ( + state_str(child_state), + initiator.parent_token, + child_impl.parent_token, + attribute.impl.parent_token, + ) + ) + + def emit_backref_from_scalar_set_event( + state, child, oldchild, initiator, **kw + ): + if oldchild is child: + return child + if ( + oldchild is not None + and oldchild is not PASSIVE_NO_RESULT + and oldchild is not NO_VALUE + ): + # With lazy=None, there's no guarantee that the full collection is + # present when updating via a backref. + old_state, old_dict = ( + instance_state(oldchild), + instance_dict(oldchild), + ) + impl = old_state.manager[key].impl + + # tokens to test for a recursive loop. + if not impl.collection and not impl.dynamic: + check_recursive_token = impl._replace_token + else: + check_recursive_token = impl._remove_token + + if initiator is not check_recursive_token: + impl.pop( + old_state, + old_dict, + state.obj(), + parent_impl._append_token, + passive=PASSIVE_NO_FETCH, + ) + + if child is not None: + child_state, child_dict = ( + instance_state(child), + instance_dict(child), + ) + child_impl = child_state.manager[key].impl + + if ( + initiator.parent_token is not parent_token + and initiator.parent_token is not child_impl.parent_token + ): + _acceptable_key_err(state, initiator, child_impl) + + # tokens to test for a recursive loop. + check_append_token = child_impl._append_token + check_bulk_replace_token = ( + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None + ) + + if ( + initiator is not check_append_token + and initiator is not check_bulk_replace_token + ): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH, + ) + return child + + def emit_backref_from_collection_append_event( + state, child, initiator, **kw + ): + if child is None: + return + + child_state, child_dict = instance_state(child), instance_dict(child) + child_impl = child_state.manager[key].impl + + if ( + initiator.parent_token is not parent_token + and initiator.parent_token is not child_impl.parent_token + ): + _acceptable_key_err(state, initiator, child_impl) + + # tokens to test for a recursive loop. + check_append_token = child_impl._append_token + check_bulk_replace_token = ( + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None + ) + + if ( + initiator is not check_append_token + and initiator is not check_bulk_replace_token + ): + child_impl.append( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH, + ) + return child + + def emit_backref_from_collection_remove_event( + state, child, initiator, **kw + ): + if ( + child is not None + and child is not PASSIVE_NO_RESULT + and child is not NO_VALUE + ): + child_state, child_dict = ( + instance_state(child), + instance_dict(child), + ) + child_impl = child_state.manager[key].impl + + check_replace_token: Optional[AttributeEventToken] + + # tokens to test for a recursive loop. + if not child_impl.collection and not child_impl.dynamic: + check_remove_token = child_impl._remove_token + check_replace_token = child_impl._replace_token + check_for_dupes_on_remove = uselist and not parent_impl.dynamic + else: + check_remove_token = child_impl._remove_token + check_replace_token = ( + child_impl._bulk_replace_token + if _is_collection_attribute_impl(child_impl) + else None + ) + check_for_dupes_on_remove = False + + if ( + initiator is not check_remove_token + and initiator is not check_replace_token + ): + if not check_for_dupes_on_remove or not util.has_dupes( + # when this event is called, the item is usually + # present in the list, except for a pop() operation. + state.dict[parent_impl.key], + child, + ): + child_impl.pop( + child_state, + child_dict, + state.obj(), + initiator, + passive=PASSIVE_NO_FETCH, + ) + + if uselist: + event.listen( + attribute, + "append", + emit_backref_from_collection_append_event, + retval=True, + raw=True, + include_key=True, + ) + else: + event.listen( + attribute, + "set", + emit_backref_from_scalar_set_event, + retval=True, + raw=True, + include_key=True, + ) + # TODO: need coverage in test/orm/ of remove event + event.listen( + attribute, + "remove", + emit_backref_from_collection_remove_event, + retval=True, + raw=True, + include_key=True, + ) + + +_NO_HISTORY = util.symbol("NO_HISTORY") +_NO_STATE_SYMBOLS = frozenset([id(PASSIVE_NO_RESULT), id(NO_VALUE)]) + + +class History(NamedTuple): + """A 3-tuple of added, unchanged and deleted values, + representing the changes which have occurred on an instrumented + attribute. + + The easiest way to get a :class:`.History` object for a particular + attribute on an object is to use the :func:`_sa.inspect` function:: + + from sqlalchemy import inspect + + hist = inspect(myobject).attrs.myattribute.history + + Each tuple member is an iterable sequence: + + * ``added`` - the collection of items added to the attribute (the first + tuple element). + + * ``unchanged`` - the collection of items that have not changed on the + attribute (the second tuple element). + + * ``deleted`` - the collection of items that have been removed from the + attribute (the third tuple element). + + """ + + added: Union[Tuple[()], List[Any]] + unchanged: Union[Tuple[()], List[Any]] + deleted: Union[Tuple[()], List[Any]] + + def __bool__(self) -> bool: + return self != HISTORY_BLANK + + def empty(self) -> bool: + """Return True if this :class:`.History` has no changes + and no existing, unchanged state. + + """ + + return not bool((self.added or self.deleted) or self.unchanged) + + def sum(self) -> Sequence[Any]: + """Return a collection of added + unchanged + deleted.""" + + return ( + (self.added or []) + (self.unchanged or []) + (self.deleted or []) + ) + + def non_deleted(self) -> Sequence[Any]: + """Return a collection of added + unchanged.""" + + return (self.added or []) + (self.unchanged or []) + + def non_added(self) -> Sequence[Any]: + """Return a collection of unchanged + deleted.""" + + return (self.unchanged or []) + (self.deleted or []) + + def has_changes(self) -> bool: + """Return True if this :class:`.History` has changes.""" + + return bool(self.added or self.deleted) + + def _merge(self, added: Iterable[Any], deleted: Iterable[Any]) -> History: + return History( + list(self.added) + list(added), + self.unchanged, + list(self.deleted) + list(deleted), + ) + + def as_state(self) -> History: + return History( + [ + (c is not None) and instance_state(c) or None + for c in self.added + ], + [ + (c is not None) and instance_state(c) or None + for c in self.unchanged + ], + [ + (c is not None) and instance_state(c) or None + for c in self.deleted + ], + ) + + @classmethod + def from_scalar_attribute( + cls, + attribute: ScalarAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + deleted: Union[Tuple[()], List[Any]] + + if original is _NO_HISTORY: + if current is NO_VALUE: + return cls((), (), ()) + else: + return cls((), [current], ()) + # don't let ClauseElement expressions here trip things up + elif ( + current is not NO_VALUE + and attribute.is_equal(current, original) is True + ): + return cls((), [current], ()) + else: + # current convention on native scalars is to not + # include information + # about missing previous value in "deleted", but + # we do include None, which helps in some primary + # key situations + if id(original) in _NO_STATE_SYMBOLS: + deleted = () + # indicate a "del" operation occurred when we don't have + # the previous value as: ([None], (), ()) + if id(current) in _NO_STATE_SYMBOLS: + current = None + else: + deleted = [original] + if current is NO_VALUE: + return cls((), (), deleted) + else: + return cls([current], (), deleted) + + @classmethod + def from_object_attribute( + cls, + attribute: ScalarObjectAttributeImpl, + state: InstanceState[Any], + current: Any, + original: Any = _NO_HISTORY, + ) -> History: + deleted: Union[Tuple[()], List[Any]] + + if original is _NO_HISTORY: + original = state.committed_state.get(attribute.key, _NO_HISTORY) + + if original is _NO_HISTORY: + if current is NO_VALUE: + return cls((), (), ()) + else: + return cls((), [current], ()) + elif current is original and current is not NO_VALUE: + return cls((), [current], ()) + else: + # current convention on related objects is to not + # include information + # about missing previous value in "deleted", and + # to also not include None - the dependency.py rules + # ignore the None in any case. + if id(original) in _NO_STATE_SYMBOLS or original is None: + deleted = () + # indicate a "del" operation occurred when we don't have + # the previous value as: ([None], (), ()) + if id(current) in _NO_STATE_SYMBOLS: + current = None + else: + deleted = [original] + if current is NO_VALUE: + return cls((), (), deleted) + else: + return cls([current], (), deleted) + + @classmethod + def from_collection( + cls, + attribute: CollectionAttributeImpl, + state: InstanceState[Any], + current: Any, + ) -> History: + original = state.committed_state.get(attribute.key, _NO_HISTORY) + if current is NO_VALUE: + return cls((), (), ()) + + current = getattr(current, "_sa_adapter") + if original is NO_VALUE: + return cls(list(current), (), ()) + elif original is _NO_HISTORY: + return cls((), list(current), ()) + else: + current_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in current + ] + original_states = [ + ((c is not None) and instance_state(c) or None, c) + for c in original + ] + + current_set = dict(current_states) + original_set = dict(original_states) + + return cls( + [o for s, o in current_states if s not in original_set], + [o for s, o in current_states if s in original_set], + [o for s, o in original_states if s not in current_set], + ) + + +HISTORY_BLANK = History((), (), ()) + + +def get_history( + obj: object, key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: + """Return a :class:`.History` record for the given object + and attribute key. + + This is the **pre-flush** history for a given attribute, which is + reset each time the :class:`.Session` flushes changes to the + current database transaction. + + .. note:: + + Prefer to use the :attr:`.AttributeState.history` and + :meth:`.AttributeState.load_history` accessors to retrieve the + :class:`.History` for instance attributes. + + + :param obj: an object whose class is instrumented by the + attributes package. + + :param key: string attribute name. + + :param passive: indicates loading behavior for the attribute + if the value is not already present. This is a + bitflag attribute, which defaults to the symbol + :attr:`.PASSIVE_OFF` indicating all necessary SQL + should be emitted. + + .. seealso:: + + :attr:`.AttributeState.history` + + :meth:`.AttributeState.load_history` - retrieve history + using loader callables if the value is not locally present. + + """ + + return get_state_history(instance_state(obj), key, passive) + + +def get_state_history( + state: InstanceState[Any], key: str, passive: PassiveFlag = PASSIVE_OFF +) -> History: + return state.get_history(key, passive) + + +def has_parent( + cls: Type[_O], obj: _O, key: str, optimistic: bool = False +) -> bool: + """TODO""" + manager = manager_of_class(cls) + state = instance_state(obj) + return manager.has_parent(state, key, optimistic) + + +def register_attribute( + class_: Type[_O], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[_O], + doc: Optional[str] = None, + **kw: Any, +) -> InstrumentedAttribute[_T]: + desc = register_descriptor( + class_, key, comparator=comparator, parententity=parententity, doc=doc + ) + register_attribute_impl(class_, key, **kw) + return desc + + +def register_attribute_impl( + class_: Type[_O], + key: str, + uselist: bool = False, + callable_: Optional[_LoaderCallable] = None, + useobject: bool = False, + impl_class: Optional[Type[AttributeImpl]] = None, + backref: Optional[str] = None, + **kw: Any, +) -> QueryableAttribute[Any]: + manager = manager_of_class(class_) + if uselist: + factory = kw.pop("typecallable", None) + typecallable = manager.instrument_collection_class( + key, factory or list + ) + else: + typecallable = kw.pop("typecallable", None) + + dispatch = cast( + "_Dispatch[QueryableAttribute[Any]]", manager[key].dispatch + ) # noqa: E501 + + impl: AttributeImpl + + if impl_class: + # TODO: this appears to be the WriteOnlyAttributeImpl / + # DynamicAttributeImpl constructor which is hardcoded + impl = cast("Type[WriteOnlyAttributeImpl]", impl_class)( + class_, key, dispatch, **kw + ) + elif uselist: + impl = CollectionAttributeImpl( + class_, key, callable_, dispatch, typecallable=typecallable, **kw + ) + elif useobject: + impl = ScalarObjectAttributeImpl( + class_, key, callable_, dispatch, **kw + ) + else: + impl = ScalarAttributeImpl(class_, key, callable_, dispatch, **kw) + + manager[key].impl = impl + + if backref: + backref_listeners(manager[key], backref, uselist) + + manager.post_configure_attribute(key) + return manager[key] + + +def register_descriptor( + class_: Type[Any], + key: str, + *, + comparator: interfaces.PropComparator[_T], + parententity: _InternalEntityType[Any], + doc: Optional[str] = None, +) -> InstrumentedAttribute[_T]: + manager = manager_of_class(class_) + + descriptor = InstrumentedAttribute( + class_, key, comparator=comparator, parententity=parententity + ) + + descriptor.__doc__ = doc # type: ignore + + manager.instrument_attribute(key, descriptor) + return descriptor + + +def unregister_attribute(class_: Type[Any], key: str) -> None: + manager_of_class(class_).uninstrument_attribute(key) + + +def init_collection(obj: object, key: str) -> CollectionAdapter: + """Initialize a collection attribute and return the collection adapter. + + This function is used to provide direct access to collection internals + for a previously unloaded attribute. e.g.:: + + collection_adapter = init_collection(someobject, 'elements') + for elem in values: + collection_adapter.append_without_event(elem) + + For an easier way to do the above, see + :func:`~sqlalchemy.orm.attributes.set_committed_value`. + + :param obj: a mapped object + + :param key: string attribute name where the collection is located. + + """ + state = instance_state(obj) + dict_ = state.dict + return init_state_collection(state, dict_, key) + + +def init_state_collection( + state: InstanceState[Any], dict_: _InstanceDict, key: str +) -> CollectionAdapter: + """Initialize a collection attribute and return the collection adapter. + + Discards any existing collection which may be there. + + """ + attr = state.manager[key].impl + + if TYPE_CHECKING: + assert isinstance(attr, HasCollectionAdapter) + + old = dict_.pop(key, None) # discard old collection + if old is not None: + old_collection = old._sa_adapter + attr._dispose_previous_collection(state, old, old_collection, False) + + user_data = attr._default_value(state, dict_) + adapter: CollectionAdapter = attr.get_collection( + state, dict_, user_data, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + adapter._reset_empty() + + return adapter + + +def set_committed_value(instance, key, value): + """Set the value of an attribute with no history events. + + Cancels any previous history present. The value should be + a scalar value for scalar-holding attributes, or + an iterable for any collection-holding attribute. + + This is the same underlying method used when a lazy loader + fires off and loads additional data from the database. + In particular, this method can be used by application code + which has loaded additional attributes or collections through + separate queries, which can then be attached to an instance + as though it were part of its original loaded state. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.set_committed_value(state, dict_, value) + + +def set_attribute( + instance: object, + key: str, + value: Any, + initiator: Optional[AttributeEventToken] = None, +) -> None: + """Set the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + :param instance: the object that will be modified + + :param key: string name of the attribute + + :param value: value to assign + + :param initiator: an instance of :class:`.Event` that would have + been propagated from a previous event listener. This argument + is used when the :func:`.set_attribute` function is being used within + an existing event listening function where an :class:`.Event` object + is being supplied; the object may be used to track the origin of the + chain of events. + + .. versionadded:: 1.2.3 + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.set(state, dict_, value, initiator) + + +def get_attribute(instance: object, key: str) -> Any: + """Get the value of an attribute, firing any callables required. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to make usage of attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + return state.manager[key].impl.get(state, dict_) + + +def del_attribute(instance: object, key: str) -> None: + """Delete the value of an attribute, firing history events. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + Custom attribute management schemes will need to make usage + of this method to establish attribute state as understood + by SQLAlchemy. + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + state.manager[key].impl.delete(state, dict_) + + +def flag_modified(instance: object, key: str) -> None: + """Mark an attribute on an instance as 'modified'. + + This sets the 'modified' flag on the instance and + establishes an unconditional change event for the given attribute. + The attribute must have a value present, else an + :class:`.InvalidRequestError` is raised. + + To mark an object "dirty" without referring to any specific attribute + so that it is considered within a flush, use the + :func:`.attributes.flag_dirty` call. + + .. seealso:: + + :func:`.attributes.flag_dirty` + + """ + state, dict_ = instance_state(instance), instance_dict(instance) + impl = state.manager[key].impl + impl.dispatch.modified(state, impl._modified_token) + state._modified_event(dict_, impl, NO_VALUE, is_userland=True) + + +def flag_dirty(instance: object) -> None: + """Mark an instance as 'dirty' without any specific attribute mentioned. + + This is a special operation that will allow the object to travel through + the flush process for interception by events such as + :meth:`.SessionEvents.before_flush`. Note that no SQL will be emitted in + the flush process for an object that has no changes, even if marked dirty + via this method. However, a :meth:`.SessionEvents.before_flush` handler + will be able to see the object in the :attr:`.Session.dirty` collection and + may establish changes on it, which will then be included in the SQL + emitted. + + .. versionadded:: 1.2 + + .. seealso:: + + :func:`.attributes.flag_modified` + + """ + + state, dict_ = instance_state(instance), instance_dict(instance) + state._modified_event(dict_, None, NO_VALUE, is_userland=True) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/base.py new file mode 100644 index 0000000..c900529 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/base.py @@ -0,0 +1,971 @@ +# orm/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Constants and rudimental functions used throughout the ORM. + +""" + +from __future__ import annotations + +from enum import Enum +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import no_type_check +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import exc +from ._typing import insp_is_mapper +from .. import exc as sa_exc +from .. import inspection +from .. import util +from ..sql import roles +from ..sql.elements import SQLColumnExpression +from ..sql.elements import SQLCoreOperations +from ..util import FastIntFlag +from ..util.langhelpers import TypingOnly +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InternalEntityType + from .attributes import InstrumentedAttribute + from .dynamic import AppenderQuery + from .instrumentation import ClassManager + from .interfaces import PropComparator + from .mapper import Mapper + from .state import InstanceState + from .util import AliasedClass + from .writeonly import WriteOnlyCollection + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement + from ..sql.operators import OperatorType + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + +_O = TypeVar("_O", bound=object) + + +class LoaderCallableStatus(Enum): + PASSIVE_NO_RESULT = 0 + """Symbol returned by a loader callable or other attribute/history + retrieval operation when a value could not be determined, based + on loader callable flags. + """ + + PASSIVE_CLASS_MISMATCH = 1 + """Symbol indicating that an object is locally present for a given + primary key identity but it is not of the requested class. The + return value is therefore None and no SQL should be emitted.""" + + ATTR_WAS_SET = 2 + """Symbol returned by a loader callable to indicate the + retrieved value, or values, were assigned to their attributes + on the target object. + """ + + ATTR_EMPTY = 3 + """Symbol used internally to indicate an attribute had no callable.""" + + NO_VALUE = 4 + """Symbol which may be placed as the 'previous' value of an attribute, + indicating no value was loaded for an attribute when it was modified, + and flags indicated we were not to load it. + """ + + NEVER_SET = NO_VALUE + """ + Synonymous with NO_VALUE + + .. versionchanged:: 1.4 NEVER_SET was merged with NO_VALUE + + """ + + +( + PASSIVE_NO_RESULT, + PASSIVE_CLASS_MISMATCH, + ATTR_WAS_SET, + ATTR_EMPTY, + NO_VALUE, +) = tuple(LoaderCallableStatus) + +NEVER_SET = NO_VALUE + + +class PassiveFlag(FastIntFlag): + """Bitflag interface that passes options onto loader callables""" + + NO_CHANGE = 0 + """No callables or SQL should be emitted on attribute access + and no state should change + """ + + CALLABLES_OK = 1 + """Loader callables can be fired off if a value + is not present. + """ + + SQL_OK = 2 + """Loader callables can emit SQL at least on scalar value attributes.""" + + RELATED_OBJECT_OK = 4 + """Callables can use SQL to load related objects as well + as scalar value attributes. + """ + + INIT_OK = 8 + """Attributes should be initialized with a blank + value (None or an empty collection) upon get, if no other + value can be obtained. + """ + + NON_PERSISTENT_OK = 16 + """Callables can be emitted if the parent is not persistent.""" + + LOAD_AGAINST_COMMITTED = 32 + """Callables should use committed values as primary/foreign keys during a + load. + """ + + NO_AUTOFLUSH = 64 + """Loader callables should disable autoflush.""", + + NO_RAISE = 128 + """Loader callables should not raise any assertions""" + + DEFERRED_HISTORY_LOAD = 256 + """indicates special load of the previous value of an attribute""" + + INCLUDE_PENDING_MUTATIONS = 512 + + # pre-packaged sets of flags used as inputs + PASSIVE_OFF = ( + RELATED_OBJECT_OK | NON_PERSISTENT_OK | INIT_OK | CALLABLES_OK | SQL_OK + ) + "Callables can be emitted in all cases." + + PASSIVE_RETURN_NO_VALUE = PASSIVE_OFF ^ INIT_OK + """PASSIVE_OFF ^ INIT_OK""" + + PASSIVE_NO_INITIALIZE = PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK + "PASSIVE_RETURN_NO_VALUE ^ CALLABLES_OK" + + PASSIVE_NO_FETCH = PASSIVE_OFF ^ SQL_OK + "PASSIVE_OFF ^ SQL_OK" + + PASSIVE_NO_FETCH_RELATED = PASSIVE_OFF ^ RELATED_OBJECT_OK + "PASSIVE_OFF ^ RELATED_OBJECT_OK" + + PASSIVE_ONLY_PERSISTENT = PASSIVE_OFF ^ NON_PERSISTENT_OK + "PASSIVE_OFF ^ NON_PERSISTENT_OK" + + PASSIVE_MERGE = PASSIVE_OFF | NO_RAISE + """PASSIVE_OFF | NO_RAISE + + Symbol used specifically for session.merge() and similar cases + + """ + + +( + NO_CHANGE, + CALLABLES_OK, + SQL_OK, + RELATED_OBJECT_OK, + INIT_OK, + NON_PERSISTENT_OK, + LOAD_AGAINST_COMMITTED, + NO_AUTOFLUSH, + NO_RAISE, + DEFERRED_HISTORY_LOAD, + INCLUDE_PENDING_MUTATIONS, + PASSIVE_OFF, + PASSIVE_RETURN_NO_VALUE, + PASSIVE_NO_INITIALIZE, + PASSIVE_NO_FETCH, + PASSIVE_NO_FETCH_RELATED, + PASSIVE_ONLY_PERSISTENT, + PASSIVE_MERGE, +) = PassiveFlag.__members__.values() + +DEFAULT_MANAGER_ATTR = "_sa_class_manager" +DEFAULT_STATE_ATTR = "_sa_instance_state" + + +class EventConstants(Enum): + EXT_CONTINUE = 1 + EXT_STOP = 2 + EXT_SKIP = 3 + NO_KEY = 4 + """indicates an :class:`.AttributeEvent` event that did not have any + key argument. + + .. versionadded:: 2.0 + + """ + + +EXT_CONTINUE, EXT_STOP, EXT_SKIP, NO_KEY = tuple(EventConstants) + + +class RelationshipDirection(Enum): + """enumeration which indicates the 'direction' of a + :class:`_orm.RelationshipProperty`. + + :class:`.RelationshipDirection` is accessible from the + :attr:`_orm.Relationship.direction` attribute of + :class:`_orm.RelationshipProperty`. + + """ + + ONETOMANY = 1 + """Indicates the one-to-many direction for a :func:`_orm.relationship`. + + This symbol is typically used by the internals but may be exposed within + certain API features. + + """ + + MANYTOONE = 2 + """Indicates the many-to-one direction for a :func:`_orm.relationship`. + + This symbol is typically used by the internals but may be exposed within + certain API features. + + """ + + MANYTOMANY = 3 + """Indicates the many-to-many direction for a :func:`_orm.relationship`. + + This symbol is typically used by the internals but may be exposed within + certain API features. + + """ + + +ONETOMANY, MANYTOONE, MANYTOMANY = tuple(RelationshipDirection) + + +class InspectionAttrExtensionType(Enum): + """Symbols indicating the type of extension that a + :class:`.InspectionAttr` is part of.""" + + +class NotExtension(InspectionAttrExtensionType): + NOT_EXTENSION = "not_extension" + """Symbol indicating an :class:`InspectionAttr` that's + not part of sqlalchemy.ext. + + Is assigned to the :attr:`.InspectionAttr.extension_type` + attribute. + + """ + + +_never_set = frozenset([NEVER_SET]) + +_none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT]) + +_SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED") + +_DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") + +_RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE") + + +_F = TypeVar("_F", bound=Callable[..., Any]) +_Self = TypeVar("_Self") + + +def _assertions( + *assertions: Any, +) -> Callable[[_F], _F]: + @util.decorator + def generate(fn: _F, self: _Self, *args: Any, **kw: Any) -> _Self: + for assertion in assertions: + assertion(self, fn.__name__) + fn(self, *args, **kw) + return self + + return generate + + +if TYPE_CHECKING: + + def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: ... + + @overload + def opt_manager_of_class(cls: AliasedClass[Any]) -> None: ... + + @overload + def opt_manager_of_class( + cls: _ExternalEntityType[_O], + ) -> Optional[ClassManager[_O]]: ... + + def opt_manager_of_class( + cls: _ExternalEntityType[_O], + ) -> Optional[ClassManager[_O]]: ... + + def instance_state(instance: _O) -> InstanceState[_O]: ... + + def instance_dict(instance: object) -> Dict[str, Any]: ... + +else: + # these can be replaced by sqlalchemy.ext.instrumentation + # if augmented class instrumentation is enabled. + + def manager_of_class(cls): + try: + return cls.__dict__[DEFAULT_MANAGER_ATTR] + except KeyError as ke: + raise exc.UnmappedClassError( + cls, f"Can't locate an instrumentation manager for class {cls}" + ) from ke + + def opt_manager_of_class(cls): + return cls.__dict__.get(DEFAULT_MANAGER_ATTR) + + instance_state = operator.attrgetter(DEFAULT_STATE_ATTR) + + instance_dict = operator.attrgetter("__dict__") + + +def instance_str(instance: object) -> str: + """Return a string describing an instance.""" + + return state_str(instance_state(instance)) + + +def state_str(state: InstanceState[Any]) -> str: + """Return a string describing an instance via its InstanceState.""" + + if state is None: + return "None" + else: + return "<%s at 0x%x>" % (state.class_.__name__, id(state.obj())) + + +def state_class_str(state: InstanceState[Any]) -> str: + """Return a string describing an instance's class via its + InstanceState. + """ + + if state is None: + return "None" + else: + return "<%s>" % (state.class_.__name__,) + + +def attribute_str(instance: object, attribute: str) -> str: + return instance_str(instance) + "." + attribute + + +def state_attribute_str(state: InstanceState[Any], attribute: str) -> str: + return state_str(state) + "." + attribute + + +def object_mapper(instance: _T) -> Mapper[_T]: + """Given an object, return the primary Mapper associated with the object + instance. + + Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError` + if no mapping is configured. + + This function is available via the inspection system as:: + + inspect(instance).mapper + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is + not part of a mapping. + + """ + return object_state(instance).mapper + + +def object_state(instance: _T) -> InstanceState[_T]: + """Given an object, return the :class:`.InstanceState` + associated with the object. + + Raises :class:`sqlalchemy.orm.exc.UnmappedInstanceError` + if no mapping is configured. + + Equivalent functionality is available via the :func:`_sa.inspect` + function as:: + + inspect(instance) + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the instance is + not part of a mapping. + + """ + state = _inspect_mapped_object(instance) + if state is None: + raise exc.UnmappedInstanceError(instance) + else: + return state + + +@inspection._inspects(object) +def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]: + try: + return instance_state(instance) + except (exc.UnmappedClassError,) + exc.NO_STATE: + return None + + +def _class_to_mapper( + class_or_mapper: Union[Mapper[_T], Type[_T]] +) -> Mapper[_T]: + # can't get mypy to see an overload for this + insp = inspection.inspect(class_or_mapper, False) + if insp is not None: + return insp.mapper # type: ignore + else: + assert isinstance(class_or_mapper, type) + raise exc.UnmappedClassError(class_or_mapper) + + +def _mapper_or_none( + entity: Union[Type[_T], _InternalEntityType[_T]] +) -> Optional[Mapper[_T]]: + """Return the :class:`_orm.Mapper` for the given class or None if the + class is not mapped. + """ + + # can't get mypy to see an overload for this + insp = inspection.inspect(entity, False) + if insp is not None: + return insp.mapper # type: ignore + else: + return None + + +def _is_mapped_class(entity: Any) -> bool: + """Return True if the given object is a mapped class, + :class:`_orm.Mapper`, or :class:`.AliasedClass`. + """ + + insp = inspection.inspect(entity, False) + return ( + insp is not None + and not insp.is_clause_element + and (insp.is_mapper or insp.is_aliased_class) + ) + + +def _is_aliased_class(entity: Any) -> bool: + insp = inspection.inspect(entity, False) + return insp is not None and getattr(insp, "is_aliased_class", False) + + +@no_type_check +def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any: + """Return a class attribute given an entity and string name. + + May return :class:`.InstrumentedAttribute` or user-defined + attribute. + + """ + insp = inspection.inspect(entity) + if insp.is_selectable: + description = entity + entity = insp.c + elif insp.is_aliased_class: + entity = insp.entity + description = entity + elif hasattr(insp, "mapper"): + description = entity = insp.mapper.class_ + else: + description = entity + + try: + return getattr(entity, key) + except AttributeError as err: + raise sa_exc.InvalidRequestError( + "Entity '%s' has no property '%s'" % (description, key) + ) from err + + +if TYPE_CHECKING: + + def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: ... + +else: + _state_mapper = util.dottedgetter("manager.mapper") + + +def _inspect_mapped_class( + class_: Type[_O], configure: bool = False +) -> Optional[Mapper[_O]]: + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + return None + else: + if configure: + mapper._check_configure() + return mapper + + +def _parse_mapper_argument(arg: Union[Mapper[_O], Type[_O]]) -> Mapper[_O]: + insp = inspection.inspect(arg, raiseerr=False) + if insp_is_mapper(insp): + return insp + + raise sa_exc.ArgumentError(f"Mapper or mapped class expected, got {arg!r}") + + +def class_mapper(class_: Type[_O], configure: bool = True) -> Mapper[_O]: + """Given a class, return the primary :class:`_orm.Mapper` associated + with the key. + + Raises :exc:`.UnmappedClassError` if no mapping is configured + on the given class, or :exc:`.ArgumentError` if a non-class + object is passed. + + Equivalent functionality is available via the :func:`_sa.inspect` + function as:: + + inspect(some_mapped_class) + + Using the inspection system will raise + :class:`sqlalchemy.exc.NoInspectionAvailable` if the class is not mapped. + + """ + mapper = _inspect_mapped_class(class_, configure=configure) + if mapper is None: + if not isinstance(class_, type): + raise sa_exc.ArgumentError( + "Class object expected, got '%r'." % (class_,) + ) + raise exc.UnmappedClassError(class_) + else: + return mapper + + +class InspectionAttr: + """A base class applied to all ORM objects and attributes that are + related to things that can be returned by the :func:`_sa.inspect` function. + + The attributes defined here allow the usage of simple boolean + checks to test basic facts about the object returned. + + While the boolean checks here are basically the same as using + the Python isinstance() function, the flags here can be used without + the need to import all of these classes, and also such that + the SQLAlchemy class system can change while leaving the flags + here intact for forwards-compatibility. + + """ + + __slots__: Tuple[str, ...] = () + + is_selectable = False + """Return True if this object is an instance of + :class:`_expression.Selectable`.""" + + is_aliased_class = False + """True if this object is an instance of :class:`.AliasedClass`.""" + + is_instance = False + """True if this object is an instance of :class:`.InstanceState`.""" + + is_mapper = False + """True if this object is an instance of :class:`_orm.Mapper`.""" + + is_bundle = False + """True if this object is an instance of :class:`.Bundle`.""" + + is_property = False + """True if this object is an instance of :class:`.MapperProperty`.""" + + is_attribute = False + """True if this object is a Python :term:`descriptor`. + + This can refer to one of many types. Usually a + :class:`.QueryableAttribute` which handles attributes events on behalf + of a :class:`.MapperProperty`. But can also be an extension type + such as :class:`.AssociationProxy` or :class:`.hybrid_property`. + The :attr:`.InspectionAttr.extension_type` will refer to a constant + identifying the specific subtype. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_descriptors` + + """ + + _is_internal_proxy = False + """True if this object is an internal proxy object. + + .. versionadded:: 1.2.12 + + """ + + is_clause_element = False + """True if this object is an instance of + :class:`_expression.ClauseElement`.""" + + extension_type: InspectionAttrExtensionType = NotExtension.NOT_EXTENSION + """The extension type, if any. + Defaults to :attr:`.interfaces.NotExtension.NOT_EXTENSION` + + .. seealso:: + + :class:`.HybridExtensionType` + + :class:`.AssociationProxyExtensionType` + + """ + + +class InspectionAttrInfo(InspectionAttr): + """Adds the ``.info`` attribute to :class:`.InspectionAttr`. + + The rationale for :class:`.InspectionAttr` vs. :class:`.InspectionAttrInfo` + is that the former is compatible as a mixin for classes that specify + ``__slots__``; this is essentially an implementation artifact. + + """ + + __slots__ = () + + @util.ro_memoized_property + def info(self) -> _InfoType: + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`_orm.relationship`, or + :func:`.composite` + functions. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + return {} + + +class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly): + __slots__ = () + + if typing.TYPE_CHECKING: + + def of_type( + self, class_: _EntityType[Any] + ) -> PropComparator[_T_co]: ... + + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[bool]: ... + + def any( # noqa: A001 + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: ... + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: ... + + +class ORMDescriptor(Generic[_T_co], TypingOnly): + """Represent any Python descriptor that provides a SQL expression + construct at the class level.""" + + __slots__ = () + + if typing.TYPE_CHECKING: + + @overload + def __get__( + self, instance: Any, owner: Literal[None] + ) -> ORMDescriptor[_T_co]: ... + + @overload + def __get__( + self, instance: Literal[None], owner: Any + ) -> SQLCoreOperations[_T_co]: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T_co: ... + + def __get__( + self, instance: object, owner: Any + ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: ... + + +class _MappedAnnotationBase(Generic[_T_co], TypingOnly): + """common class for Mapped and similar ORM container classes. + + these are classes that can appear on the left side of an ORM declarative + mapping, containing a mapped class or in some cases a collection + surrounding a mapped class. + + """ + + __slots__ = () + + +class SQLORMExpression( + SQLORMOperations[_T_co], SQLColumnExpression[_T_co], TypingOnly +): + """A type that may be used to indicate any ORM-level attribute or + object that acts in place of one, in the context of SQL expression + construction. + + :class:`.SQLORMExpression` extends from the Core + :class:`.SQLColumnExpression` to add additional SQL methods that are ORM + specific, such as :meth:`.PropComparator.of_type`, and is part of the bases + for :class:`.InstrumentedAttribute`. It may be used in :pep:`484` typing to + indicate arguments or return values that should behave as ORM-level + attribute expressions. + + .. versionadded:: 2.0.0b4 + + + """ + + __slots__ = () + + +class Mapped( + SQLORMExpression[_T_co], + ORMDescriptor[_T_co], + _MappedAnnotationBase[_T_co], + roles.DDLConstraintColumnRole, +): + """Represent an ORM mapped attribute on a mapped class. + + This class represents the complete descriptor interface for any class + attribute that will have been :term:`instrumented` by the ORM + :class:`_orm.Mapper` class. Provides appropriate information to type + checkers such as pylance and mypy so that ORM-mapped attributes + are correctly typed. + + The most prominent use of :class:`_orm.Mapped` is in + the :ref:`Declarative Mapping ` form + of :class:`_orm.Mapper` configuration, where used explicitly it drives + the configuration of ORM attributes such as :func:`_orm.mapped_class` + and :func:`_orm.relationship`. + + .. seealso:: + + :ref:`orm_explicit_declarative_base` + + :ref:`orm_declarative_table` + + .. tip:: + + The :class:`_orm.Mapped` class represents attributes that are handled + directly by the :class:`_orm.Mapper` class. It does not include other + Python descriptor classes that are provided as extensions, including + :ref:`hybrids_toplevel` and the :ref:`associationproxy_toplevel`. + While these systems still make use of ORM-specific superclasses + and structures, they are not :term:`instrumented` by the + :class:`_orm.Mapper` and instead provide their own functionality + when they are accessed on a class. + + .. versionadded:: 1.4 + + + """ + + __slots__ = () + + if typing.TYPE_CHECKING: + + @overload + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T_co]: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T_co: ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... + + @classmethod + def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: ... + + def __set__( + self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co] + ) -> None: ... + + def __delete__(self, instance: Any) -> None: ... + + +class _MappedAttribute(Generic[_T_co], TypingOnly): + """Mixin for attributes which should be replaced by mapper-assigned + attributes. + + """ + + __slots__ = () + + +class _DeclarativeMapped(Mapped[_T_co], _MappedAttribute[_T_co]): + """Mixin for :class:`.MapperProperty` subclasses that allows them to + be compatible with ORM-annotated declarative mappings. + + """ + + __slots__ = () + + # MappedSQLExpression, Relationship, Composite etc. dont actually do + # SQL expression behavior. yet there is code that compares them with + # __eq__(), __ne__(), etc. Since #8847 made Mapped even more full + # featured including ColumnOperators, we need to have those methods + # be no-ops for these objects, so return NotImplemented to fall back + # to normal comparison behavior. + def operate(self, op: OperatorType, *other: Any, **kwargs: Any) -> Any: + return NotImplemented + + __sa_operate__ = operate + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> Any: + return NotImplemented + + +class DynamicMapped(_MappedAnnotationBase[_T_co]): + """Represent the ORM mapped attribute type for a "dynamic" relationship. + + The :class:`_orm.DynamicMapped` type annotation may be used in an + :ref:`Annotated Declarative Table ` mapping + to indicate that the ``lazy="dynamic"`` loader strategy should be used + for a particular :func:`_orm.relationship`. + + .. legacy:: The "dynamic" lazy loader strategy is the legacy form of what + is now the "write_only" strategy described in the section + :ref:`write_only_relationship`. + + E.g.:: + + class User(Base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + addresses: DynamicMapped[Address] = relationship( + cascade="all,delete-orphan" + ) + + See the section :ref:`dynamic_relationship` for background. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`dynamic_relationship` - complete background + + :class:`.WriteOnlyMapped` - fully 2.0 style version + + """ + + __slots__ = () + + if TYPE_CHECKING: + + @overload + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T_co]: ... + + @overload + def __get__( + self, instance: object, owner: Any + ) -> AppenderQuery[_T_co]: ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: ... + + def __set__( + self, instance: Any, value: typing.Collection[_T_co] + ) -> None: ... + + +class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): + """Represent the ORM mapped attribute type for a "write only" relationship. + + The :class:`_orm.WriteOnlyMapped` type annotation may be used in an + :ref:`Annotated Declarative Table ` mapping + to indicate that the ``lazy="write_only"`` loader strategy should be used + for a particular :func:`_orm.relationship`. + + E.g.:: + + class User(Base): + __tablename__ = "user" + id: Mapped[int] = mapped_column(primary_key=True) + addresses: WriteOnlyMapped[Address] = relationship( + cascade="all,delete-orphan" + ) + + See the section :ref:`write_only_relationship` for background. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`write_only_relationship` - complete background + + :class:`.DynamicMapped` - includes legacy :class:`_orm.Query` support + + """ + + __slots__ = () + + if TYPE_CHECKING: + + @overload + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T_co]: ... + + @overload + def __get__( + self, instance: object, owner: Any + ) -> WriteOnlyCollection[_T_co]: ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[ + InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co] + ]: ... + + def __set__( + self, instance: Any, value: typing.Collection[_T_co] + ) -> None: ... diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/bulk_persistence.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/bulk_persistence.py new file mode 100644 index 0000000..5d2558d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/bulk_persistence.py @@ -0,0 +1,2048 @@ +# orm/bulk_persistence.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""additional ORM persistence classes related to "bulk" operations, +specifically outside of the flush() process. + +""" + +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import overload +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import context +from . import evaluator +from . import exc as orm_exc +from . import loading +from . import persistence +from .base import NO_VALUE +from .context import AbstractORMCompileState +from .context import FromStatement +from .context import ORMFromStatementCompileState +from .context import QueryContext +from .. import exc as sa_exc +from .. import util +from ..engine import Dialect +from ..engine import result as _result +from ..sql import coercions +from ..sql import dml +from ..sql import expression +from ..sql import roles +from ..sql import select +from ..sql import sqltypes +from ..sql.base import _entity_namespace_key +from ..sql.base import CompileState +from ..sql.base import Options +from ..sql.dml import DeleteDMLState +from ..sql.dml import InsertDMLState +from ..sql.dml import UpdateDMLState +from ..util import EMPTY_DICT +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import DMLStrategyArgument + from ._typing import OrmExecuteOptionsParameter + from ._typing import SynchronizeSessionArgument + from .mapper import Mapper + from .session import _BindArguments + from .session import ORMExecuteState + from .session import Session + from .session import SessionTransaction + from .state import InstanceState + from ..engine import Connection + from ..engine import cursor + from ..engine.interfaces import _CoreAnyExecuteParams + +_O = TypeVar("_O", bound=object) + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Literal[None] = ..., + execution_options: Optional[OrmExecuteOptionsParameter] = ..., +) -> None: ... + + +@overload +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = ..., + execution_options: Optional[OrmExecuteOptionsParameter] = ..., +) -> cursor.CursorResult[Any]: ... + + +def _bulk_insert( + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + return_defaults: bool, + render_nulls: bool, + use_orm_insert_stmt: Optional[dml.Insert] = None, + execution_options: Optional[OrmExecuteOptionsParameter] = None, +) -> Optional[cursor.CursorResult[Any]]: + base_mapper = mapper.base_mapper + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_insert()" + ) + + if isstates: + if return_defaults: + states = [(state, state.dict) for state in mappings] + mappings = [dict_ for (state, dict_) in states] + else: + mappings = [state.dict for state in mappings] + else: + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) + + connection = session_transaction.connection(base_mapper) + + return_result: Optional[cursor.CursorResult[Any]] = None + + mappers_to_run = [ + (table, mp) + for table, mp in base_mapper._sorted_tables.items() + if table in mapper._pks_by_table + ] + + if return_defaults: + # not used by new-style bulk inserts, only used for legacy + bookkeeping = True + elif len(mappers_to_run) > 1: + # if we have more than one table, mapper to run where we will be + # either horizontally splicing, or copying values between tables, + # we need the "bookkeeping" / deterministic returning order + bookkeeping = True + else: + bookkeeping = False + + for table, super_mapper in mappers_to_run: + # find bindparams in the statement. For bulk, we don't really know if + # a key in the params applies to a different table since we are + # potentially inserting for multiple tables here; looking at the + # bindparam() is a lot more direct. in most cases this will + # use _generate_cache_key() which is memoized, although in practice + # the ultimate statement that's executed is probably not the same + # object so that memoization might not matter much. + extra_bp_names = ( + [ + b.key + for b in use_orm_insert_stmt._get_embedded_bindparams() + if b.key in mappings[0] + ] + if use_orm_insert_stmt is not None + else () + ) + + records = ( + ( + None, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) + for ( + state, + state_dict, + params, + mp, + conn, + value_params, + has_all_pks, + has_all_defaults, + ) in persistence._collect_insert_commands( + table, + ((None, mapping, mapper, connection) for mapping in mappings), + bulk=True, + return_defaults=bookkeeping, + render_nulls=render_nulls, + include_bulk_keys=extra_bp_names, + ) + ) + + result = persistence._emit_insert_statements( + base_mapper, + None, + super_mapper, + table, + records, + bookkeeping=bookkeeping, + use_orm_insert_stmt=use_orm_insert_stmt, + execution_options=execution_options, + ) + if use_orm_insert_stmt is not None: + if not use_orm_insert_stmt._returning or return_result is None: + return_result = result + elif result.returns_rows: + assert bookkeeping + return_result = return_result.splice_horizontally(result) + + if return_defaults and isstates: + identity_cls = mapper._identity_class + identity_props = [p.key for p in mapper._identity_key_props] + for state, dict_ in states: + state.key = ( + identity_cls, + tuple([dict_[key] for key in identity_props]), + ) + + if use_orm_insert_stmt is not None: + assert return_result is not None + return return_result + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Literal[None] = ..., + enable_check_rowcount: bool = True, +) -> None: ... + + +@overload +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = ..., + enable_check_rowcount: bool = True, +) -> _result.Result[Any]: ... + + +def _bulk_update( + mapper: Mapper[Any], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + session_transaction: SessionTransaction, + isstates: bool, + update_changed_only: bool, + use_orm_update_stmt: Optional[dml.Update] = None, + enable_check_rowcount: bool = True, +) -> Optional[_result.Result[Any]]: + base_mapper = mapper.base_mapper + + search_keys = mapper._primary_key_propkeys + if mapper._version_id_prop: + search_keys = {mapper._version_id_prop.key}.union(search_keys) + + def _changed_dict(mapper, state): + return { + k: v + for k, v in state.dict.items() + if k in state.committed_state or k in search_keys + } + + if isstates: + if update_changed_only: + mappings = [_changed_dict(mapper, state) for state in mappings] + else: + mappings = [state.dict for state in mappings] + else: + mappings = [dict(m) for m in mappings] + _expand_composites(mapper, mappings) + + if session_transaction.session.connection_callable: + raise NotImplementedError( + "connection_callable / per-instance sharding " + "not supported in bulk_update()" + ) + + connection = session_transaction.connection(base_mapper) + + # find bindparams in the statement. see _bulk_insert for similar + # notes for the insert case + extra_bp_names = ( + [ + b.key + for b in use_orm_update_stmt._get_embedded_bindparams() + if b.key in mappings[0] + ] + if use_orm_update_stmt is not None + else () + ) + + for table, super_mapper in base_mapper._sorted_tables.items(): + if not mapper.isa(super_mapper) or table not in mapper._pks_by_table: + continue + + records = persistence._collect_update_commands( + None, + table, + ( + ( + None, + mapping, + mapper, + connection, + ( + mapping[mapper._version_id_prop.key] + if mapper._version_id_prop + else None + ), + ) + for mapping in mappings + ), + bulk=True, + use_orm_update_stmt=use_orm_update_stmt, + include_bulk_keys=extra_bp_names, + ) + persistence._emit_update_statements( + base_mapper, + None, + super_mapper, + table, + records, + bookkeeping=False, + use_orm_update_stmt=use_orm_update_stmt, + enable_check_rowcount=enable_check_rowcount, + ) + + if use_orm_update_stmt is not None: + return _result.null_result() + + +def _expand_composites(mapper, mappings): + composite_attrs = mapper.composites + if not composite_attrs: + return + + composite_keys = set(composite_attrs.keys()) + populators = { + key: composite_attrs[key]._populate_composite_bulk_save_mappings_fn() + for key in composite_keys + } + for mapping in mappings: + for key in composite_keys.intersection(mapping): + populators[key](mapping) + + +class ORMDMLState(AbstractORMCompileState): + is_dml_returning = True + from_statement_ctx: Optional[ORMFromStatementCompileState] = None + + @classmethod + def _get_orm_crud_kv_pairs( + cls, mapper, statement, kv_iterator, needs_to_be_cacheable + ): + core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs + + for k, v in kv_iterator: + k = coercions.expect(roles.DMLColumnRole, k) + + if isinstance(k, str): + desc = _entity_namespace_key(mapper, k, default=NO_VALUE) + if desc is NO_VALUE: + yield ( + coercions.expect(roles.DMLColumnRole, k), + ( + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v + ), + ) + else: + yield from core_get_crud_kv_pairs( + statement, + desc._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + elif "entity_namespace" in k._annotations: + k_anno = k._annotations + attr = _entity_namespace_key( + k_anno["entity_namespace"], k_anno["proxy_key"] + ) + yield from core_get_crud_kv_pairs( + statement, + attr._bulk_update_tuples(v), + needs_to_be_cacheable, + ) + else: + yield ( + k, + ( + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + ), + ) + + @classmethod + def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_multi_crud_kv_pairs( + statement, kv_iterator + ) + + return [ + dict( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, statement, value_dict.items(), False + ) + ) + for value_dict in kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs(cls, statement, kv_iterator, needs_to_be_cacheable): + assert ( + needs_to_be_cacheable + ), "no test coverage for needs_to_be_cacheable=False" + + plugin_subject = statement._propagate_attrs["plugin_subject"] + + if not plugin_subject or not plugin_subject.mapper: + return UpdateDMLState._get_crud_kv_pairs( + statement, kv_iterator, needs_to_be_cacheable + ) + + return list( + cls._get_orm_crud_kv_pairs( + plugin_subject.mapper, + statement, + kv_iterator, + needs_to_be_cacheable, + ) + ) + + @classmethod + def get_entity_description(cls, statement): + ext_info = statement.table._annotations["parententity"] + mapper = ext_info.mapper + if ext_info.is_aliased_class: + _label_name = ext_info.name + else: + _label_name = mapper.class_.__name__ + + return { + "name": _label_name, + "type": mapper.class_, + "expr": ext_info.entity, + "entity": ext_info.entity, + "table": mapper.local_table, + } + + @classmethod + def get_returning_column_descriptions(cls, statement): + def _ent_for_col(c): + return c._annotations.get("parententity", None) + + def _attr_for_col(c, ent): + if ent is None: + return c + proxy_key = c._annotations.get("proxy_key", None) + if not proxy_key: + return c + else: + return getattr(ent.entity, proxy_key, c) + + return [ + { + "name": c.key, + "type": c.type, + "expr": _attr_for_col(c, ent), + "aliased": ent.is_aliased_class, + "entity": ent.entity, + } + for c, ent in [ + (c, _ent_for_col(c)) for c in statement._all_selected_columns + ] + ] + + def _setup_orm_returning( + self, + compiler, + orm_level_statement, + dml_level_statement, + dml_mapper, + *, + use_supplemental_cols=True, + ): + """establish ORM column handlers for an INSERT, UPDATE, or DELETE + which uses explicit returning(). + + called within compilation level create_for_statement. + + The _return_orm_returning() method then receives the Result + after the statement was executed, and applies ORM loading to the + state that we first established here. + + """ + + if orm_level_statement._returning: + fs = FromStatement( + orm_level_statement._returning, + dml_level_statement, + _adapt_on_names=False, + ) + fs = fs.execution_options(**orm_level_statement._execution_options) + fs = fs.options(*orm_level_statement._with_options) + self.select_statement = fs + self.from_statement_ctx = fsc = ( + ORMFromStatementCompileState.create_for_statement(fs, compiler) + ) + fsc.setup_dml_returning_compile_state(dml_mapper) + + dml_level_statement = dml_level_statement._generate() + dml_level_statement._returning = () + + cols_to_return = [c for c in fsc.primary_columns if c is not None] + + # since we are splicing result sets together, make sure there + # are columns of some kind returned in each result set + if not cols_to_return: + cols_to_return.extend(dml_mapper.primary_key) + + if use_supplemental_cols: + dml_level_statement = dml_level_statement.return_defaults( + # this is a little weird looking, but by passing + # primary key as the main list of cols, this tells + # return_defaults to omit server-default cols (and + # actually all cols, due to some weird thing we should + # clean up in crud.py). + # Since we have cols_to_return, just return what we asked + # for (plus primary key, which ORM persistence needs since + # we likely set bookkeeping=True here, which is another + # whole thing...). We dont want to clutter the + # statement up with lots of other cols the user didn't + # ask for. see #9685 + *dml_mapper.primary_key, + supplemental_cols=cols_to_return, + ) + else: + dml_level_statement = dml_level_statement.returning( + *cols_to_return + ) + + return dml_level_statement + + @classmethod + def _return_orm_returning( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + if ( + compile_state.from_statement_ctx + and not compile_state.from_statement_ctx.compile_options._is_star + ): + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + + querycontext = QueryContext( + compile_state.from_statement_ctx, + compile_state.select_statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + else: + return result + + +class BulkUDCompileState(ORMDMLState): + class default_update_options(Options): + _dml_strategy: DMLStrategyArgument = "auto" + _synchronize_session: SynchronizeSessionArgument = "auto" + _can_use_returning: bool = False + _is_delete_using: bool = False + _is_update_from: bool = False + _autoflush: bool = True + _subject_mapper: Optional[Mapper[Any]] = None + _resolved_values = EMPTY_DICT + _eval_condition = None + _matched_rows = None + _identity_token = None + + @classmethod + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + is_executemany: bool = False, + ) -> bool: + raise NotImplementedError() + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + ( + update_options, + execution_options, + ) = BulkUDCompileState.default_update_options.from_execution_options( + "_sa_orm_update_options", + { + "synchronize_session", + "autoflush", + "identity_token", + "is_delete_using", + "is_update_from", + "dml_strategy", + }, + execution_options, + statement._execution_options, + ) + bind_arguments["clause"] = statement + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + if plugin_subject: + bind_arguments["mapper"] = plugin_subject.mapper + update_options += {"_subject_mapper": plugin_subject.mapper} + + if "parententity" not in statement.table._annotations: + update_options += {"_dml_strategy": "core_only"} + elif not isinstance(params, list): + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "orm"} + elif update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if update_options._dml_strategy == "auto": + update_options += {"_dml_strategy": "bulk"} + + sync = update_options._synchronize_session + if sync is not None: + if sync not in ("auto", "evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'auto', 'evaluate', 'fetch', False" + ) + if update_options._dml_strategy == "bulk" and sync == "fetch": + raise sa_exc.InvalidRequestError( + "The 'fetch' synchronization strategy is not available " + "for 'bulk' ORM updates (i.e. multiple parameter sets)" + ) + + if not is_pre_event: + if update_options._autoflush: + session._autoflush() + + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. + statement = statement._annotate( + { + "synchronize_session": update_options._synchronize_session, + "is_delete_using": update_options._is_delete_using, + "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, + } + ) + + return ( + statement, + util.immutabledict(execution_options).union( + {"_sa_orm_update_options": update_options} + ), + ) + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + # this stage of the execution is called after the + # do_orm_execute event hook. meaning for an extension like + # horizontal sharding, this step happens *within* the horizontal + # sharding event handler which calls session.execute() re-entrantly + # and will occur for each backend individually. + # the sharding extension then returns its own merged result from the + # individual ones we return here. + + update_options = execution_options["_sa_orm_update_options"] + if update_options._dml_strategy == "orm": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate( + session, statement, result, update_options + ) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch( + session, statement, result, update_options + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_bulk_evaluate( + session, params, result, update_options + ) + return result + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def _adjust_for_extra_criteria(cls, global_attributes, ext_info): + """Apply extra criteria filtering. + + For all distinct single-table-inheritance mappers represented in the + table being updated or deleted, produce additional WHERE criteria such + that only the appropriate subtypes are selected from the total results. + + Additionally, add WHERE criteria originating from LoaderCriteriaOptions + collected from the statement. + + """ + + return_crit = () + + adapter = ext_info._adapter if ext_info.is_aliased_class else None + + if ( + "additional_entity_criteria", + ext_info.mapper, + ) in global_attributes: + return_crit += tuple( + ae._resolve_where_criteria(ext_info) + for ae in global_attributes[ + ("additional_entity_criteria", ext_info.mapper) + ] + if ae.include_aliases or ae.entity is ext_info + ) + + if ext_info.mapper._single_table_criterion is not None: + return_crit += (ext_info.mapper._single_table_criterion,) + + if adapter: + return_crit = tuple(adapter.traverse(crit) for crit in return_crit) + + return return_crit + + @classmethod + def _interpret_returning_rows(cls, mapper, rows): + """translate from local inherited table columns to base mapper + primary key columns. + + Joined inheritance mappers always establish the primary key in terms of + the base table. When we UPDATE a sub-table, we can only get + RETURNING for the sub-table's columns. + + Here, we create a lookup from the local sub table's primary key + columns to the base table PK columns so that we can get identity + key values from RETURNING that's against the joined inheritance + sub-table. + + the complexity here is to support more than one level deep of + inheritance, where we have to link columns to each other across + the inheritance hierarchy. + + """ + + if mapper.local_table is not mapper.base_mapper.local_table: + return rows + + # this starts as a mapping of + # local_pk_col: local_pk_col. + # we will then iteratively rewrite the "value" of the dict with + # each successive superclass column + local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} + + for mp in mapper.iterate_to_root(): + if mp.inherits is None: + break + elif mp.local_table is mp.inherits.local_table: + continue + + t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) + col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} + for pk, super_ in local_pk_to_base_pk.items(): + local_pk_to_base_pk[pk] = col_to_col[super_] + + lookup = { + local_pk_to_base_pk[lpk]: idx + for idx, lpk in enumerate(mapper.local_table.primary_key) + } + primary_key_convert = [ + lookup[bpk] for bpk in mapper.base_mapper.primary_key + ] + return [tuple(row[idx] for idx in primary_key_convert) for row in rows] + + @classmethod + def _get_matched_objects_on_criteria(cls, update_options, states): + mapper = update_options._subject_mapper + eval_condition = update_options._eval_condition + + raw_data = [ + (state.obj(), state, state.dict) + for state in states + if state.mapper.isa(mapper) and not state.expired + ] + + identity_token = update_options._identity_token + if identity_token is not None: + raw_data = [ + (obj, state, dict_) + for obj, state, dict_ in raw_data + if state.identity_token == identity_token + ] + + result = [] + for obj, state, dict_ in raw_data: + evaled_condition = eval_condition(obj) + + # caution: don't use "in ()" or == here, _EXPIRE_OBJECT + # evaluates as True for all comparisons + if ( + evaled_condition is True + or evaled_condition is evaluator._EXPIRED_OBJECT + ): + result.append( + ( + obj, + state, + dict_, + evaled_condition is evaluator._EXPIRED_OBJECT, + ) + ) + return result + + @classmethod + def _eval_condition_from_statement(cls, update_options, statement): + mapper = update_options._subject_mapper + target_cls = mapper.class_ + + evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) + crit = () + if statement._where_criteria: + crit += statement._where_criteria + + global_attributes = {} + for opt in statement._with_options: + if opt._is_criteria_option: + opt.get_global_criteria(global_attributes) + + if global_attributes: + crit += cls._adjust_for_extra_criteria(global_attributes, mapper) + + if crit: + eval_condition = evaluator_compiler.process(*crit) + else: + # workaround for mypy https://github.com/python/mypy/issues/14027 + def _eval_condition(obj): + return True + + eval_condition = _eval_condition + + return eval_condition + + @classmethod + def _do_pre_synchronize_auto( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + """setup auto sync strategy + + + "auto" checks if we can use "evaluate" first, then falls back + to "fetch" + + evaluate is vastly more efficient for the common case + where session is empty, only has a few objects, and the UPDATE + statement can potentially match thousands/millions of rows. + + OTOH more complex criteria that fails to work with "evaluate" + we would hope usually correlates with fewer net rows. + + """ + + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) + + except evaluator.UnevaluatableError: + pass + else: + return update_options + { + "_eval_condition": eval_condition, + "_synchronize_session": "evaluate", + } + + update_options += {"_synchronize_session": "fetch"} + return cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + try: + eval_condition = cls._eval_condition_from_statement( + update_options, statement + ) + + except evaluator.UnevaluatableError as err: + raise sa_exc.InvalidRequestError( + 'Could not evaluate current criteria in Python: "%s". ' + "Specify 'fetch' or False for the " + "synchronize_session execution option." % err + ) from err + + return update_options + { + "_eval_condition": eval_condition, + } + + @classmethod + def _get_resolved_values(cls, mapper, statement): + if statement._multi_values: + return [] + elif statement._ordered_values: + return list(statement._ordered_values) + elif statement._values: + return list(statement._values.items()) + else: + return [] + + @classmethod + def _resolved_keys_as_propnames(cls, mapper, resolved_values): + values = [] + for k, v in resolved_values: + if mapper and isinstance(k, expression.ColumnElement): + try: + attr = mapper._columntoproperty[k] + except orm_exc.UnmappedColumnError: + pass + else: + values.append((attr.key, v)) + else: + raise sa_exc.InvalidRequestError( + "Attribute name not found, can't be " + "synchronized back to objects: %r" % k + ) + return values + + @classmethod + def _do_pre_synchronize_fetch( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + mapper = update_options._subject_mapper + + select_stmt = ( + select(*(mapper.primary_key + (mapper.select_identity_token,))) + .select_from(mapper) + .options(*statement._with_options) + ) + select_stmt._where_criteria = statement._where_criteria + + # conditionally run the SELECT statement for pre-fetch, testing the + # "bind" for if we can use RETURNING or not using the do_orm_execute + # event. If RETURNING is available, the do_orm_execute event + # will cancel the SELECT from being actually run. + # + # The way this is organized seems strange, why don't we just + # call can_use_returning() before invoking the statement and get + # answer?, why does this go through the whole execute phase using an + # event? Answer: because we are integrating with extensions such + # as the horizontal sharding extention that "multiplexes" an individual + # statement run through multiple engines, and it uses + # do_orm_execute() to do that. + + can_use_returning = None + + def skip_for_returning(orm_context: ORMExecuteState) -> Any: + bind = orm_context.session.get_bind(**orm_context.bind_arguments) + nonlocal can_use_returning + + per_bind_result = cls.can_use_returning( + bind.dialect, + mapper, + is_update_from=update_options._is_update_from, + is_delete_using=update_options._is_delete_using, + is_executemany=orm_context.is_executemany, + ) + + if can_use_returning is not None: + if can_use_returning != per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't mix multiple " + "backends where some support RETURNING and others " + "don't" + ) + elif orm_context.is_executemany and not per_bind_result: + raise sa_exc.InvalidRequestError( + "For synchronize_session='fetch', can't use multiple " + "parameter sets in ORM mode, which this backend does not " + "support with RETURNING" + ) + else: + can_use_returning = per_bind_result + + if per_bind_result: + return _result.null_result() + else: + return None + + result = session.execute( + select_stmt, + params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _add_event=skip_for_returning, + ) + matched_rows = result.fetchall() + + return update_options + { + "_matched_rows": matched_rows, + "_can_use_returning": can_use_returning, + } + + +@CompileState.plugin_for("orm", "insert") +class BulkORMInsert(ORMDMLState, InsertDMLState): + class default_insert_options(Options): + _dml_strategy: DMLStrategyArgument = "auto" + _render_nulls: bool = False + _return_defaults: bool = False + _subject_mapper: Optional[Mapper[Any]] = None + _autoflush: bool = True + _populate_existing: bool = False + + select_statement: Optional[FromStatement] = None + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + ( + insert_options, + execution_options, + ) = BulkORMInsert.default_insert_options.from_execution_options( + "_sa_orm_insert_options", + {"dml_strategy", "autoflush", "populate_existing", "render_nulls"}, + execution_options, + statement._execution_options, + ) + bind_arguments["clause"] = statement + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + if plugin_subject: + bind_arguments["mapper"] = plugin_subject.mapper + insert_options += {"_subject_mapper": plugin_subject.mapper} + + if not params: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "orm"} + elif insert_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + 'Can\'t use "bulk" ORM insert strategy without ' + "passing separate parameters" + ) + else: + if insert_options._dml_strategy == "auto": + insert_options += {"_dml_strategy": "bulk"} + + if insert_options._dml_strategy != "raw": + # for ORM object loading, like ORMContext, we have to disable + # result set adapt_to_context, because we will be generating a + # new statement with specific columns that's cached inside of + # an ORMFromStatementCompileState, which we will re-use for + # each result. + if not execution_options: + execution_options = context._orm_load_exec_options + else: + execution_options = execution_options.union( + context._orm_load_exec_options + ) + + if not is_pre_event and insert_options._autoflush: + session._autoflush() + + statement = statement._annotate( + {"dml_strategy": insert_options._dml_strategy} + ) + + return ( + statement, + util.immutabledict(execution_options).union( + {"_sa_orm_insert_options": insert_options} + ), + ) + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Insert, + params: _CoreAnyExecuteParams, + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + insert_options = execution_options.get( + "_sa_orm_insert_options", cls.default_insert_options + ) + + if insert_options._dml_strategy not in ( + "raw", + "bulk", + "orm", + "auto", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM insert strategy " + "are 'raw', 'orm', 'bulk', 'auto" + ) + + result: _result.Result[Any] + + if insert_options._dml_strategy == "raw": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return result + + if insert_options._dml_strategy == "bulk": + mapper = insert_options._subject_mapper + + if ( + statement._post_values_clause is not None + and mapper._multiple_persistence_tables + ): + raise sa_exc.InvalidRequestError( + "bulk INSERT with a 'post values' clause " + "(typically upsert) not supported for multi-table " + f"mapper {mapper}" + ) + + assert mapper is not None + assert session._transaction is not None + result = _bulk_insert( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + return_defaults=insert_options._return_defaults, + render_nulls=insert_options._render_nulls, + use_orm_insert_stmt=statement, + execution_options=execution_options, + ) + elif insert_options._dml_strategy == "orm": + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + else: + raise AssertionError() + + if not bool(statement._returning): + return result + + if insert_options._populate_existing: + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + load_options += {"_populate_existing": True} + execution_options = execution_options.union( + {"_sa_orm_load_options": load_options} + ) + + return cls._return_orm_returning( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert: + self = cast( + BulkORMInsert, + super().create_for_statement(statement, compiler, **kw), + ) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + if not toplevel: + return self + + mapper = statement._propagate_attrs["plugin_subject"] + dml_strategy = statement._annotations.get("dml_strategy", "raw") + if dml_strategy == "bulk": + self._setup_for_bulk_insert(compiler) + elif dml_strategy == "orm": + self._setup_for_orm_insert(compiler, mapper) + + return self + + @classmethod + def _resolved_keys_as_col_keys(cls, mapper, resolved_value_dict): + return { + col.key if col is not None else k: v + for col, k, v in ( + (mapper.c.get(k), k, v) for k, v in resolved_value_dict.items() + ) + } + + def _setup_for_orm_insert(self, compiler, mapper): + statement = orm_level_statement = cast(dml.Insert, self.statement) + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + dml_mapper=mapper, + use_supplemental_cols=False, + ) + self.statement = statement + + def _setup_for_bulk_insert(self, compiler): + """establish an INSERT statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_insert_statement(). + + """ + statement = orm_level_statement = cast(dml.Insert, self.statement) + an = statement._annotations + + emit_insert_table, emit_insert_mapper = ( + an["_emit_insert_table"], + an["_emit_insert_mapper"], + ) + + statement = statement._clone() + + statement.table = emit_insert_table + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_insert_table + } + + statement = self._setup_orm_returning( + compiler, + orm_level_statement, + statement, + dml_mapper=emit_insert_mapper, + use_supplemental_cols=True, + ) + + if ( + self.from_statement_ctx is not None + and self.from_statement_ctx.compile_options._is_star + ): + raise sa_exc.CompileError( + "Can't use RETURNING * with bulk ORM INSERT. " + "Please use a different INSERT form, such as INSERT..VALUES " + "or INSERT with a Core Connection" + ) + + self.statement = statement + + +@CompileState.plugin_for("orm", "update") +class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + self = cls.__new__(cls) + + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + toplevel = not compiler.stack + + if toplevel and dml_strategy == "bulk": + self._setup_for_bulk_update(statement, compiler) + elif ( + dml_strategy == "core_only" + or dml_strategy == "unspecified" + and "parententity" not in statement.table._annotations + ): + UpdateDMLState.__init__(self, statement, compiler, **kw) + elif not toplevel or dml_strategy in ("orm", "unspecified"): + self._setup_for_orm_update(statement, compiler) + + return self + + def _setup_for_orm_update(self, statement, compiler, **kw): + orm_level_statement = statement + + toplevel = not compiler.stack + + ext_info = statement.table._annotations["parententity"] + + self.mapper = mapper = ext_info.mapper + + self._resolved_values = self._get_resolved_values(mapper, statement) + + self._init_global_attributes( + statement, + compiler, + toplevel=toplevel, + process_criteria_for_toplevel=toplevel, + ) + + if statement._values: + self._resolved_values = dict(self._resolved_values) + + new_stmt = statement._clone() + + # note if the statement has _multi_values, these + # are passed through to the new statement, which will then raise + # InvalidRequestError because UPDATE doesn't support multi_values + # right now. + if statement._ordered_values: + new_stmt._ordered_values = self._resolved_values + elif statement._values: + new_stmt._values = self._resolved_values + + new_crit = self._adjust_for_extra_criteria( + self.global_attributes, mapper + ) + if new_crit: + new_stmt = new_stmt.where(*new_crit) + + # if we are against a lambda statement we might not be the + # topmost object that received per-execute annotations + + # do this first as we need to determine if there is + # UPDATE..FROM + + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False + + if not toplevel: + synchronize_session = None + else: + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, mapper, is_multitable=self.is_multitable + ) + ) + + if synchronize_session == "fetch" and can_use_returning: + use_supplemental_cols = True + + # NOTE: we might want to RETURNING the actual columns to be + # synchronized also. however this is complicated and difficult + # to align against the behavior of "evaluate". Additionally, + # in a large number (if not the majority) of cases, we have the + # "evaluate" answer, usually a fixed value, in memory already and + # there's no need to re-fetch the same value + # over and over again. so perhaps if it could be RETURNING just + # the elements that were based on a SQL expression and not + # a constant. For now it doesn't quite seem worth it + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + if toplevel: + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + dml_mapper=mapper, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + def _setup_for_bulk_update(self, statement, compiler, **kw): + """establish an UPDATE statement within the context of + bulk insert. + + This method will be within the "conn.execute()" call that is invoked + by persistence._emit_update_statement(). + + """ + statement = cast(dml.Update, statement) + an = statement._annotations + + emit_update_table, _ = ( + an["_emit_update_table"], + an["_emit_update_mapper"], + ) + + statement = statement._clone() + statement.table = emit_update_table + + UpdateDMLState.__init__(self, statement, compiler, **kw) + + if self._ordered_values: + raise sa_exc.InvalidRequestError( + "bulk ORM UPDATE does not support ordered_values() for " + "custom UPDATE statements with bulk parameter sets. Use a " + "non-bulk UPDATE statement or use values()." + ) + + if self._dict_parameters: + self._dict_parameters = { + col: val + for col, val in self._dict_parameters.items() + if col.table is emit_update_table + } + self.statement = statement + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Update, + params: _CoreAnyExecuteParams, + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy not in ( + "orm", + "auto", + "bulk", + "core_only", + ): + raise sa_exc.ArgumentError( + "Valid strategies for ORM UPDATE strategy " + "are 'orm', 'auto', 'bulk', 'core_only'" + ) + + result: _result.Result[Any] + + if update_options._dml_strategy == "bulk": + enable_check_rowcount = not statement._where_criteria + + assert update_options._synchronize_session != "fetch" + + if ( + statement._where_criteria + and update_options._synchronize_session == "evaluate" + ): + raise sa_exc.InvalidRequestError( + "bulk synchronize of persistent objects not supported " + "when using bulk update with additional WHERE " + "criteria right now. add synchronize_session=None " + "execution option to bypass synchronize of persistent " + "objects." + ) + mapper = update_options._subject_mapper + assert mapper is not None + assert session._transaction is not None + result = _bulk_update( + mapper, + cast( + "Iterable[Dict[str, Any]]", + [params] if isinstance(params, dict) else params, + ), + session._transaction, + isstates=False, + update_changed_only=False, + use_orm_update_stmt=statement, + enable_check_rowcount=enable_check_rowcount, + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + else: + return super().orm_execute_statement( + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) + + @classmethod + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + is_executemany: bool = False, + ) -> bool: + # normal answer for "should we use RETURNING" at all. + normal_answer = ( + dialect.update_returning and mapper.local_table.implicit_returning + ) + if not normal_answer: + return False + + if is_executemany: + return dialect.update_executemany_returning + + # these workarounds are currently hypothetical for UPDATE, + # unlike DELETE where they impact MariaDB + if is_update_from: + return dialect.update_returning_multifrom + + elif is_multitable and not dialect.update_returning_multifrom: + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with UPDATE..FROM; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_update_from=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True + + @classmethod + def _do_post_synchronize_bulk_evaluate( + cls, session, params, result, update_options + ): + if not params: + return + + mapper = update_options._subject_mapper + pk_keys = [prop.key for prop in mapper._identity_key_props] + + identity_map = session.identity_map + + for param in params: + identity_key = mapper.identity_key_from_primary_key( + (param[key] for key in pk_keys), + update_options._identity_token, + ) + state = identity_map.fast_get_state(identity_key) + if not state: + continue + + evaluated_keys = set(param).difference(pk_keys) + + dict_ = state.dict + # only evaluate unmodified attributes + to_evaluate = state.unmodified.intersection(evaluated_keys) + for key in to_evaluate: + if key in dict_: + dict_[key] = param[key] + + state.manager.dispatch.refresh(state, None, to_evaluate) + + state._commit(dict_, list(to_evaluate)) + + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = evaluated_keys.intersection(dict_).difference( + to_evaluate + ) + if to_expire: + state._expire_attributes(dict_, to_expire) + + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], + ) + + @classmethod + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): + target_mapper = update_options._subject_mapper + + returned_defaults_rows = result.returned_defaults_rows + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) + + matched_rows = [ + tuple(row) + (update_options._identity_token,) + for row in pk_rows + ] + else: + matched_rows = update_options._matched_rows + + objs = [ + session.identity_map[identity_key] + for identity_key in [ + target_mapper.identity_key_from_primary_key( + list(primary_key), + identity_token=identity_token, + ) + for primary_key, identity_token in [ + (row[0:-1], row[-1]) for row in matched_rows + ] + if update_options._identity_token is None + or identity_token == update_options._identity_token + ] + if identity_key in session.identity_map + ] + + if not objs: + return + + cls._apply_update_set_values_to_objects( + session, + update_options, + statement, + [ + ( + obj, + attributes.instance_state(obj), + attributes.instance_dict(obj), + ) + for obj in objs + ], + ) + + @classmethod + def _apply_update_set_values_to_objects( + cls, session, update_options, statement, matched_objects + ): + """apply values to objects derived from an update statement, e.g. + UPDATE..SET + + """ + mapper = update_options._subject_mapper + target_cls = mapper.class_ + evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + value_evaluators = {} + for key, value in resolved_keys_as_propnames: + try: + _evaluator = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) + except evaluator.UnevaluatableError: + pass + else: + value_evaluators[key] = _evaluator + + evaluated_keys = list(value_evaluators.keys()) + attrib = {k for k, v in resolved_keys_as_propnames} + + states = set() + for obj, state, dict_ in matched_objects: + to_evaluate = state.unmodified.intersection(evaluated_keys) + + for key in to_evaluate: + if key in dict_: + # only run eval for attributes that are present. + dict_[key] = value_evaluators[key](obj) + + state.manager.dispatch.refresh(state, None, to_evaluate) + + state._commit(dict_, list(to_evaluate)) + + # attributes that were formerly modified instead get expired. + # this only gets hit if the session had pending changes + # and autoflush were set to False. + to_expire = attrib.intersection(dict_).difference(to_evaluate) + if to_expire: + state._expire_attributes(dict_, to_expire) + + states.add(state) + session._register_altered(states) + + +@CompileState.plugin_for("orm", "delete") +class BulkORMDelete(BulkUDCompileState, DeleteDMLState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + self = cls.__new__(cls) + + dml_strategy = statement._annotations.get( + "dml_strategy", "unspecified" + ) + + if ( + dml_strategy == "core_only" + or dml_strategy == "unspecified" + and "parententity" not in statement.table._annotations + ): + DeleteDMLState.__init__(self, statement, compiler, **kw) + return self + + toplevel = not compiler.stack + + orm_level_statement = statement + + ext_info = statement.table._annotations["parententity"] + self.mapper = mapper = ext_info.mapper + + self._init_global_attributes( + statement, + compiler, + toplevel=toplevel, + process_criteria_for_toplevel=toplevel, + ) + + new_stmt = statement._clone() + + new_crit = cls._adjust_for_extra_criteria( + self.global_attributes, mapper + ) + if new_crit: + new_stmt = new_stmt.where(*new_crit) + + # do this first as we need to determine if there is + # DELETE..FROM + DeleteDMLState.__init__(self, new_stmt, compiler, **kw) + + use_supplemental_cols = False + + if not toplevel: + synchronize_session = None + else: + synchronize_session = compiler._annotations.get( + "synchronize_session", None + ) + can_use_returning = compiler._annotations.get( + "can_use_returning", None + ) + if can_use_returning is not False: + # even though pre_exec has determined basic + # can_use_returning for the dialect, if we are to use + # RETURNING we need to run can_use_returning() at this level + # unconditionally because is_delete_using was not known + # at the pre_exec level + can_use_returning = ( + synchronize_session == "fetch" + and self.can_use_returning( + compiler.dialect, + mapper, + is_multitable=self.is_multitable, + is_delete_using=compiler._annotations.get( + "is_delete_using", False + ), + ) + ) + + if can_use_returning: + use_supplemental_cols = True + + new_stmt = new_stmt.return_defaults(*new_stmt.table.primary_key) + + if toplevel: + new_stmt = self._setup_orm_returning( + compiler, + orm_level_statement, + new_stmt, + dml_mapper=mapper, + use_supplemental_cols=use_supplemental_cols, + ) + + self.statement = new_stmt + + return self + + @classmethod + def orm_execute_statement( + cls, + session: Session, + statement: dml.Delete, + params: _CoreAnyExecuteParams, + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + conn: Connection, + ) -> _result.Result: + update_options = execution_options.get( + "_sa_orm_update_options", cls.default_update_options + ) + + if update_options._dml_strategy == "bulk": + raise sa_exc.InvalidRequestError( + "Bulk ORM DELETE not supported right now. " + "Statement may be invoked at the " + "Core level using " + "session.connection().execute(stmt, parameters)" + ) + + if update_options._dml_strategy not in ("orm", "auto", "core_only"): + raise sa_exc.ArgumentError( + "Valid strategies for ORM DELETE strategy are 'orm', 'auto', " + "'core_only'" + ) + + return super().orm_execute_statement( + session, statement, params, execution_options, bind_arguments, conn + ) + + @classmethod + def can_use_returning( + cls, + dialect: Dialect, + mapper: Mapper[Any], + *, + is_multitable: bool = False, + is_update_from: bool = False, + is_delete_using: bool = False, + is_executemany: bool = False, + ) -> bool: + # normal answer for "should we use RETURNING" at all. + normal_answer = ( + dialect.delete_returning and mapper.local_table.implicit_returning + ) + if not normal_answer: + return False + + # now get into special workarounds because MariaDB supports + # DELETE...RETURNING but not DELETE...USING...RETURNING. + if is_delete_using: + # is_delete_using hint was passed. use + # additional dialect feature (True for PG, False for MariaDB) + return dialect.delete_returning_multifrom + + elif is_multitable and not dialect.delete_returning_multifrom: + # is_delete_using hint was not passed, but we determined + # at compile time that this is in fact a DELETE..USING. + # it's too late to continue since we did not pre-SELECT. + # raise that we need that hint up front. + + raise sa_exc.CompileError( + f'Dialect "{dialect.name}" does not support RETURNING ' + "with DELETE..USING; for synchronize_session='fetch', " + "please add the additional execution option " + "'is_delete_using=True' to the statement to indicate that " + "a separate SELECT should be used for this backend." + ) + + return True + + @classmethod + def _do_post_synchronize_evaluate( + cls, session, statement, result, update_options + ): + matched_objects = cls._get_matched_objects_on_criteria( + update_options, + session.identity_map.all_states(), + ) + + to_delete = [] + + for _, state, dict_, is_partially_expired in matched_objects: + if is_partially_expired: + state._expire(dict_, session.identity_map._modified) + else: + to_delete.append(state) + + if to_delete: + session._remove_newly_deleted(to_delete) + + @classmethod + def _do_post_synchronize_fetch( + cls, session, statement, result, update_options + ): + target_mapper = update_options._subject_mapper + + returned_defaults_rows = result.returned_defaults_rows + + if returned_defaults_rows: + pk_rows = cls._interpret_returning_rows( + target_mapper, returned_defaults_rows + ) + + matched_rows = [ + tuple(row) + (update_options._identity_token,) + for row in pk_rows + ] + else: + matched_rows = update_options._matched_rows + + for row in matched_rows: + primary_key = row[0:-1] + identity_token = row[-1] + + # TODO: inline this and call remove_newly_deleted + # once + identity_key = target_mapper.identity_key_from_primary_key( + list(primary_key), + identity_token=identity_token, + ) + if identity_key in session.identity_map: + session._remove_newly_deleted( + [ + attributes.instance_state( + session.identity_map[identity_key] + ) + ] + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/clsregistry.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/clsregistry.py new file mode 100644 index 0000000..26113d8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/clsregistry.py @@ -0,0 +1,570 @@ +# orm/clsregistry.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Routines to handle the string class registry used by declarative. + +This system allows specification of classes and expressions used in +:func:`_orm.relationship` using strings. + +""" + +from __future__ import annotations + +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generator +from typing import Iterable +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import interfaces +from .descriptor_props import SynonymProperty +from .properties import ColumnProperty +from .util import class_mapper +from .. import exc +from .. import inspection +from .. import util +from ..sql.schema import _get_table_key +from ..util.typing import CallableReference + +if TYPE_CHECKING: + from .relationships import RelationshipProperty + from ..sql.schema import MetaData + from ..sql.schema import Table + +_T = TypeVar("_T", bound=Any) + +_ClsRegistryType = MutableMapping[str, Union[type, "ClsRegistryToken"]] + +# strong references to registries which we place in +# the _decl_class_registry, which is usually weak referencing. +# the internal registries here link to classes with weakrefs and remove +# themselves when all references to contained classes are removed. +_registries: Set[ClsRegistryToken] = set() + + +def add_class( + classname: str, cls: Type[_T], decl_class_registry: _ClsRegistryType +) -> None: + """Add a class to the _decl_class_registry associated with the + given declarative class. + + """ + if classname in decl_class_registry: + # class already exists. + existing = decl_class_registry[classname] + if not isinstance(existing, _MultipleClassMarker): + existing = decl_class_registry[classname] = _MultipleClassMarker( + [cls, cast("Type[Any]", existing)] + ) + else: + decl_class_registry[classname] = cls + + try: + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) + except KeyError: + decl_class_registry["_sa_module_registry"] = root_module = ( + _ModuleMarker("_sa_module_registry", None) + ) + + tokens = cls.__module__.split(".") + + # build up a tree like this: + # modulename: myapp.snacks.nuts + # + # myapp->snack->nuts->(classes) + # snack->nuts->(classes) + # nuts->(classes) + # + # this allows partial token paths to be used. + while tokens: + token = tokens.pop(0) + module = root_module.get_module(token) + for token in tokens: + module = module.get_module(token) + + try: + module.add_class(classname, cls) + except AttributeError as ae: + if not isinstance(module, _ModuleMarker): + raise exc.InvalidRequestError( + f'name "{classname}" matches both a ' + "class name and a module name" + ) from ae + else: + raise + + +def remove_class( + classname: str, cls: Type[Any], decl_class_registry: _ClsRegistryType +) -> None: + if classname in decl_class_registry: + existing = decl_class_registry[classname] + if isinstance(existing, _MultipleClassMarker): + existing.remove_item(cls) + else: + del decl_class_registry[classname] + + try: + root_module = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) + except KeyError: + return + + tokens = cls.__module__.split(".") + + while tokens: + token = tokens.pop(0) + module = root_module.get_module(token) + for token in tokens: + module = module.get_module(token) + try: + module.remove_class(classname, cls) + except AttributeError: + if not isinstance(module, _ModuleMarker): + pass + else: + raise + + +def _key_is_empty( + key: str, + decl_class_registry: _ClsRegistryType, + test: Callable[[Any], bool], +) -> bool: + """test if a key is empty of a certain object. + + used for unit tests against the registry to see if garbage collection + is working. + + "test" is a callable that will be passed an object should return True + if the given object is the one we were looking for. + + We can't pass the actual object itself b.c. this is for testing garbage + collection; the caller will have to have removed references to the + object itself. + + """ + if key not in decl_class_registry: + return True + + thing = decl_class_registry[key] + if isinstance(thing, _MultipleClassMarker): + for sub_thing in thing.contents: + if test(sub_thing): + return False + else: + raise NotImplementedError("unknown codepath") + else: + return not test(thing) + + +class ClsRegistryToken: + """an object that can be in the registry._class_registry as a value.""" + + __slots__ = () + + +class _MultipleClassMarker(ClsRegistryToken): + """refers to multiple classes of the same name + within _decl_class_registry. + + """ + + __slots__ = "on_remove", "contents", "__weakref__" + + contents: Set[weakref.ref[Type[Any]]] + on_remove: CallableReference[Optional[Callable[[], None]]] + + def __init__( + self, + classes: Iterable[Type[Any]], + on_remove: Optional[Callable[[], None]] = None, + ): + self.on_remove = on_remove + self.contents = { + weakref.ref(item, self._remove_item) for item in classes + } + _registries.add(self) + + def remove_item(self, cls: Type[Any]) -> None: + self._remove_item(weakref.ref(cls)) + + def __iter__(self) -> Generator[Optional[Type[Any]], None, None]: + return (ref() for ref in self.contents) + + def attempt_get(self, path: List[str], key: str) -> Type[Any]: + if len(self.contents) > 1: + raise exc.InvalidRequestError( + 'Multiple classes found for path "%s" ' + "in the registry of this declarative " + "base. Please use a fully module-qualified path." + % (".".join(path + [key])) + ) + else: + ref = list(self.contents)[0] + cls = ref() + if cls is None: + raise NameError(key) + return cls + + def _remove_item(self, ref: weakref.ref[Type[Any]]) -> None: + self.contents.discard(ref) + if not self.contents: + _registries.discard(self) + if self.on_remove: + self.on_remove() + + def add_item(self, item: Type[Any]) -> None: + # protect against class registration race condition against + # asynchronous garbage collection calling _remove_item, + # [ticket:3208] and [ticket:10782] + modules = { + cls.__module__ + for cls in [ref() for ref in list(self.contents)] + if cls is not None + } + if item.__module__ in modules: + util.warn( + "This declarative base already contains a class with the " + "same class name and module name as %s.%s, and will " + "be replaced in the string-lookup table." + % (item.__module__, item.__name__) + ) + self.contents.add(weakref.ref(item, self._remove_item)) + + +class _ModuleMarker(ClsRegistryToken): + """Refers to a module name within + _decl_class_registry. + + """ + + __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" + + parent: Optional[_ModuleMarker] + contents: Dict[str, Union[_ModuleMarker, _MultipleClassMarker]] + mod_ns: _ModNS + path: List[str] + + def __init__(self, name: str, parent: Optional[_ModuleMarker]): + self.parent = parent + self.name = name + self.contents = {} + self.mod_ns = _ModNS(self) + if self.parent: + self.path = self.parent.path + [self.name] + else: + self.path = [] + _registries.add(self) + + def __contains__(self, name: str) -> bool: + return name in self.contents + + def __getitem__(self, name: str) -> ClsRegistryToken: + return self.contents[name] + + def _remove_item(self, name: str) -> None: + self.contents.pop(name, None) + if not self.contents and self.parent is not None: + self.parent._remove_item(self.name) + _registries.discard(self) + + def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: + return self.mod_ns.__getattr__(key) + + def get_module(self, name: str) -> _ModuleMarker: + if name not in self.contents: + marker = _ModuleMarker(name, self) + self.contents[name] = marker + else: + marker = cast(_ModuleMarker, self.contents[name]) + return marker + + def add_class(self, name: str, cls: Type[Any]) -> None: + if name in self.contents: + existing = cast(_MultipleClassMarker, self.contents[name]) + try: + existing.add_item(cls) + except AttributeError as ae: + if not isinstance(existing, _MultipleClassMarker): + raise exc.InvalidRequestError( + f'name "{name}" matches both a ' + "class name and a module name" + ) from ae + else: + raise + else: + existing = self.contents[name] = _MultipleClassMarker( + [cls], on_remove=lambda: self._remove_item(name) + ) + + def remove_class(self, name: str, cls: Type[Any]) -> None: + if name in self.contents: + existing = cast(_MultipleClassMarker, self.contents[name]) + existing.remove_item(cls) + + +class _ModNS: + __slots__ = ("__parent",) + + __parent: _ModuleMarker + + def __init__(self, parent: _ModuleMarker): + self.__parent = parent + + def __getattr__(self, key: str) -> Union[_ModNS, Type[Any]]: + try: + value = self.__parent.contents[key] + except KeyError: + pass + else: + if value is not None: + if isinstance(value, _ModuleMarker): + return value.mod_ns + else: + assert isinstance(value, _MultipleClassMarker) + return value.attempt_get(self.__parent.path, key) + raise NameError( + "Module %r has no mapped classes " + "registered under the name %r" % (self.__parent.name, key) + ) + + +class _GetColumns: + __slots__ = ("cls",) + + cls: Type[Any] + + def __init__(self, cls: Type[Any]): + self.cls = cls + + def __getattr__(self, key: str) -> Any: + mp = class_mapper(self.cls, configure=False) + if mp: + if key not in mp.all_orm_descriptors: + raise AttributeError( + "Class %r does not have a mapped column named %r" + % (self.cls, key) + ) + + desc = mp.all_orm_descriptors[key] + if desc.extension_type is interfaces.NotExtension.NOT_EXTENSION: + assert isinstance(desc, attributes.QueryableAttribute) + prop = desc.property + if isinstance(prop, SynonymProperty): + key = prop.name + elif not isinstance(prop, ColumnProperty): + raise exc.InvalidRequestError( + "Property %r is not an instance of" + " ColumnProperty (i.e. does not correspond" + " directly to a Column)." % key + ) + return getattr(self.cls, key) + + +inspection._inspects(_GetColumns)( + lambda target: inspection.inspect(target.cls) +) + + +class _GetTable: + __slots__ = "key", "metadata" + + key: str + metadata: MetaData + + def __init__(self, key: str, metadata: MetaData): + self.key = key + self.metadata = metadata + + def __getattr__(self, key: str) -> Table: + return self.metadata.tables[_get_table_key(key, self.key)] + + +def _determine_container(key: str, value: Any) -> _GetColumns: + if isinstance(value, _MultipleClassMarker): + value = value.attempt_get([], key) + return _GetColumns(value) + + +class _class_resolver: + __slots__ = ( + "cls", + "prop", + "arg", + "fallback", + "_dict", + "_resolvers", + "favor_tables", + ) + + cls: Type[Any] + prop: RelationshipProperty[Any] + fallback: Mapping[str, Any] + arg: str + favor_tables: bool + _resolvers: Tuple[Callable[[str], Any], ...] + + def __init__( + self, + cls: Type[Any], + prop: RelationshipProperty[Any], + fallback: Mapping[str, Any], + arg: str, + favor_tables: bool = False, + ): + self.cls = cls + self.prop = prop + self.arg = arg + self.fallback = fallback + self._dict = util.PopulateDict(self._access_cls) + self._resolvers = () + self.favor_tables = favor_tables + + def _access_cls(self, key: str) -> Any: + cls = self.cls + + manager = attributes.manager_of_class(cls) + decl_base = manager.registry + assert decl_base is not None + decl_class_registry = decl_base._class_registry + metadata = decl_base.metadata + + if self.favor_tables: + if key in metadata.tables: + return metadata.tables[key] + elif key in metadata._schemas: + return _GetTable(key, getattr(cls, "metadata", metadata)) + + if key in decl_class_registry: + return _determine_container(key, decl_class_registry[key]) + + if not self.favor_tables: + if key in metadata.tables: + return metadata.tables[key] + elif key in metadata._schemas: + return _GetTable(key, getattr(cls, "metadata", metadata)) + + if "_sa_module_registry" in decl_class_registry and key in cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ): + registry = cast( + _ModuleMarker, decl_class_registry["_sa_module_registry"] + ) + return registry.resolve_attr(key) + elif self._resolvers: + for resolv in self._resolvers: + value = resolv(key) + if value is not None: + return value + + return self.fallback[key] + + def _raise_for_name(self, name: str, err: Exception) -> NoReturn: + generic_match = re.match(r"(.+)\[(.+)\]", name) + + if generic_match: + clsarg = generic_match.group(2).strip("'") + raise exc.InvalidRequestError( + f"When initializing mapper {self.prop.parent}, " + f'expression "relationship({self.arg!r})" seems to be ' + "using a generic class as the argument to relationship(); " + "please state the generic argument " + "using an annotation, e.g. " + f'"{self.prop.key}: Mapped[{generic_match.group(1)}' + f"['{clsarg}']] = relationship()\"" + ) from err + else: + raise exc.InvalidRequestError( + "When initializing mapper %s, expression %r failed to " + "locate a name (%r). If this is a class name, consider " + "adding this relationship() to the %r class after " + "both dependent classes have been defined." + % (self.prop.parent, self.arg, name, self.cls) + ) from err + + def _resolve_name(self) -> Union[Table, Type[Any], _ModNS]: + name = self.arg + d = self._dict + rval = None + try: + for token in name.split("."): + if rval is None: + rval = d[token] + else: + rval = getattr(rval, token) + except KeyError as err: + self._raise_for_name(name, err) + except NameError as n: + self._raise_for_name(n.args[0], n) + else: + if isinstance(rval, _GetColumns): + return rval.cls + else: + if TYPE_CHECKING: + assert isinstance(rval, (type, Table, _ModNS)) + return rval + + def __call__(self) -> Any: + try: + x = eval(self.arg, globals(), self._dict) + + if isinstance(x, _GetColumns): + return x.cls + else: + return x + except NameError as n: + self._raise_for_name(n.args[0], n) + + +_fallback_dict: Mapping[str, Any] = None # type: ignore + + +def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], +]: + global _fallback_dict + + if _fallback_dict is None: + import sqlalchemy + from . import foreign + from . import remote + + _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union( + {"foreign": foreign, "remote": remote} + ) + + def resolve_arg(arg: str, favor_tables: bool = False) -> _class_resolver: + return _class_resolver( + cls, prop, _fallback_dict, arg, favor_tables=favor_tables + ) + + def resolve_name( + arg: str, + ) -> Callable[[], Union[Type[Any], Table, _ModNS]]: + return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name + + return resolve_name, resolve_arg diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/collections.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/collections.py new file mode 100644 index 0000000..6fefd78 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/collections.py @@ -0,0 +1,1618 @@ +# orm/collections.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Support for collections of mapped entities. + +The collections package supplies the machinery used to inform the ORM of +collection membership changes. An instrumentation via decoration approach is +used, allowing arbitrary types (including built-ins) to be used as entity +collections without requiring inheritance from a base class. + +Instrumentation decoration relays membership change events to the +:class:`.CollectionAttributeImpl` that is currently managing the collection. +The decorators observe function call arguments and return values, tracking +entities entering or leaving the collection. Two decorator approaches are +provided. One is a bundle of generic decorators that map function arguments +and return values to events:: + + from sqlalchemy.orm.collections import collection + class MyClass: + # ... + + @collection.adds(1) + def store(self, item): + self.data.append(item) + + @collection.removes_return() + def pop(self): + return self.data.pop() + + +The second approach is a bundle of targeted decorators that wrap appropriate +append and remove notifiers around the mutation methods present in the +standard Python ``list``, ``set`` and ``dict`` interfaces. These could be +specified in terms of generic decorator recipes, but are instead hand-tooled +for increased efficiency. The targeted decorators occasionally implement +adapter-like behavior, such as mapping bulk-set methods (``extend``, +``update``, ``__setslice__``, etc.) into the series of atomic mutation events +that the ORM requires. + +The targeted decorators are used internally for automatic instrumentation of +entity collection classes. Every collection class goes through a +transformation process roughly like so: + +1. If the class is a built-in, substitute a trivial sub-class +2. Is this class already instrumented? +3. Add in generic decorators +4. Sniff out the collection interface through duck-typing +5. Add targeted decoration to any undecorated interface method + +This process modifies the class at runtime, decorating methods and adding some +bookkeeping properties. This isn't possible (or desirable) for built-in +classes like ``list``, so trivial sub-classes are substituted to hold +decoration:: + + class InstrumentedList(list): + pass + +Collection classes can be specified in ``relationship(collection_class=)`` as +types or a function that returns an instance. Collection classes are +inspected and instrumented during the mapper compilation phase. The +collection_class callable will be executed once to produce a specimen +instance, and the type of that specimen will be instrumented. Functions that +return built-in types like ``lists`` will be adapted to produce instrumented +instances. + +When extending a known type like ``list``, additional decorations are not +generally not needed. Odds are, the extension method will delegate to a +method that's already instrumented. For example:: + + class QueueIsh(list): + def push(self, item): + self.append(item) + def shift(self): + return self.pop(0) + +There's no need to decorate these methods. ``append`` and ``pop`` are already +instrumented as part of the ``list`` interface. Decorating them would fire +duplicate events, which should be avoided. + +The targeted decoration tries not to rely on other methods in the underlying +collection class, but some are unavoidable. Many depend on 'read' methods +being present to properly instrument a 'write', for example, ``__setitem__`` +needs ``__getitem__``. "Bulk" methods like ``update`` and ``extend`` may also +reimplemented in terms of atomic appends and removes, so the ``extend`` +decoration will actually perform many ``append`` operations and not call the +underlying method at all. + +Tight control over bulk operation and the firing of events is also possible by +implementing the instrumentation internally in your methods. The basic +instrumentation package works under the general assumption that collection +mutation will not raise unusual exceptions. If you want to closely +orchestrate append and remove events with exception management, internal +instrumentation may be the answer. Within your method, +``collection_adapter(self)`` will retrieve an object that you can use for +explicit control over triggering append and remove events. + +The owning object and :class:`.CollectionAttributeImpl` are also reachable +through the adapter, allowing for some very sophisticated behavior. + +""" +from __future__ import annotations + +import operator +import threading +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from .base import NO_KEY +from .. import exc as sa_exc +from .. import util +from ..sql.base import NO_ARG +from ..util.compat import inspect_getfullargspec +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from .attributes import AttributeEventToken + from .attributes import CollectionAttributeImpl + from .mapped_collection import attribute_keyed_dict + from .mapped_collection import column_keyed_dict + from .mapped_collection import keyfunc_mapping + from .mapped_collection import KeyFuncDict # noqa: F401 + from .state import InstanceState + + +__all__ = [ + "collection", + "collection_adapter", + "keyfunc_mapping", + "column_keyed_dict", + "attribute_keyed_dict", + "column_keyed_dict", + "attribute_keyed_dict", + "MappedCollection", + "KeyFuncDict", +] + +__instrumentation_mutex = threading.Lock() + + +_CollectionFactoryType = Callable[[], "_AdaptedCollectionProtocol"] + +_T = TypeVar("_T", bound=Any) +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) +_COL = TypeVar("_COL", bound="Collection[Any]") +_FN = TypeVar("_FN", bound="Callable[..., Any]") + + +class _CollectionConverterProtocol(Protocol): + def __call__(self, collection: _COL) -> _COL: ... + + +class _AdaptedCollectionProtocol(Protocol): + _sa_adapter: CollectionAdapter + _sa_appender: Callable[..., Any] + _sa_remover: Callable[..., Any] + _sa_iterator: Callable[..., Iterable[Any]] + _sa_converter: _CollectionConverterProtocol + + +class collection: + """Decorators for entity collection classes. + + The decorators fall into two groups: annotations and interception recipes. + + The annotating decorators (appender, remover, iterator, converter, + internally_instrumented) indicate the method's purpose and take no + arguments. They are not written with parens:: + + @collection.appender + def append(self, append): ... + + The recipe decorators all require parens, even those that take no + arguments:: + + @collection.adds('entity') + def insert(self, position, entity): ... + + @collection.removes_return() + def popitem(self): ... + + """ + + # Bundled as a class solely for ease of use: packaging, doc strings, + # importability. + + @staticmethod + def appender(fn): + """Tag the method as the collection appender. + + The appender method is called with one positional argument: the value + to append. The method will be automatically decorated with 'adds(1)' + if not already decorated:: + + @collection.appender + def add(self, append): ... + + # or, equivalently + @collection.appender + @collection.adds(1) + def add(self, append): ... + + # for mapping type, an 'append' may kick out a previous value + # that occupies that slot. consider d['a'] = 'foo'- any previous + # value in d['a'] is discarded. + @collection.appender + @collection.replaces(1) + def add(self, entity): + key = some_key_func(entity) + previous = None + if key in self: + previous = self[key] + self[key] = entity + return previous + + If the value to append is not allowed in the collection, you may + raise an exception. Something to remember is that the appender + will be called for each object mapped by a database query. If the + database contains rows that violate your collection semantics, you + will need to get creative to fix the problem, as access via the + collection will not work. + + If the appender method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + fn._sa_instrument_role = "appender" + return fn + + @staticmethod + def remover(fn): + """Tag the method as the collection remover. + + The remover method is called with one positional argument: the value + to remove. The method will be automatically decorated with + :meth:`removes_return` if not already decorated:: + + @collection.remover + def zap(self, entity): ... + + # or, equivalently + @collection.remover + @collection.removes_return() + def zap(self, ): ... + + If the value to remove is not present in the collection, you may + raise an exception or return None to ignore the error. + + If the remove method is internally instrumented, you must also + receive the keyword argument '_sa_initiator' and ensure its + promulgation to collection events. + + """ + fn._sa_instrument_role = "remover" + return fn + + @staticmethod + def iterator(fn): + """Tag the method as the collection remover. + + The iterator method is called with no arguments. It is expected to + return an iterator over all collection members:: + + @collection.iterator + def __iter__(self): ... + + """ + fn._sa_instrument_role = "iterator" + return fn + + @staticmethod + def internally_instrumented(fn): + """Tag the method as instrumented. + + This tag will prevent any decoration from being applied to the + method. Use this if you are orchestrating your own calls to + :func:`.collection_adapter` in one of the basic SQLAlchemy + interface methods, or to prevent an automatic ABC method + decoration from wrapping your implementation:: + + # normally an 'extend' method on a list-like class would be + # automatically intercepted and re-implemented in terms of + # SQLAlchemy events and append(). your implementation will + # never be called, unless: + @collection.internally_instrumented + def extend(self, items): ... + + """ + fn._sa_instrumented = True + return fn + + @staticmethod + @util.deprecated( + "1.3", + "The :meth:`.collection.converter` handler is deprecated and will " + "be removed in a future release. Please refer to the " + ":class:`.AttributeEvents.bulk_replace` listener interface in " + "conjunction with the :func:`.event.listen` function.", + ) + def converter(fn): + """Tag the method as the collection converter. + + This optional method will be called when a collection is being + replaced entirely, as in:: + + myobj.acollection = [newvalue1, newvalue2] + + The converter method will receive the object being assigned and should + return an iterable of values suitable for use by the ``appender`` + method. A converter must not assign values or mutate the collection, + its sole job is to adapt the value the user provides into an iterable + of values for the ORM's use. + + The default converter implementation will use duck-typing to do the + conversion. A dict-like collection will be convert into an iterable + of dictionary values, and other types will simply be iterated:: + + @collection.converter + def convert(self, other): ... + + If the duck-typing of the object does not match the type of this + collection, a TypeError is raised. + + Supply an implementation of this method if you want to expand the + range of possible types that can be assigned in bulk or perform + validation on the values about to be assigned. + + """ + fn._sa_instrument_role = "converter" + return fn + + @staticmethod + def adds(arg): + """Mark the method as adding an entity to the collection. + + Adds "add to collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value. Arguments can be specified positionally (i.e. integer) or by + name:: + + @collection.adds(1) + def push(self, item): ... + + @collection.adds('entity') + def do_stuff(self, thing, entity=None): ... + + """ + + def decorator(fn): + fn._sa_instrument_before = ("fire_append_event", arg) + return fn + + return decorator + + @staticmethod + def replaces(arg): + """Mark the method as replacing an entity in the collection. + + Adds "add to collection" and "remove from collection" handling to + the method. The decorator argument indicates which method argument + holds the SQLAlchemy-relevant value to be added, and return value, if + any will be considered the value to remove. + + Arguments can be specified positionally (i.e. integer) or by name:: + + @collection.replaces(2) + def __setitem__(self, index, item): ... + + """ + + def decorator(fn): + fn._sa_instrument_before = ("fire_append_event", arg) + fn._sa_instrument_after = "fire_remove_event" + return fn + + return decorator + + @staticmethod + def removes(arg): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The decorator + argument indicates which method argument holds the SQLAlchemy-relevant + value to be removed. Arguments can be specified positionally (i.e. + integer) or by name:: + + @collection.removes(1) + def zap(self, item): ... + + For methods where the value to remove is not known at call-time, use + collection.removes_return. + + """ + + def decorator(fn): + fn._sa_instrument_before = ("fire_remove_event", arg) + return fn + + return decorator + + @staticmethod + def removes_return(): + """Mark the method as removing an entity in the collection. + + Adds "remove from collection" handling to the method. The return + value of the method, if any, is considered the value to remove. The + method arguments are not inspected:: + + @collection.removes_return() + def pop(self): ... + + For methods where the value to remove is known at call-time, use + collection.remove. + + """ + + def decorator(fn): + fn._sa_instrument_after = "fire_remove_event" + return fn + + return decorator + + +if TYPE_CHECKING: + + def collection_adapter(collection: Collection[Any]) -> CollectionAdapter: + """Fetch the :class:`.CollectionAdapter` for a collection.""" + +else: + collection_adapter = operator.attrgetter("_sa_adapter") + + +class CollectionAdapter: + """Bridges between the ORM and arbitrary Python collections. + + Proxies base-level collection operations (append, remove, iterate) + to the underlying Python collection, and emits add/remove events for + entities entering or leaving the collection. + + The ORM uses :class:`.CollectionAdapter` exclusively for interaction with + entity collections. + + + """ + + __slots__ = ( + "attr", + "_key", + "_data", + "owner_state", + "_converter", + "invalidated", + "empty", + ) + + attr: CollectionAttributeImpl + _key: str + + # this is actually a weakref; see note in constructor + _data: Callable[..., _AdaptedCollectionProtocol] + + owner_state: InstanceState[Any] + _converter: _CollectionConverterProtocol + invalidated: bool + empty: bool + + def __init__( + self, + attr: CollectionAttributeImpl, + owner_state: InstanceState[Any], + data: _AdaptedCollectionProtocol, + ): + self.attr = attr + self._key = attr.key + + # this weakref stays referenced throughout the lifespan of + # CollectionAdapter. so while the weakref can return None, this + # is realistically only during garbage collection of this object, so + # we type this as a callable that returns _AdaptedCollectionProtocol + # in all cases. + self._data = weakref.ref(data) # type: ignore + + self.owner_state = owner_state + data._sa_adapter = self + self._converter = data._sa_converter + self.invalidated = False + self.empty = False + + def _warn_invalidated(self) -> None: + util.warn("This collection has been invalidated.") + + @property + def data(self) -> _AdaptedCollectionProtocol: + "The entity collection being adapted." + return self._data() + + @property + def _referenced_by_owner(self) -> bool: + """return True if the owner state still refers to this collection. + + This will return False within a bulk replace operation, + where this collection is the one being replaced. + + """ + return self.owner_state.dict[self._key] is self._data() + + def bulk_appender(self): + return self._data()._sa_appender + + def append_with_event( + self, item: Any, initiator: Optional[AttributeEventToken] = None + ) -> None: + """Add an entity to the collection, firing mutation events.""" + + self._data()._sa_appender(item, _sa_initiator=initiator) + + def _set_empty(self, user_data): + assert ( + not self.empty + ), "This collection adapter is already in the 'empty' state" + self.empty = True + self.owner_state._empty_collections[self._key] = user_data + + def _reset_empty(self) -> None: + assert ( + self.empty + ), "This collection adapter is not in the 'empty' state" + self.empty = False + self.owner_state.dict[self._key] = ( + self.owner_state._empty_collections.pop(self._key) + ) + + def _refuse_empty(self) -> NoReturn: + raise sa_exc.InvalidRequestError( + "This is a special 'empty' collection which cannot accommodate " + "internal mutation operations" + ) + + def append_without_event(self, item: Any) -> None: + """Add or restore an entity to the collection, firing no events.""" + + if self.empty: + self._refuse_empty() + self._data()._sa_appender(item, _sa_initiator=False) + + def append_multiple_without_event(self, items: Iterable[Any]) -> None: + """Add or restore an entity to the collection, firing no events.""" + if self.empty: + self._refuse_empty() + appender = self._data()._sa_appender + for item in items: + appender(item, _sa_initiator=False) + + def bulk_remover(self): + return self._data()._sa_remover + + def remove_with_event( + self, item: Any, initiator: Optional[AttributeEventToken] = None + ) -> None: + """Remove an entity from the collection, firing mutation events.""" + self._data()._sa_remover(item, _sa_initiator=initiator) + + def remove_without_event(self, item: Any) -> None: + """Remove an entity from the collection, firing no events.""" + if self.empty: + self._refuse_empty() + self._data()._sa_remover(item, _sa_initiator=False) + + def clear_with_event( + self, initiator: Optional[AttributeEventToken] = None + ) -> None: + """Empty the collection, firing a mutation event for each entity.""" + + if self.empty: + self._refuse_empty() + remover = self._data()._sa_remover + for item in list(self): + remover(item, _sa_initiator=initiator) + + def clear_without_event(self) -> None: + """Empty the collection, firing no events.""" + + if self.empty: + self._refuse_empty() + remover = self._data()._sa_remover + for item in list(self): + remover(item, _sa_initiator=False) + + def __iter__(self): + """Iterate over entities in the collection.""" + + return iter(self._data()._sa_iterator()) + + def __len__(self): + """Count entities in the collection.""" + return len(list(self._data()._sa_iterator())) + + def __bool__(self): + return True + + def _fire_append_wo_mutation_event_bulk( + self, items, initiator=None, key=NO_KEY + ): + if not items: + return + + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + for item in items: + self.attr.fire_append_wo_mutation_event( + self.owner_state, + self.owner_state.dict, + item, + initiator, + key, + ) + + def fire_append_wo_mutation_event(self, item, initiator=None, key=NO_KEY): + """Notify that a entity is entering the collection but is already + present. + + + Initiator is a token owned by the InstrumentedAttribute that + initiated the membership mutation, and should be left as None + unless you are passing along an initiator value from a chained + operation. + + .. versionadded:: 1.4.15 + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + return self.attr.fire_append_wo_mutation_event( + self.owner_state, self.owner_state.dict, item, initiator, key + ) + else: + return item + + def fire_append_event(self, item, initiator=None, key=NO_KEY): + """Notify that a entity has entered the collection. + + Initiator is a token owned by the InstrumentedAttribute that + initiated the membership mutation, and should be left as None + unless you are passing along an initiator value from a chained + operation. + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + return self.attr.fire_append_event( + self.owner_state, self.owner_state.dict, item, initiator, key + ) + else: + return item + + def _fire_remove_event_bulk(self, items, initiator=None, key=NO_KEY): + if not items: + return + + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + for item in items: + self.attr.fire_remove_event( + self.owner_state, + self.owner_state.dict, + item, + initiator, + key, + ) + + def fire_remove_event(self, item, initiator=None, key=NO_KEY): + """Notify that a entity has been removed from the collection. + + Initiator is the InstrumentedAttribute that initiated the membership + mutation, and should be left as None unless you are passing along + an initiator value from a chained operation. + + """ + if initiator is not False: + if self.invalidated: + self._warn_invalidated() + + if self.empty: + self._reset_empty() + + self.attr.fire_remove_event( + self.owner_state, self.owner_state.dict, item, initiator, key + ) + + def fire_pre_remove_event(self, initiator=None, key=NO_KEY): + """Notify that an entity is about to be removed from the collection. + + Only called if the entity cannot be removed after calling + fire_remove_event(). + + """ + if self.invalidated: + self._warn_invalidated() + self.attr.fire_pre_remove_event( + self.owner_state, + self.owner_state.dict, + initiator=initiator, + key=key, + ) + + def __getstate__(self): + return { + "key": self._key, + "owner_state": self.owner_state, + "owner_cls": self.owner_state.class_, + "data": self.data, + "invalidated": self.invalidated, + "empty": self.empty, + } + + def __setstate__(self, d): + self._key = d["key"] + self.owner_state = d["owner_state"] + + # see note in constructor regarding this type: ignore + self._data = weakref.ref(d["data"]) # type: ignore + + self._converter = d["data"]._sa_converter + d["data"]._sa_adapter = self + self.invalidated = d["invalidated"] + self.attr = getattr(d["owner_cls"], self._key).impl + self.empty = d.get("empty", False) + + +def bulk_replace(values, existing_adapter, new_adapter, initiator=None): + """Load a new collection, firing events based on prior like membership. + + Appends instances in ``values`` onto the ``new_adapter``. Events will be + fired for any instance not present in the ``existing_adapter``. Any + instances in ``existing_adapter`` not present in ``values`` will have + remove events fired upon them. + + :param values: An iterable of collection member instances + + :param existing_adapter: A :class:`.CollectionAdapter` of + instances to be replaced + + :param new_adapter: An empty :class:`.CollectionAdapter` + to load with ``values`` + + + """ + + assert isinstance(values, list) + + idset = util.IdentitySet + existing_idset = idset(existing_adapter or ()) + constants = existing_idset.intersection(values or ()) + additions = idset(values or ()).difference(constants) + removals = existing_idset.difference(constants) + + appender = new_adapter.bulk_appender() + + for member in values or (): + if member in additions: + appender(member, _sa_initiator=initiator) + elif member in constants: + appender(member, _sa_initiator=False) + + if existing_adapter: + existing_adapter._fire_append_wo_mutation_event_bulk( + constants, initiator=initiator + ) + existing_adapter._fire_remove_event_bulk(removals, initiator=initiator) + + +def prepare_instrumentation( + factory: Union[Type[Collection[Any]], _CollectionFactoryType], +) -> _CollectionFactoryType: + """Prepare a callable for future use as a collection class factory. + + Given a collection class factory (either a type or no-arg callable), + return another factory that will produce compatible instances when + called. + + This function is responsible for converting collection_class=list + into the run-time behavior of collection_class=InstrumentedList. + + """ + + impl_factory: _CollectionFactoryType + + # Convert a builtin to 'Instrumented*' + if factory in __canned_instrumentation: + impl_factory = __canned_instrumentation[factory] + else: + impl_factory = cast(_CollectionFactoryType, factory) + + cls: Union[_CollectionFactoryType, Type[Collection[Any]]] + + # Create a specimen + cls = type(impl_factory()) + + # Did factory callable return a builtin? + if cls in __canned_instrumentation: + # if so, just convert. + # in previous major releases, this codepath wasn't working and was + # not covered by tests. prior to that it supplied a "wrapper" + # function that would return the class, though the rationale for this + # case is not known + impl_factory = __canned_instrumentation[cls] + cls = type(impl_factory()) + + # Instrument the class if needed. + if __instrumentation_mutex.acquire(): + try: + if getattr(cls, "_sa_instrumented", None) != id(cls): + _instrument_class(cls) + finally: + __instrumentation_mutex.release() + + return impl_factory + + +def _instrument_class(cls): + """Modify methods in a class and install instrumentation.""" + + # In the normal call flow, a request for any of the 3 basic collection + # types is transformed into one of our trivial subclasses + # (e.g. InstrumentedList). Catch anything else that sneaks in here... + if cls.__module__ == "__builtin__": + raise sa_exc.ArgumentError( + "Can not instrument a built-in type. Use a " + "subclass, even a trivial one." + ) + + roles, methods = _locate_roles_and_methods(cls) + + _setup_canned_roles(cls, roles, methods) + + _assert_required_roles(cls, roles, methods) + + _set_collection_attributes(cls, roles, methods) + + +def _locate_roles_and_methods(cls): + """search for _sa_instrument_role-decorated methods in + method resolution order, assign to roles. + + """ + + roles: Dict[str, str] = {} + methods: Dict[str, Tuple[Optional[str], Optional[int], Optional[str]]] = {} + + for supercls in cls.__mro__: + for name, method in vars(supercls).items(): + if not callable(method): + continue + + # note role declarations + if hasattr(method, "_sa_instrument_role"): + role = method._sa_instrument_role + assert role in ( + "appender", + "remover", + "iterator", + "converter", + ) + roles.setdefault(role, name) + + # transfer instrumentation requests from decorated function + # to the combined queue + before: Optional[Tuple[str, int]] = None + after: Optional[str] = None + + if hasattr(method, "_sa_instrument_before"): + op, argument = method._sa_instrument_before + assert op in ("fire_append_event", "fire_remove_event") + before = op, argument + if hasattr(method, "_sa_instrument_after"): + op = method._sa_instrument_after + assert op in ("fire_append_event", "fire_remove_event") + after = op + if before: + methods[name] = before + (after,) + elif after: + methods[name] = None, None, after + return roles, methods + + +def _setup_canned_roles(cls, roles, methods): + """see if this class has "canned" roles based on a known + collection type (dict, set, list). Apply those roles + as needed to the "roles" dictionary, and also + prepare "decorator" methods + + """ + collection_type = util.duck_type_collection(cls) + if collection_type in __interfaces: + assert collection_type is not None + canned_roles, decorators = __interfaces[collection_type] + for role, name in canned_roles.items(): + roles.setdefault(role, name) + + # apply ABC auto-decoration to methods that need it + for method, decorator in decorators.items(): + fn = getattr(cls, method, None) + if ( + fn + and method not in methods + and not hasattr(fn, "_sa_instrumented") + ): + setattr(cls, method, decorator(fn)) + + +def _assert_required_roles(cls, roles, methods): + """ensure all roles are present, and apply implicit instrumentation if + needed + + """ + if "appender" not in roles or not hasattr(cls, roles["appender"]): + raise sa_exc.ArgumentError( + "Type %s must elect an appender method to be " + "a collection class" % cls.__name__ + ) + elif roles["appender"] not in methods and not hasattr( + getattr(cls, roles["appender"]), "_sa_instrumented" + ): + methods[roles["appender"]] = ("fire_append_event", 1, None) + + if "remover" not in roles or not hasattr(cls, roles["remover"]): + raise sa_exc.ArgumentError( + "Type %s must elect a remover method to be " + "a collection class" % cls.__name__ + ) + elif roles["remover"] not in methods and not hasattr( + getattr(cls, roles["remover"]), "_sa_instrumented" + ): + methods[roles["remover"]] = ("fire_remove_event", 1, None) + + if "iterator" not in roles or not hasattr(cls, roles["iterator"]): + raise sa_exc.ArgumentError( + "Type %s must elect an iterator method to be " + "a collection class" % cls.__name__ + ) + + +def _set_collection_attributes(cls, roles, methods): + """apply ad-hoc instrumentation from decorators, class-level defaults + and implicit role declarations + + """ + for method_name, (before, argument, after) in methods.items(): + setattr( + cls, + method_name, + _instrument_membership_mutator( + getattr(cls, method_name), before, argument, after + ), + ) + # intern the role map + for role, method_name in roles.items(): + setattr(cls, "_sa_%s" % role, getattr(cls, method_name)) + + cls._sa_adapter = None + + if not hasattr(cls, "_sa_converter"): + cls._sa_converter = None + cls._sa_instrumented = id(cls) + + +def _instrument_membership_mutator(method, before, argument, after): + """Route method args and/or return value through the collection + adapter.""" + # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' + if before: + fn_args = list( + util.flatten_iterator(inspect_getfullargspec(method)[0]) + ) + if isinstance(argument, int): + pos_arg = argument + named_arg = len(fn_args) > argument and fn_args[argument] or None + else: + if argument in fn_args: + pos_arg = fn_args.index(argument) + else: + pos_arg = None + named_arg = argument + del fn_args + + def wrapper(*args, **kw): + if before: + if pos_arg is None: + if named_arg not in kw: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument + ) + value = kw[named_arg] + else: + if len(args) > pos_arg: + value = args[pos_arg] + elif named_arg in kw: + value = kw[named_arg] + else: + raise sa_exc.ArgumentError( + "Missing argument %s" % argument + ) + + initiator = kw.pop("_sa_initiator", None) + if initiator is False: + executor = None + else: + executor = args[0]._sa_adapter + + if before and executor: + getattr(executor, before)(value, initiator) + + if not after or not executor: + return method(*args, **kw) + else: + res = method(*args, **kw) + if res is not None: + getattr(executor, after)(res, initiator) + return res + + wrapper._sa_instrumented = True # type: ignore[attr-defined] + if hasattr(method, "_sa_instrument_role"): + wrapper._sa_instrument_role = method._sa_instrument_role # type: ignore[attr-defined] # noqa: E501 + wrapper.__name__ = method.__name__ + wrapper.__doc__ = method.__doc__ + return wrapper + + +def __set_wo_mutation(collection, item, _sa_initiator=None): + """Run set wo mutation events. + + The collection is not mutated. + + """ + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + executor.fire_append_wo_mutation_event( + item, _sa_initiator, key=None + ) + + +def __set(collection, item, _sa_initiator, key): + """Run set events. + + This event always occurs before the collection is actually mutated. + + """ + + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + item = executor.fire_append_event(item, _sa_initiator, key=key) + return item + + +def __del(collection, item, _sa_initiator, key): + """Run del events. + + This event occurs before the collection is actually mutated, *except* + in the case of a pop operation, in which case it occurs afterwards. + For pop operations, the __before_pop hook is called before the + operation occurs. + + """ + if _sa_initiator is not False: + executor = collection._sa_adapter + if executor: + executor.fire_remove_event(item, _sa_initiator, key=key) + + +def __before_pop(collection, _sa_initiator=None): + """An event which occurs on a before a pop() operation occurs.""" + executor = collection._sa_adapter + if executor: + executor.fire_pre_remove_event(_sa_initiator) + + +def _list_decorators() -> Dict[str, Callable[[_FN], _FN]]: + """Tailored instrumentation wrappers for any list-like class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(list, fn.__name__).__doc__ + + def append(fn): + def append(self, item, _sa_initiator=None): + item = __set(self, item, _sa_initiator, NO_KEY) + fn(self, item) + + _tidy(append) + return append + + def remove(fn): + def remove(self, value, _sa_initiator=None): + __del(self, value, _sa_initiator, NO_KEY) + # testlib.pragma exempt:__eq__ + fn(self, value) + + _tidy(remove) + return remove + + def insert(fn): + def insert(self, index, value): + value = __set(self, value, None, index) + fn(self, index, value) + + _tidy(insert) + return insert + + def __setitem__(fn): + def __setitem__(self, index, value): + if not isinstance(index, slice): + existing = self[index] + if existing is not None: + __del(self, existing, None, index) + value = __set(self, value, None, index) + fn(self, index, value) + else: + # slice assignment requires __delitem__, insert, __len__ + step = index.step or 1 + start = index.start or 0 + if start < 0: + start += len(self) + if index.stop is not None: + stop = index.stop + else: + stop = len(self) + if stop < 0: + stop += len(self) + + if step == 1: + if value is self: + return + for i in range(start, stop, step): + if len(self) > start: + del self[start] + + for i, item in enumerate(value): + self.insert(i + start, item) + else: + rng = list(range(start, stop, step)) + if len(value) != len(rng): + raise ValueError( + "attempt to assign sequence of size %s to " + "extended slice of size %s" + % (len(value), len(rng)) + ) + for i, item in zip(rng, value): + self.__setitem__(i, item) + + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, index): + if not isinstance(index, slice): + item = self[index] + __del(self, item, None, index) + fn(self, index) + else: + # slice deletion requires __getslice__ and a slice-groking + # __getitem__ for stepped deletion + # note: not breaking this into atomic dels + for item in self[index]: + __del(self, item, None, index) + fn(self, index) + + _tidy(__delitem__) + return __delitem__ + + def extend(fn): + def extend(self, iterable): + for value in list(iterable): + self.append(value) + + _tidy(extend) + return extend + + def __iadd__(fn): + def __iadd__(self, iterable): + # list.__iadd__ takes any iterable and seems to let TypeError + # raise as-is instead of returning NotImplemented + for value in list(iterable): + self.append(value) + return self + + _tidy(__iadd__) + return __iadd__ + + def pop(fn): + def pop(self, index=-1): + __before_pop(self) + item = fn(self, index) + __del(self, item, None, index) + return item + + _tidy(pop) + return pop + + def clear(fn): + def clear(self, index=-1): + for item in self: + __del(self, item, None, index) + fn(self) + + _tidy(clear) + return clear + + # __imul__ : not wrapping this. all members of the collection are already + # present, so no need to fire appends... wrapping it with an explicit + # decorator is still possible, so events on *= can be had if they're + # desired. hard to imagine a use case for __imul__, though. + + l = locals().copy() + l.pop("_tidy") + return l + + +def _dict_decorators() -> Dict[str, Callable[[_FN], _FN]]: + """Tailored instrumentation wrappers for any dict-like mapping class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(dict, fn.__name__).__doc__ + + def __setitem__(fn): + def __setitem__(self, key, value, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator, key) + value = __set(self, value, _sa_initiator, key) + fn(self, key, value) + + _tidy(__setitem__) + return __setitem__ + + def __delitem__(fn): + def __delitem__(self, key, _sa_initiator=None): + if key in self: + __del(self, self[key], _sa_initiator, key) + fn(self, key) + + _tidy(__delitem__) + return __delitem__ + + def clear(fn): + def clear(self): + for key in self: + __del(self, self[key], None, key) + fn(self) + + _tidy(clear) + return clear + + def pop(fn): + def pop(self, key, default=NO_ARG): + __before_pop(self) + _to_del = key in self + if default is NO_ARG: + item = fn(self, key) + else: + item = fn(self, key, default) + if _to_del: + __del(self, item, None, key) + return item + + _tidy(pop) + return pop + + def popitem(fn): + def popitem(self): + __before_pop(self) + item = fn(self) + __del(self, item[1], None, 1) + return item + + _tidy(popitem) + return popitem + + def setdefault(fn): + def setdefault(self, key, default=None): + if key not in self: + self.__setitem__(key, default) + return default + else: + value = self.__getitem__(key) + if value is default: + __set_wo_mutation(self, value, None) + + return value + + _tidy(setdefault) + return setdefault + + def update(fn): + def update(self, __other=NO_ARG, **kw): + if __other is not NO_ARG: + if hasattr(__other, "keys"): + for key in list(__other): + if key not in self or self[key] is not __other[key]: + self[key] = __other[key] + else: + __set_wo_mutation(self, __other[key], None) + else: + for key, value in __other: + if key not in self or self[key] is not value: + self[key] = value + else: + __set_wo_mutation(self, value, None) + for key in kw: + if key not in self or self[key] is not kw[key]: + self[key] = kw[key] + else: + __set_wo_mutation(self, kw[key], None) + + _tidy(update) + return update + + l = locals().copy() + l.pop("_tidy") + return l + + +_set_binop_bases = (set, frozenset) + + +def _set_binops_check_strict(self: Any, obj: Any) -> bool: + """Allow only set, frozenset and self.__class__-derived + objects in binops.""" + return isinstance(obj, _set_binop_bases + (self.__class__,)) + + +def _set_binops_check_loose(self: Any, obj: Any) -> bool: + """Allow anything set-like to participate in set binops.""" + return ( + isinstance(obj, _set_binop_bases + (self.__class__,)) + or util.duck_type_collection(obj) == set + ) + + +def _set_decorators() -> Dict[str, Callable[[_FN], _FN]]: + """Tailored instrumentation wrappers for any set-like class.""" + + def _tidy(fn): + fn._sa_instrumented = True + fn.__doc__ = getattr(set, fn.__name__).__doc__ + + def add(fn): + def add(self, value, _sa_initiator=None): + if value not in self: + value = __set(self, value, _sa_initiator, NO_KEY) + else: + __set_wo_mutation(self, value, _sa_initiator) + # testlib.pragma exempt:__hash__ + fn(self, value) + + _tidy(add) + return add + + def discard(fn): + def discard(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator, NO_KEY) + # testlib.pragma exempt:__hash__ + fn(self, value) + + _tidy(discard) + return discard + + def remove(fn): + def remove(self, value, _sa_initiator=None): + # testlib.pragma exempt:__hash__ + if value in self: + __del(self, value, _sa_initiator, NO_KEY) + # testlib.pragma exempt:__hash__ + fn(self, value) + + _tidy(remove) + return remove + + def pop(fn): + def pop(self): + __before_pop(self) + item = fn(self) + # for set in particular, we have no way to access the item + # that will be popped before pop is called. + __del(self, item, None, NO_KEY) + return item + + _tidy(pop) + return pop + + def clear(fn): + def clear(self): + for item in list(self): + self.remove(item) + + _tidy(clear) + return clear + + def update(fn): + def update(self, value): + for item in value: + self.add(item) + + _tidy(update) + return update + + def __ior__(fn): + def __ior__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.add(item) + return self + + _tidy(__ior__) + return __ior__ + + def difference_update(fn): + def difference_update(self, value): + for item in value: + self.discard(item) + + _tidy(difference_update) + return difference_update + + def __isub__(fn): + def __isub__(self, value): + if not _set_binops_check_strict(self, value): + return NotImplemented + for item in value: + self.discard(item) + return self + + _tidy(__isub__) + return __isub__ + + def intersection_update(fn): + def intersection_update(self, other): + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + + _tidy(intersection_update) + return intersection_update + + def __iand__(fn): + def __iand__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.intersection(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + + _tidy(__iand__) + return __iand__ + + def symmetric_difference_update(fn): + def symmetric_difference_update(self, other): + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + + _tidy(symmetric_difference_update) + return symmetric_difference_update + + def __ixor__(fn): + def __ixor__(self, other): + if not _set_binops_check_strict(self, other): + return NotImplemented + want, have = self.symmetric_difference(other), set(self) + remove, add = have - want, want - have + + for item in remove: + self.remove(item) + for item in add: + self.add(item) + return self + + _tidy(__ixor__) + return __ixor__ + + l = locals().copy() + l.pop("_tidy") + return l + + +class InstrumentedList(List[_T]): + """An instrumented version of the built-in list.""" + + +class InstrumentedSet(Set[_T]): + """An instrumented version of the built-in set.""" + + +class InstrumentedDict(Dict[_KT, _VT]): + """An instrumented version of the built-in dict.""" + + +__canned_instrumentation: util.immutabledict[Any, _CollectionFactoryType] = ( + util.immutabledict( + { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } + ) +) + +__interfaces: util.immutabledict[ + Any, + Tuple[ + Dict[str, str], + Dict[str, Callable[..., Any]], + ], +] = util.immutabledict( + { + list: ( + { + "appender": "append", + "remover": "remove", + "iterator": "__iter__", + }, + _list_decorators(), + ), + set: ( + {"appender": "add", "remover": "remove", "iterator": "__iter__"}, + _set_decorators(), + ), + # decorators are required for dicts and object collections. + dict: ({"iterator": "values"}, _dict_decorators()), + } +) + + +def __go(lcls): + global keyfunc_mapping, mapped_collection + global column_keyed_dict, column_mapped_collection + global MappedCollection, KeyFuncDict + global attribute_keyed_dict, attribute_mapped_collection + + from .mapped_collection import keyfunc_mapping + from .mapped_collection import column_keyed_dict + from .mapped_collection import attribute_keyed_dict + from .mapped_collection import KeyFuncDict + + from .mapped_collection import mapped_collection + from .mapped_collection import column_mapped_collection + from .mapped_collection import attribute_mapped_collection + from .mapped_collection import MappedCollection + + # ensure instrumentation is associated with + # these built-in classes; if a user-defined class + # subclasses these and uses @internally_instrumented, + # the superclass is otherwise not instrumented. + # see [ticket:2406]. + _instrument_class(InstrumentedList) + _instrument_class(InstrumentedSet) + _instrument_class(KeyFuncDict) + + +__go(locals()) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/context.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/context.py new file mode 100644 index 0000000..3056016 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/context.py @@ -0,0 +1,3243 @@ +# orm/context.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from __future__ import annotations + +import itertools +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import interfaces +from . import loading +from .base import _is_aliased_class +from .interfaces import ORMColumnDescription +from .interfaces import ORMColumnsClauseRole +from .path_registry import PathRegistry +from .util import _entity_corresponds_to +from .util import _ORMJoin +from .util import _TraceAdaptRole +from .util import AliasedClass +from .util import Bundle +from .util import ORMAdapter +from .util import ORMStatementAdapter +from .. import exc as sa_exc +from .. import future +from .. import inspect +from .. import sql +from .. import util +from ..sql import coercions +from ..sql import expression +from ..sql import roles +from ..sql import util as sql_util +from ..sql import visitors +from ..sql._typing import _TP +from ..sql._typing import is_dml +from ..sql._typing import is_insert_update +from ..sql._typing import is_select_base +from ..sql.base import _select_iterables +from ..sql.base import CacheableOptions +from ..sql.base import CompileState +from ..sql.base import Executable +from ..sql.base import Generative +from ..sql.base import Options +from ..sql.dml import UpdateBase +from ..sql.elements import GroupedElement +from ..sql.elements import TextClause +from ..sql.selectable import CompoundSelectState +from ..sql.selectable import LABEL_STYLE_DISAMBIGUATE_ONLY +from ..sql.selectable import LABEL_STYLE_NONE +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import Select +from ..sql.selectable import SelectLabelStyle +from ..sql.selectable import SelectState +from ..sql.selectable import TypedReturnsRows +from ..sql.visitors import InternalTraversal + +if TYPE_CHECKING: + from ._typing import _InternalEntityType + from ._typing import OrmExecuteOptionsParameter + from .loading import PostLoad + from .mapper import Mapper + from .query import Query + from .session import _BindArguments + from .session import Session + from ..engine import Result + from ..engine.interfaces import _CoreSingleExecuteParams + from ..sql._typing import _ColumnsClauseArgument + from ..sql.compiler import SQLCompiler + from ..sql.dml import _DMLTableElement + from ..sql.elements import ColumnElement + from ..sql.selectable import _JoinTargetElement + from ..sql.selectable import _LabelConventionCallable + from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import SelectBase + from ..sql.type_api import TypeEngine + +_T = TypeVar("_T", bound=Any) +_path_registry = PathRegistry.root + +_EMPTY_DICT = util.immutabledict() + + +LABEL_STYLE_LEGACY_ORM = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM + + +class QueryContext: + __slots__ = ( + "top_level_context", + "compile_state", + "query", + "params", + "load_options", + "bind_arguments", + "execution_options", + "session", + "autoflush", + "populate_existing", + "invoke_all_eagers", + "version_check", + "refresh_state", + "create_eager_joins", + "propagated_loader_options", + "attributes", + "runid", + "partials", + "post_load_paths", + "identity_token", + "yield_per", + "loaders_require_buffering", + "loaders_require_uniquing", + ) + + runid: int + post_load_paths: Dict[PathRegistry, PostLoad] + compile_state: ORMCompileState + + class default_load_options(Options): + _only_return_tuples = False + _populate_existing = False + _version_check = False + _invoke_all_eagers = True + _autoflush = True + _identity_token = None + _yield_per = None + _refresh_state = None + _lazy_loaded_from = None + _legacy_uniquing = False + _sa_top_level_orm_context = None + _is_user_refresh = False + + def __init__( + self, + compile_state: CompileState, + statement: Union[Select[Any], FromStatement[Any]], + params: _CoreSingleExecuteParams, + session: Session, + load_options: Union[ + Type[QueryContext.default_load_options], + QueryContext.default_load_options, + ], + execution_options: Optional[OrmExecuteOptionsParameter] = None, + bind_arguments: Optional[_BindArguments] = None, + ): + self.load_options = load_options + self.execution_options = execution_options or _EMPTY_DICT + self.bind_arguments = bind_arguments or _EMPTY_DICT + self.compile_state = compile_state + self.query = statement + self.session = session + self.loaders_require_buffering = False + self.loaders_require_uniquing = False + self.params = params + self.top_level_context = load_options._sa_top_level_orm_context + + cached_options = compile_state.select_statement._with_options + uncached_options = statement._with_options + + # see issue #7447 , #8399 for some background + # propagated loader options will be present on loaded InstanceState + # objects under state.load_options and are typically used by + # LazyLoader to apply options to the SELECT statement it emits. + # For compile state options (i.e. loader strategy options), these + # need to line up with the ".load_path" attribute which in + # loader.py is pulled from context.compile_state.current_path. + # so, this means these options have to be the ones from the + # *cached* statement that's travelling with compile_state, not the + # *current* statement which won't match up for an ad-hoc + # AliasedClass + self.propagated_loader_options = tuple( + opt._adapt_cached_option_to_uncached_option(self, uncached_opt) + for opt, uncached_opt in zip(cached_options, uncached_options) + if opt.propagate_to_loaders + ) + + self.attributes = dict(compile_state.attributes) + + self.autoflush = load_options._autoflush + self.populate_existing = load_options._populate_existing + self.invoke_all_eagers = load_options._invoke_all_eagers + self.version_check = load_options._version_check + self.refresh_state = load_options._refresh_state + self.yield_per = load_options._yield_per + self.identity_token = load_options._identity_token + + def _get_top_level_context(self) -> QueryContext: + return self.top_level_context or self + + +_orm_load_exec_options = util.immutabledict( + {"_result_disable_adapt_to_context": True} +) + + +class AbstractORMCompileState(CompileState): + is_dml_returning = False + + def _init_global_attributes( + self, statement, compiler, *, toplevel, process_criteria_for_toplevel + ): + self.attributes = {} + + if compiler is None: + # this is the legacy / testing only ORM _compile_state() use case. + # there is no need to apply criteria options for this. + self.global_attributes = ga = {} + assert toplevel + return + else: + self.global_attributes = ga = compiler._global_attributes + + if toplevel: + ga["toplevel_orm"] = True + + if process_criteria_for_toplevel: + for opt in statement._with_options: + if opt._is_criteria_option: + opt.process_compile_state(self) + + return + elif ga.get("toplevel_orm", False): + return + + stack_0 = compiler.stack[0] + + try: + toplevel_stmt = stack_0["selectable"] + except KeyError: + pass + else: + for opt in toplevel_stmt._with_options: + if opt._is_compile_state and opt._is_criteria_option: + opt.process_compile_state(self) + + ga["toplevel_orm"] = True + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> AbstractORMCompileState: + """Create a context for a statement given a :class:`.Compiler`. + + This method is always invoked in the context of SQLCompiler.process(). + + For a Select object, this would be invoked from + SQLCompiler.visit_select(). For the special FromStatement object used + by Query to indicate "Query.from_statement()", this is called by + FromStatement._compiler_dispatch() that would be called by + SQLCompiler.process(). + """ + return super().create_for_statement(statement, compiler, **kw) + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + raise NotImplementedError() + + @classmethod + def orm_execute_statement( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + conn, + ) -> Result: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + return cls.orm_setup_cursor_result( + session, + statement, + params, + execution_options, + bind_arguments, + result, + ) + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + raise NotImplementedError() + + +class AutoflushOnlyORMCompileState(AbstractORMCompileState): + """ORM compile state that is a passthrough, except for autoflush.""" + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + # consume result-level load_options. These may have been set up + # in an ORMExecuteState hook + ( + load_options, + execution_options, + ) = QueryContext.default_load_options.from_execution_options( + "_sa_orm_load_options", + { + "autoflush", + }, + execution_options, + statement._execution_options, + ) + + if not is_pre_event and load_options._autoflush: + session._autoflush() + + return statement, execution_options + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + return result + + +class ORMCompileState(AbstractORMCompileState): + class default_compile_options(CacheableOptions): + _cache_key_traversal = [ + ("_use_legacy_query_style", InternalTraversal.dp_boolean), + ("_for_statement", InternalTraversal.dp_boolean), + ("_bake_ok", InternalTraversal.dp_boolean), + ("_current_path", InternalTraversal.dp_has_cache_key), + ("_enable_single_crit", InternalTraversal.dp_boolean), + ("_enable_eagerloads", InternalTraversal.dp_boolean), + ("_only_load_props", InternalTraversal.dp_plain_obj), + ("_set_base_alias", InternalTraversal.dp_boolean), + ("_for_refresh_state", InternalTraversal.dp_boolean), + ("_render_for_subquery", InternalTraversal.dp_boolean), + ("_is_star", InternalTraversal.dp_boolean), + ] + + # set to True by default from Query._statement_20(), to indicate + # the rendered query should look like a legacy ORM query. right + # now this basically indicates we should use tablename_columnname + # style labels. Generally indicates the statement originated + # from a Query object. + _use_legacy_query_style = False + + # set *only* when we are coming from the Query.statement + # accessor, or a Query-level equivalent such as + # query.subquery(). this supersedes "toplevel". + _for_statement = False + + _bake_ok = True + _current_path = _path_registry + _enable_single_crit = True + _enable_eagerloads = True + _only_load_props = None + _set_base_alias = False + _for_refresh_state = False + _render_for_subquery = False + _is_star = False + + attributes: Dict[Any, Any] + global_attributes: Dict[Any, Any] + + statement: Union[Select[Any], FromStatement[Any]] + select_statement: Union[Select[Any], FromStatement[Any]] + _entities: List[_QueryEntity] + _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] + compile_options: Union[ + Type[default_compile_options], default_compile_options + ] + _primary_entity: Optional[_QueryEntity] + use_legacy_query_style: bool + _label_convention: _LabelConventionCallable + primary_columns: List[ColumnElement[Any]] + secondary_columns: List[ColumnElement[Any]] + dedupe_columns: Set[ColumnElement[Any]] + create_eager_joins: List[ + # TODO: this structure is set up by JoinedLoader + Tuple[Any, ...] + ] + current_path: PathRegistry = _path_registry + _has_mapper_entities = False + + def __init__(self, *arg, **kw): + raise NotImplementedError() + + if TYPE_CHECKING: + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: ... + + def _append_dedupe_col_collection(self, obj, col_collection): + dedupe = self.dedupe_columns + if obj not in dedupe: + dedupe.add(obj) + col_collection.append(obj) + + @classmethod + def _column_naming_convention( + cls, label_style: SelectLabelStyle, legacy: bool + ) -> _LabelConventionCallable: + if legacy: + + def name(col, col_name=None): + if col_name: + return col_name + else: + return getattr(col, "key") + + return name + else: + return SelectState._column_naming_convention(label_style) + + @classmethod + def get_column_descriptions(cls, statement): + return _column_descriptions(statement) + + @classmethod + def orm_pre_session_exec( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + is_pre_event, + ): + # consume result-level load_options. These may have been set up + # in an ORMExecuteState hook + ( + load_options, + execution_options, + ) = QueryContext.default_load_options.from_execution_options( + "_sa_orm_load_options", + { + "populate_existing", + "autoflush", + "yield_per", + "identity_token", + "sa_top_level_orm_context", + }, + execution_options, + statement._execution_options, + ) + + # default execution options for ORM results: + # 1. _result_disable_adapt_to_context=True + # this will disable the ResultSetMetadata._adapt_to_context() + # step which we don't need, as we have result processors cached + # against the original SELECT statement before caching. + + if "sa_top_level_orm_context" in execution_options: + ctx = execution_options["sa_top_level_orm_context"] + execution_options = ctx.query._execution_options.merge_with( + ctx.execution_options, execution_options + ) + + if not execution_options: + execution_options = _orm_load_exec_options + else: + execution_options = execution_options.union(_orm_load_exec_options) + + # would have been placed here by legacy Query only + if load_options._yield_per: + execution_options = execution_options.union( + {"yield_per": load_options._yield_per} + ) + + if ( + getattr(statement._compile_options, "_current_path", None) + and len(statement._compile_options._current_path) > 10 + and execution_options.get("compiled_cache", True) is not None + ): + execution_options: util.immutabledict[str, Any] = ( + execution_options.union( + { + "compiled_cache": None, + "_cache_disable_reason": "excess depth for " + "ORM loader options", + } + ) + ) + + bind_arguments["clause"] = statement + + # new in 1.4 - the coercions system is leveraged to allow the + # "subject" mapper of a statement be propagated to the top + # as the statement is built. "subject" mapper is the generally + # standard object used as an identifier for multi-database schemes. + + # we are here based on the fact that _propagate_attrs contains + # "compile_state_plugin": "orm". The "plugin_subject" + # needs to be present as well. + + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" + else: + if plugin_subject: + bind_arguments["mapper"] = plugin_subject.mapper + + if not is_pre_event and load_options._autoflush: + session._autoflush() + + return statement, execution_options + + @classmethod + def orm_setup_cursor_result( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + result, + ): + execution_context = result.context + compile_state = execution_context.compiled.compile_state + + # cover edge case where ORM entities used in legacy select + # were passed to session.execute: + # session.execute(legacy_select([User.id, User.name])) + # see test_query->test_legacy_tuple_old_select + + load_options = execution_options.get( + "_sa_orm_load_options", QueryContext.default_load_options + ) + + if compile_state.compile_options._is_star: + return result + + querycontext = QueryContext( + compile_state, + statement, + params, + session, + load_options, + execution_options, + bind_arguments, + ) + return loading.instances(result, querycontext) + + @property + def _lead_mapper_entities(self): + """return all _MapperEntity objects in the lead entities collection. + + Does **not** include entities that have been replaced by + with_entities(), with_only_columns() + + """ + return [ + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ] + + def _create_with_polymorphic_adapter(self, ext_info, selectable): + """given MapperEntity or ORMColumnEntity, setup polymorphic loading + if called for by the Mapper. + + As of #8168 in 2.0.0rc1, polymorphic adapters, which greatly increase + the complexity of the query creation process, are not used at all + except in the quasi-legacy cases of with_polymorphic referring to an + alias and/or subquery. This would apply to concrete polymorphic + loading, and joined inheritance where a subquery is + passed to with_polymorphic (which is completely unnecessary in modern + use). + + """ + if ( + not ext_info.is_aliased_class + and ext_info.mapper.persist_selectable + not in self._polymorphic_adapters + ): + for mp in ext_info.mapper.iterate_to_root(): + self._mapper_loads_polymorphically_with( + mp, + ORMAdapter( + _TraceAdaptRole.WITH_POLYMORPHIC_ADAPTER, + mp, + equivalents=mp._equivalent_columns, + selectable=selectable, + ), + ) + + def _mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers or [mapper]: + self._polymorphic_adapters[m2] = adapter + + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.local_table] = adapter + + @classmethod + def _create_entities_collection(cls, query, legacy): + raise NotImplementedError( + "this method only works for ORMSelectCompileState" + ) + + +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. + + Has a subset of the interface used by + :class:`.ORMAdapter` and is used for :class:`._QueryEntity` + instances to set up their columns as used in RETURNING for a + DML statement. + + """ + + __slots__ = ("mapper", "columns", "__weakref__") + + def __init__(self, target_mapper, immediate_dml_mapper): + if ( + immediate_dml_mapper is not None + and target_mapper.local_table + is not immediate_dml_mapper.local_table + ): + # joined inh, or in theory other kinds of multi-table mappings + self.mapper = immediate_dml_mapper + else: + # single inh, normal mappings, etc. + self.mapper = target_mapper + self.columns = self.columns = util.WeakPopulateDict( + self.adapt_check_present # type: ignore + ) + + def __call__(self, col, as_filter): + for cc in sql_util._find_columns(col): + c2 = self.adapt_check_present(cc) + if c2 is not None: + return col + else: + return None + + def adapt_check_present(self, col): + mapper = self.mapper + prop = mapper._columntoproperty.get(col, None) + if prop is None: + return None + return mapper.local_table.c.corresponding_column(col) + + +@sql.base.CompileState.plugin_for("orm", "orm_from_statement") +class ORMFromStatementCompileState(ORMCompileState): + _from_obj_alias = None + _has_mapper_entities = False + + statement_container: FromStatement + requested_statement: Union[SelectBase, TextClause, UpdateBase] + dml_table: Optional[_DMLTableElement] = None + + _has_orm_entities = False + multi_row_eager_loaders = False + eager_adding_joins = False + compound_eager_adapter = None + + extra_criteria_entities = _EMPTY_DICT + eager_joins = _EMPTY_DICT + + @classmethod + def create_for_statement( + cls, + statement_container: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMFromStatementCompileState: + assert isinstance(statement_container, FromStatement) + + if compiler is not None and compiler.stack: + raise sa_exc.CompileError( + "The ORM FromStatement construct only supports being " + "invoked as the topmost statement, as it is only intended to " + "define how result rows should be returned." + ) + + self = cls.__new__(cls) + self._primary_entity = None + + self.use_legacy_query_style = ( + statement_container._compile_options._use_legacy_query_style + ) + self.statement_container = self.select_statement = statement_container + self.requested_statement = statement = statement_container.element + + if statement.is_dml: + self.dml_table = statement.table + self.is_dml_returning = True + + self._entities = [] + self._polymorphic_adapters = {} + + self.compile_options = statement_container._compile_options + + if ( + self.use_legacy_query_style + and isinstance(statement, expression.SelectBase) + and not statement._is_textual + and not statement.is_dml + and statement._label_style is LABEL_STYLE_NONE + ): + self.statement = statement.set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + else: + self.statement = statement + + self._label_convention = self._column_naming_convention( + ( + statement._label_style + if not statement._is_textual and not statement.is_dml + else LABEL_STYLE_NONE + ), + self.use_legacy_query_style, + ) + + _QueryEntity.to_compile_state( + self, + statement_container._raw_columns, + self._entities, + is_current_entities=True, + ) + + self.current_path = statement_container._compile_options._current_path + + self._init_global_attributes( + statement_container, + compiler, + process_criteria_for_toplevel=False, + toplevel=True, + ) + + if statement_container._with_options: + for opt in statement_container._with_options: + if opt._is_compile_state: + opt.process_compile_state(self) + + if statement_container._with_context_options: + for fn, key in statement_container._with_context_options: + fn(self) + + self.primary_columns = [] + self.secondary_columns = [] + self.dedupe_columns = set() + self.create_eager_joins = [] + self._fallback_from_clauses = [] + + self.order_by = None + + if isinstance(self.statement, expression.TextClause): + # TextClause has no "column" objects at all. for this case, + # we generate columns from our _QueryEntity objects, then + # flip on all the "please match no matter what" parameters. + self.extra_criteria_entities = {} + + for entity in self._entities: + entity.setup_compile_state(self) + + compiler._ordered_columns = compiler._textual_ordered_columns = ( + False + ) + + # enable looser result column matching. this is shown to be + # needed by test_query.py::TextTest + compiler._loose_column_name_matching = True + + for c in self.primary_columns: + compiler.process( + c, + within_columns_clause=True, + add_to_result_map=compiler._add_to_result_map, + ) + else: + # for everyone else, Select, Insert, Update, TextualSelect, they + # have column objects already. After much + # experimentation here, the best approach seems to be, use + # those columns completely, don't interfere with the compiler + # at all; just in ORM land, use an adapter to convert from + # our ORM columns to whatever columns are in the statement, + # before we look in the result row. Adapt on names + # to accept cases such as issue #9217, however also allow + # this to be overridden for cases such as #9273. + self._from_obj_alias = ORMStatementAdapter( + _TraceAdaptRole.ADAPT_FROM_STATEMENT, + self.statement, + adapt_on_names=statement_container._adapt_on_names, + ) + + return self + + def _adapt_col_list(self, cols, current_adapter): + return cols + + def _get_current_adapter(self): + return None + + def setup_dml_returning_compile_state(self, dml_mapper): + """used by BulkORMInsert (and Update / Delete?) to set up a handler + for RETURNING to return ORM objects and expressions + + """ + target_mapper = self.statement._propagate_attrs.get( + "plugin_subject", None + ) + adapter = DMLReturningColFilter(target_mapper, dml_mapper) + + if self.compile_options._is_star and (len(self._entities) != 1): + raise sa_exc.CompileError( + "Can't generate ORM query that includes multiple expressions " + "at the same time as '*'; query for '*' alone if present" + ) + + for entity in self._entities: + entity.setup_dml_returning_compile_state(self, adapter) + + +class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): + """Core construct that represents a load of ORM objects from various + :class:`.ReturnsRows` and other classes including: + + :class:`.Select`, :class:`.TextClause`, :class:`.TextualSelect`, + :class:`.CompoundSelect`, :class`.Insert`, :class:`.Update`, + and in theory, :class:`.Delete`. + + """ + + __visit_name__ = "orm_from_statement" + + _compile_options = ORMFromStatementCompileState.default_compile_options + + _compile_state_factory = ORMFromStatementCompileState.create_for_statement + + _for_update_arg = None + + element: Union[ExecutableReturnsRows, TextClause] + + _adapt_on_names: bool + + _traverse_internals = [ + ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("element", InternalTraversal.dp_clauseelement), + ] + Executable._executable_traverse_internals + + _cache_key_traversal = _traverse_internals + [ + ("_compile_options", InternalTraversal.dp_has_cache_key) + ] + + def __init__( + self, + entities: Iterable[_ColumnsClauseArgument[Any]], + element: Union[ExecutableReturnsRows, TextClause], + _adapt_on_names: bool = True, + ): + self._raw_columns = [ + coercions.expect( + roles.ColumnsClauseRole, + ent, + apply_propagate_attrs=self, + post_inspect=True, + ) + for ent in util.to_list(entities) + ] + self.element = element + self.is_dml = element.is_dml + self._label_style = ( + element._label_style if is_select_base(element) else None + ) + self._adapt_on_names = _adapt_on_names + + def _compiler_dispatch(self, compiler, **kw): + """provide a fixed _compiler_dispatch method. + + This is roughly similar to using the sqlalchemy.ext.compiler + ``@compiles`` extension. + + """ + + compile_state = self._compile_state_factory(self, compiler, **kw) + + toplevel = not compiler.stack + + if toplevel: + compiler.compile_state = compile_state + + return compiler.process(compile_state.statement, **kw) + + @property + def column_descriptions(self): + """Return a :term:`plugin-enabled` 'column descriptions' structure + referring to the columns which are SELECTed by this statement. + + See the section :ref:`queryguide_inspection` for an overview + of this feature. + + .. seealso:: + + :ref:`queryguide_inspection` - ORM background + + """ + meth = cast( + ORMSelectCompileState, SelectState.get_plugin_class(self) + ).get_column_descriptions + return meth(self) + + def _ensure_disambiguated_names(self): + return self + + def get_children(self, **kw): + yield from itertools.chain.from_iterable( + element._from_objects for element in self._raw_columns + ) + yield from super().get_children(**kw) + + @property + def _all_selected_columns(self): + return self.element._all_selected_columns + + @property + def _return_defaults(self): + return self.element._return_defaults if is_dml(self.element) else None + + @property + def _returning(self): + return self.element._returning if is_dml(self.element) else None + + @property + def _inline(self): + return self.element._inline if is_insert_update(self.element) else None + + +@sql.base.CompileState.plugin_for("orm", "compound_select") +class CompoundSelectCompileState( + AutoflushOnlyORMCompileState, CompoundSelectState +): + pass + + +@sql.base.CompileState.plugin_for("orm", "select") +class ORMSelectCompileState(ORMCompileState, SelectState): + _already_joined_edges = () + + _memoized_entities = _EMPTY_DICT + + _from_obj_alias = None + _has_mapper_entities = False + + _has_orm_entities = False + multi_row_eager_loaders = False + eager_adding_joins = False + compound_eager_adapter = None + + correlate = None + correlate_except = None + _where_criteria = () + _having_criteria = () + + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMSelectCompileState: + """compiler hook, we arrive here from compiler.visit_select() only.""" + + self = cls.__new__(cls) + + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + + select_statement = statement + + # if we are a select() that was never a legacy Query, we won't + # have ORM level compile options. + statement._compile_options = cls.default_compile_options.safe_merge( + statement._compile_options + ) + + if select_statement._execution_options: + # execution options should not impact the compilation of a + # query, and at the moment subqueryloader is putting some things + # in here that we explicitly don't want stuck in a cache. + self.select_statement = select_statement._clone() + self.select_statement._execution_options = util.immutabledict() + else: + self.select_statement = select_statement + + # indicates this select() came from Query.statement + self.for_statement = select_statement._compile_options._for_statement + + # generally if we are from Query or directly from a select() + self.use_legacy_query_style = ( + select_statement._compile_options._use_legacy_query_style + ) + + self._entities = [] + self._primary_entity = None + self._polymorphic_adapters = {} + + self.compile_options = select_statement._compile_options + + if not toplevel: + # for subqueries, turn off eagerloads and set + # "render_for_subquery". + self.compile_options += { + "_enable_eagerloads": False, + "_render_for_subquery": True, + } + + # determine label style. we can make different decisions here. + # at the moment, trying to see if we can always use DISAMBIGUATE_ONLY + # rather than LABEL_STYLE_NONE, and if we can use disambiguate style + # for new style ORM selects too. + if ( + self.use_legacy_query_style + and self.select_statement._label_style is LABEL_STYLE_LEGACY_ORM + ): + if not self.for_statement: + self.label_style = LABEL_STYLE_TABLENAME_PLUS_COL + else: + self.label_style = LABEL_STYLE_DISAMBIGUATE_ONLY + else: + self.label_style = self.select_statement._label_style + + if select_statement._memoized_select_entities: + self._memoized_entities = { + memoized_entities: _QueryEntity.to_compile_state( + self, + memoized_entities._raw_columns, + [], + is_current_entities=False, + ) + for memoized_entities in ( + select_statement._memoized_select_entities + ) + } + + # label_convention is stateful and will yield deduping keys if it + # sees the same key twice. therefore it's important that it is not + # invoked for the above "memoized" entities that aren't actually + # in the columns clause + self._label_convention = self._column_naming_convention( + statement._label_style, self.use_legacy_query_style + ) + + _QueryEntity.to_compile_state( + self, + select_statement._raw_columns, + self._entities, + is_current_entities=True, + ) + + self.current_path = select_statement._compile_options._current_path + + self.eager_order_by = () + + self._init_global_attributes( + select_statement, + compiler, + toplevel=toplevel, + process_criteria_for_toplevel=False, + ) + + if toplevel and ( + select_statement._with_options + or select_statement._memoized_select_entities + ): + for ( + memoized_entities + ) in select_statement._memoized_select_entities: + for opt in memoized_entities._with_options: + if opt._is_compile_state: + opt.process_compile_state_replaced_entities( + self, + [ + ent + for ent in self._memoized_entities[ + memoized_entities + ] + if isinstance(ent, _MapperEntity) + ], + ) + + for opt in self.select_statement._with_options: + if opt._is_compile_state: + opt.process_compile_state(self) + + # uncomment to print out the context.attributes structure + # after it's been set up above + # self._dump_option_struct() + + if select_statement._with_context_options: + for fn, key in select_statement._with_context_options: + fn(self) + + self.primary_columns = [] + self.secondary_columns = [] + self.dedupe_columns = set() + self.eager_joins = {} + self.extra_criteria_entities = {} + self.create_eager_joins = [] + self._fallback_from_clauses = [] + + # normalize the FROM clauses early by themselves, as this makes + # it an easier job when we need to assemble a JOIN onto these, + # for select.join() as well as joinedload(). As of 1.4 there are now + # potentially more complex sets of FROM objects here as the use + # of lambda statements for lazyload, load_on_pk etc. uses more + # cloning of the select() construct. See #6495 + self.from_clauses = self._normalize_froms( + info.selectable for info in select_statement._from_obj + ) + + # this is a fairly arbitrary break into a second method, + # so it might be nicer to break up create_for_statement() + # and _setup_for_generate into three or four logical sections + self._setup_for_generate() + + SelectState.__init__(self, self.statement, compiler, **kw) + return self + + def _dump_option_struct(self): + print("\n---------------------------------------------------\n") + print(f"current path: {self.current_path}") + for key in self.attributes: + if isinstance(key, tuple) and key[0] == "loader": + print(f"\nLoader: {PathRegistry.coerce(key[1])}") + print(f" {self.attributes[key]}") + print(f" {self.attributes[key].__dict__}") + elif isinstance(key, tuple) and key[0] == "path_with_polymorphic": + print(f"\nWith Polymorphic: {PathRegistry.coerce(key[1])}") + print(f" {self.attributes[key]}") + + def _setup_for_generate(self): + query = self.select_statement + + self.statement = None + self._join_entities = () + + if self.compile_options._set_base_alias: + # legacy Query only + self._set_select_from_alias() + + for memoized_entities in query._memoized_select_entities: + if memoized_entities._setup_joins: + self._join( + memoized_entities._setup_joins, + self._memoized_entities[memoized_entities], + ) + + if query._setup_joins: + self._join(query._setup_joins, self._entities) + + current_adapter = self._get_current_adapter() + + if query._where_criteria: + self._where_criteria = query._where_criteria + + if current_adapter: + self._where_criteria = tuple( + current_adapter(crit, True) + for crit in self._where_criteria + ) + + # TODO: some complexity with order_by here was due to mapper.order_by. + # now that this is removed we can hopefully make order_by / + # group_by act identically to how they are in Core select. + self.order_by = ( + self._adapt_col_list(query._order_by_clauses, current_adapter) + if current_adapter and query._order_by_clauses not in (None, False) + else query._order_by_clauses + ) + + if query._having_criteria: + self._having_criteria = tuple( + current_adapter(crit, True) if current_adapter else crit + for crit in query._having_criteria + ) + + self.group_by = ( + self._adapt_col_list( + util.flatten_iterator(query._group_by_clauses), current_adapter + ) + if current_adapter and query._group_by_clauses not in (None, False) + else query._group_by_clauses or None + ) + + if self.eager_order_by: + adapter = self.from_clauses[0]._target_adapter + self.eager_order_by = adapter.copy_and_process(self.eager_order_by) + + if query._distinct_on: + self.distinct_on = self._adapt_col_list( + query._distinct_on, current_adapter + ) + else: + self.distinct_on = () + + self.distinct = query._distinct + + if query._correlate: + # ORM mapped entities that are mapped to joins can be passed + # to .correlate, so here they are broken into their component + # tables. + self.correlate = tuple( + util.flatten_iterator( + sql_util.surface_selectables(s) if s is not None else None + for s in query._correlate + ) + ) + elif query._correlate_except is not None: + self.correlate_except = tuple( + util.flatten_iterator( + sql_util.surface_selectables(s) if s is not None else None + for s in query._correlate_except + ) + ) + elif not query._auto_correlate: + self.correlate = (None,) + + # PART II + + self._for_update_arg = query._for_update_arg + + if self.compile_options._is_star and (len(self._entities) != 1): + raise sa_exc.CompileError( + "Can't generate ORM query that includes multiple expressions " + "at the same time as '*'; query for '*' alone if present" + ) + for entity in self._entities: + entity.setup_compile_state(self) + + for rec in self.create_eager_joins: + strategy = rec[0] + strategy(self, *rec[1:]) + + # else "load from discrete FROMs" mode, + # i.e. when each _MappedEntity has its own FROM + + if self.compile_options._enable_single_crit: + self._adjust_for_extra_criteria() + + if not self.primary_columns: + if self.compile_options._only_load_props: + assert False, "no columns were included in _only_load_props" + + raise sa_exc.InvalidRequestError( + "Query contains no columns with which to SELECT from." + ) + + if not self.from_clauses: + self.from_clauses = list(self._fallback_from_clauses) + + if self.order_by is False: + self.order_by = None + + if ( + self.multi_row_eager_loaders + and self.eager_adding_joins + and self._should_nest_selectable + ): + self.statement = self._compound_eager_statement() + else: + self.statement = self._simple_statement() + + if self.for_statement: + ezero = self._mapper_zero() + if ezero is not None: + # TODO: this goes away once we get rid of the deep entity + # thing + self.statement = self.statement._annotate( + {"deepentity": ezero} + ) + + @classmethod + def _create_entities_collection(cls, query, legacy): + """Creates a partial ORMSelectCompileState that includes + the full collection of _MapperEntity and other _QueryEntity objects. + + Supports a few remaining use cases that are pre-compilation + but still need to gather some of the column / adaption information. + + """ + self = cls.__new__(cls) + + self._entities = [] + self._primary_entity = None + self._polymorphic_adapters = {} + + self._label_convention = self._column_naming_convention( + query._label_style, legacy + ) + + # entities will also set up polymorphic adapters for mappers + # that have with_polymorphic configured + _QueryEntity.to_compile_state( + self, query._raw_columns, self._entities, is_current_entities=True + ) + return self + + @classmethod + def determine_last_joined_entity(cls, statement): + setup_joins = statement._setup_joins + + return _determine_last_joined_entity(setup_joins, None) + + @classmethod + def all_selected_columns(cls, statement): + for element in statement._raw_columns: + if ( + element.is_selectable + and "entity_namespace" in element._annotations + ): + ens = element._annotations["entity_namespace"] + if not ens.is_mapper and not ens.is_aliased_class: + yield from _select_iterables([element]) + else: + yield from _select_iterables(ens._all_column_expressions) + else: + yield from _select_iterables([element]) + + @classmethod + def get_columns_clause_froms(cls, statement): + return cls._normalize_froms( + itertools.chain.from_iterable( + ( + element._from_objects + if "parententity" not in element._annotations + else [ + element._annotations[ + "parententity" + ].__clause_element__() + ] + ) + for element in statement._raw_columns + ) + ) + + @classmethod + def from_statement(cls, statement, from_statement): + from_statement = coercions.expect( + roles.ReturnsRowsRole, + from_statement, + apply_propagate_attrs=statement, + ) + + stmt = FromStatement(statement._raw_columns, from_statement) + + stmt.__dict__.update( + _with_options=statement._with_options, + _with_context_options=statement._with_context_options, + _execution_options=statement._execution_options, + _propagate_attrs=statement._propagate_attrs, + ) + return stmt + + def _set_select_from_alias(self): + """used only for legacy Query cases""" + + query = self.select_statement # query + + assert self.compile_options._set_base_alias + assert len(query._from_obj) == 1 + + adapter = self._get_select_from_alias_from_obj(query._from_obj[0]) + if adapter: + self.compile_options += {"_enable_single_crit": False} + self._from_obj_alias = adapter + + def _get_select_from_alias_from_obj(self, from_obj): + """used only for legacy Query cases""" + + info = from_obj + + if "parententity" in info._annotations: + info = info._annotations["parententity"] + + if hasattr(info, "mapper"): + if not info.is_aliased_class: + raise sa_exc.ArgumentError( + "A selectable (FromClause) instance is " + "expected when the base alias is being set." + ) + else: + return info._adapter + + elif isinstance(info.selectable, sql.selectable.AliasedReturnsRows): + equivs = self._all_equivs() + assert info is info.selectable + return ORMStatementAdapter( + _TraceAdaptRole.LEGACY_SELECT_FROM_ALIAS, + info.selectable, + equivalents=equivs, + ) + else: + return None + + def _mapper_zero(self): + """return the Mapper associated with the first QueryEntity.""" + return self._entities[0].mapper + + def _entity_zero(self): + """Return the 'entity' (mapper or AliasedClass) associated + with the first QueryEntity, or alternatively the 'select from' + entity if specified.""" + + for ent in self.from_clauses: + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + for qent in self._entities: + if qent.entity_zero: + return qent.entity_zero + + return None + + def _only_full_mapper_zero(self, methname): + if self._entities != [self._primary_entity]: + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + return self._primary_entity.entity_zero + + def _only_entity_zero(self, rationale=None): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError( + rationale + or "This operation requires a Query " + "against a single mapper." + ) + return self._entity_zero() + + def _all_equivs(self): + equivs = {} + + for memoized_entities in self._memoized_entities.values(): + for ent in [ + ent + for ent in memoized_entities + if isinstance(ent, _MapperEntity) + ]: + equivs.update(ent.mapper._equivalent_columns) + + for ent in [ + ent for ent in self._entities if isinstance(ent, _MapperEntity) + ]: + equivs.update(ent.mapper._equivalent_columns) + return equivs + + def _compound_eager_statement(self): + # for eager joins present and LIMIT/OFFSET/DISTINCT, + # wrap the query inside a select, + # then append eager joins onto that + + if self.order_by: + # the default coercion for ORDER BY is now the OrderByRole, + # which adds an additional post coercion to ByOfRole in that + # elements are converted into label references. For the + # eager load / subquery wrapping case, we need to un-coerce + # the original expressions outside of the label references + # in order to have them render. + unwrapped_order_by = [ + ( + elem.element + if isinstance(elem, sql.elements._label_reference) + else elem + ) + for elem in self.order_by + ] + + order_by_col_expr = sql_util.expand_column_list_from_order_by( + self.primary_columns, unwrapped_order_by + ) + else: + order_by_col_expr = [] + unwrapped_order_by = None + + # put FOR UPDATE on the inner query, where MySQL will honor it, + # as well as if it has an OF so PostgreSQL can use it. + inner = self._select_statement( + self.primary_columns + + [c for c in order_by_col_expr if c not in self.dedupe_columns], + self.from_clauses, + self._where_criteria, + self._having_criteria, + self.label_style, + self.order_by, + for_update=self._for_update_arg, + hints=self.select_statement._hints, + statement_hints=self.select_statement._statement_hints, + correlate=self.correlate, + correlate_except=self.correlate_except, + **self._select_args, + ) + + inner = inner.alias() + + equivs = self._all_equivs() + + self.compound_eager_adapter = ORMStatementAdapter( + _TraceAdaptRole.COMPOUND_EAGER_STATEMENT, inner, equivalents=equivs + ) + + statement = future.select( + *([inner] + self.secondary_columns) # use_labels=self.labels + ) + statement._label_style = self.label_style + + # Oracle however does not allow FOR UPDATE on the subquery, + # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL + # we expect that all elements of the row are locked, so also put it + # on the outside (except in the case of PG when OF is used) + if ( + self._for_update_arg is not None + and self._for_update_arg.of is None + ): + statement._for_update_arg = self._for_update_arg + + from_clause = inner + for eager_join in self.eager_joins.values(): + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of + # the join should be + from_clause = sql_util.splice_joins( + from_clause, eager_join, eager_join.stop_on + ) + + statement.select_from.non_generative(statement, from_clause) + + if unwrapped_order_by: + statement.order_by.non_generative( + statement, + *self.compound_eager_adapter.copy_and_process( + unwrapped_order_by + ), + ) + + statement.order_by.non_generative(statement, *self.eager_order_by) + return statement + + def _simple_statement(self): + statement = self._select_statement( + self.primary_columns + self.secondary_columns, + tuple(self.from_clauses) + tuple(self.eager_joins.values()), + self._where_criteria, + self._having_criteria, + self.label_style, + self.order_by, + for_update=self._for_update_arg, + hints=self.select_statement._hints, + statement_hints=self.select_statement._statement_hints, + correlate=self.correlate, + correlate_except=self.correlate_except, + **self._select_args, + ) + + if self.eager_order_by: + statement.order_by.non_generative(statement, *self.eager_order_by) + return statement + + def _select_statement( + self, + raw_columns, + from_obj, + where_criteria, + having_criteria, + label_style, + order_by, + for_update, + hints, + statement_hints, + correlate, + correlate_except, + limit_clause, + offset_clause, + fetch_clause, + fetch_clause_options, + distinct, + distinct_on, + prefixes, + suffixes, + group_by, + independent_ctes, + independent_ctes_opts, + ): + statement = Select._create_raw_select( + _raw_columns=raw_columns, + _from_obj=from_obj, + _label_style=label_style, + ) + + if where_criteria: + statement._where_criteria = where_criteria + if having_criteria: + statement._having_criteria = having_criteria + + if order_by: + statement._order_by_clauses += tuple(order_by) + + if distinct_on: + statement.distinct.non_generative(statement, *distinct_on) + elif distinct: + statement.distinct.non_generative(statement) + + if group_by: + statement._group_by_clauses += tuple(group_by) + + statement._limit_clause = limit_clause + statement._offset_clause = offset_clause + statement._fetch_clause = fetch_clause + statement._fetch_clause_options = fetch_clause_options + statement._independent_ctes = independent_ctes + statement._independent_ctes_opts = independent_ctes_opts + + if prefixes: + statement._prefixes = prefixes + + if suffixes: + statement._suffixes = suffixes + + statement._for_update_arg = for_update + + if hints: + statement._hints = hints + if statement_hints: + statement._statement_hints = statement_hints + + if correlate: + statement.correlate.non_generative(statement, *correlate) + + if correlate_except is not None: + statement.correlate_except.non_generative( + statement, *correlate_except + ) + + return statement + + def _adapt_polymorphic_element(self, element): + if "parententity" in element._annotations: + search = element._annotations["parententity"] + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + if isinstance(element, expression.FromClause): + search = element + elif hasattr(element, "table"): + search = element.table + else: + return None + + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + def _adapt_col_list(self, cols, current_adapter): + if current_adapter: + return [current_adapter(o, True) for o in cols] + else: + return cols + + def _get_current_adapter(self): + adapters = [] + + if self._from_obj_alias: + # used for legacy going forward for query set_ops, e.g. + # union(), union_all(), etc. + # 1.4 and previously, also used for from_self(), + # select_entity_from() + # + # for the "from obj" alias, apply extra rule to the + # 'ORM only' check, if this query were generated from a + # subquery of itself, i.e. _from_selectable(), apply adaption + # to all SQL constructs. + adapters.append( + ( + True, + self._from_obj_alias.replace, + ) + ) + + # this was *hopefully* the only adapter we were going to need + # going forward...however, we unfortunately need _from_obj_alias + # for query.union(), which we can't drop + if self._polymorphic_adapters: + adapters.append((False, self._adapt_polymorphic_element)) + + if not adapters: + return None + + def _adapt_clause(clause, as_filter): + # do we adapt all expression elements or only those + # tagged as 'ORM' constructs ? + + def replace(elem): + is_orm_adapt = ( + "_orm_adapt" in elem._annotations + or "parententity" in elem._annotations + ) + for always_adapt, adapter in adapters: + if is_orm_adapt or always_adapt: + e = adapter(elem) + if e is not None: + return e + + return visitors.replacement_traverse(clause, {}, replace) + + return _adapt_clause + + def _join(self, args, entities_collection): + for right, onclause, from_, flags in args: + isouter = flags["isouter"] + full = flags["full"] + + right = inspect(right) + if onclause is not None: + onclause = inspect(onclause) + + if isinstance(right, interfaces.PropComparator): + if onclause is not None: + raise sa_exc.InvalidRequestError( + "No 'on clause' argument may be passed when joining " + "to a relationship path as a target" + ) + + onclause = right + right = None + elif "parententity" in right._annotations: + right = right._annotations["parententity"] + + if onclause is None: + if not right.is_selectable and not hasattr(right, "mapper"): + raise sa_exc.ArgumentError( + "Expected mapped entity or " + "selectable/table as join target" + ) + + of_type = None + + if isinstance(onclause, interfaces.PropComparator): + # descriptor/property given (or determined); this tells us + # explicitly what the expected "left" side of the join is. + + of_type = getattr(onclause, "_of_type", None) + + if right is None: + if of_type: + right = of_type + else: + right = onclause.property + + try: + right = right.entity + except AttributeError as err: + raise sa_exc.ArgumentError( + "Join target %s does not refer to a " + "mapped entity" % right + ) from err + + left = onclause._parententity + + prop = onclause.property + if not isinstance(onclause, attributes.QueryableAttribute): + onclause = prop + + # check for this path already present. don't render in that + # case. + if (left, right, prop.key) in self._already_joined_edges: + continue + + if from_ is not None: + if ( + from_ is not left + and from_._annotations.get("parententity", None) + is not left + ): + raise sa_exc.InvalidRequestError( + "explicit from clause %s does not match left side " + "of relationship attribute %s" + % ( + from_._annotations.get("parententity", from_), + onclause, + ) + ) + elif from_ is not None: + prop = None + left = from_ + else: + # no descriptor/property given; we will need to figure out + # what the effective "left" side is + prop = left = None + + # figure out the final "left" and "right" sides and create an + # ORMJoin to add to our _from_obj tuple + self._join_left_to_right( + entities_collection, + left, + right, + onclause, + prop, + isouter, + full, + ) + + def _join_left_to_right( + self, + entities_collection, + left, + right, + onclause, + prop, + outerjoin, + full, + ): + """given raw "left", "right", "onclause" parameters consumed from + a particular key within _join(), add a real ORMJoin object to + our _from_obj list (or augment an existing one) + + """ + + if left is None: + # left not given (e.g. no relationship object/name specified) + # figure out the best "left" side based on our existing froms / + # entities + assert prop is None + ( + left, + replace_from_obj_index, + use_entity_index, + ) = self._join_determine_implicit_left_side( + entities_collection, left, right, onclause + ) + else: + # left is given via a relationship/name, or as explicit left side. + # Determine where in our + # "froms" list it should be spliced/appended as well as what + # existing entity it corresponds to. + ( + replace_from_obj_index, + use_entity_index, + ) = self._join_place_explicit_left_side(entities_collection, left) + + if left is right: + raise sa_exc.InvalidRequestError( + "Can't construct a join from %s to %s, they " + "are the same entity" % (left, right) + ) + + # the right side as given often needs to be adapted. additionally + # a lot of things can be wrong with it. handle all that and + # get back the new effective "right" side + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + + if not r_info.is_selectable: + extra_criteria = self._get_extra_criteria(r_info) + else: + extra_criteria = () + + if replace_from_obj_index is not None: + # splice into an existing element in the + # self._from_obj list + left_clause = self.from_clauses[replace_from_obj_index] + + self.from_clauses = ( + self.from_clauses[:replace_from_obj_index] + + [ + _ORMJoin( + left_clause, + right, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + self.from_clauses[replace_from_obj_index + 1 :] + ) + else: + # add a new element to the self._from_obj list + if use_entity_index is not None: + # make use of _MapperEntity selectable, which is usually + # entity_zero.selectable, but if with_polymorphic() were used + # might be distinct + assert isinstance( + entities_collection[use_entity_index], _MapperEntity + ) + left_clause = entities_collection[use_entity_index].selectable + else: + left_clause = left + + self.from_clauses = self.from_clauses + [ + _ORMJoin( + left_clause, + r_info, + onclause, + isouter=outerjoin, + full=full, + _extra_criteria=extra_criteria, + ) + ] + + def _join_determine_implicit_left_side( + self, entities_collection, left, right, onclause + ): + """When join conditions don't express the left side explicitly, + determine if an existing FROM or entity in this query + can serve as the left hand side. + + """ + + # when we are here, it means join() was called without an ORM- + # specific way of telling us what the "left" side is, e.g.: + # + # join(RightEntity) + # + # or + # + # join(RightEntity, RightEntity.foo == LeftEntity.bar) + # + + r_info = inspect(right) + + replace_from_obj_index = use_entity_index = None + + if self.from_clauses: + # we have a list of FROMs already. So by definition this + # join has to connect to one of those FROMs. + + indexes = sql_util.find_left_clause_to_join_from( + self.from_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + replace_from_obj_index = indexes[0] + left = self.from_clauses[replace_from_obj_index] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + + elif entities_collection: + # we have no explicit FROMs, so the implicit left has to + # come from our list of entities. + + potential = {} + for entity_index, ent in enumerate(entities_collection): + entity = ent.entity_zero_or_selectable + if entity is None: + continue + ent_info = inspect(entity) + if ent_info is r_info: # left and right are the same, skip + continue + + # by using a dictionary with the selectables as keys this + # de-duplicates those selectables as occurs when the query is + # against a series of columns from the same selectable + if isinstance(ent, _MapperEntity): + potential[ent.selectable] = (entity_index, entity) + else: + potential[ent_info.selectable] = (None, entity) + + all_clauses = list(potential.keys()) + indexes = sql_util.find_left_clause_to_join_from( + all_clauses, r_info.selectable, onclause + ) + + if len(indexes) == 1: + use_entity_index, left = potential[all_clauses[indexes[0]]] + elif len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't determine which FROM clause to join " + "from, there are multiple FROMS which can " + "join to this entity. Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." + ) + else: + raise sa_exc.InvalidRequestError( + "Don't know how to join to %r. " + "Please use the .select_from() " + "method to establish an explicit left side, as well as " + "providing an explicit ON clause if not present already " + "to help resolve the ambiguity." % (right,) + ) + else: + raise sa_exc.InvalidRequestError( + "No entities to join from; please use " + "select_from() to establish the left " + "entity/selectable of this join" + ) + + return left, replace_from_obj_index, use_entity_index + + def _join_place_explicit_left_side(self, entities_collection, left): + """When join conditions express a left side explicitly, determine + where in our existing list of FROM clauses we should join towards, + or if we need to make a new join, and if so is it from one of our + existing entities. + + """ + + # when we are here, it means join() was called with an indicator + # as to an exact left side, which means a path to a + # Relationship was given, e.g.: + # + # join(RightEntity, LeftEntity.right) + # + # or + # + # join(LeftEntity.right) + # + # as well as string forms: + # + # join(RightEntity, "right") + # + # etc. + # + + replace_from_obj_index = use_entity_index = None + + l_info = inspect(left) + if self.from_clauses: + indexes = sql_util.find_left_clause_that_matches_given( + self.from_clauses, l_info.selectable + ) + + if len(indexes) > 1: + raise sa_exc.InvalidRequestError( + "Can't identify which entity in which to assign the " + "left side of this join. Please use a more specific " + "ON clause." + ) + + # have an index, means the left side is already present in + # an existing FROM in the self._from_obj tuple + if indexes: + replace_from_obj_index = indexes[0] + + # no index, means we need to add a new element to the + # self._from_obj tuple + + # no from element present, so we will have to add to the + # self._from_obj tuple. Determine if this left side matches up + # with existing mapper entities, in which case we want to apply the + # aliasing / adaptation rules present on that entity if any + if ( + replace_from_obj_index is None + and entities_collection + and hasattr(l_info, "mapper") + ): + for idx, ent in enumerate(entities_collection): + # TODO: should we be checking for multiple mapper entities + # matching? + if isinstance(ent, _MapperEntity) and ent.corresponds_to(left): + use_entity_index = idx + break + + return replace_from_obj_index, use_entity_index + + def _join_check_and_adapt_right_side(self, left, right, onclause, prop): + """transform the "right" side of the join as well as the onclause + according to polymorphic mapping translations, aliasing on the query + or on the join, special cases where the right and left side have + overlapping tables. + + """ + + l_info = inspect(left) + r_info = inspect(right) + + overlap = False + + right_mapper = getattr(r_info, "mapper", None) + # if the target is a joined inheritance mapping, + # be more liberal about auto-aliasing. + if right_mapper and ( + right_mapper.with_polymorphic + or isinstance(right_mapper.persist_selectable, expression.Join) + ): + for from_obj in self.from_clauses or [l_info.selectable]: + if sql_util.selectables_overlap( + l_info.selectable, from_obj + ) and sql_util.selectables_overlap( + from_obj, r_info.selectable + ): + overlap = True + break + + if overlap and l_info.selectable is r_info.selectable: + raise sa_exc.InvalidRequestError( + "Can't join table/selectable '%s' to itself" + % l_info.selectable + ) + + right_mapper, right_selectable, right_is_aliased = ( + getattr(r_info, "mapper", None), + r_info.selectable, + getattr(r_info, "is_aliased_class", False), + ) + + if ( + right_mapper + and prop + and not right_mapper.common_parent(prop.mapper) + ): + raise sa_exc.InvalidRequestError( + "Join target %s does not correspond to " + "the right side of join condition %s" % (right, onclause) + ) + + # _join_entities is used as a hint for single-table inheritance + # purposes at the moment + if hasattr(r_info, "mapper"): + self._join_entities += (r_info,) + + need_adapter = False + + # test for joining to an unmapped selectable as the target + if r_info.is_clause_element: + if prop: + right_mapper = prop.mapper + + if right_selectable._is_lateral: + # orm_only is disabled to suit the case where we have to + # adapt an explicit correlate(Entity) - the select() loses + # the ORM-ness in this case right now, ideally it would not + current_adapter = self._get_current_adapter() + if current_adapter is not None: + # TODO: we had orm_only=False here before, removing + # it didn't break things. if we identify the rationale, + # may need to apply "_orm_only" annotation here. + right = current_adapter(right, True) + + elif prop: + # joining to selectable with a mapper property given + # as the ON clause + + if not right_selectable.is_derived_from( + right_mapper.persist_selectable + ): + raise sa_exc.InvalidRequestError( + "Selectable '%s' is not derived from '%s'" + % ( + right_selectable.description, + right_mapper.persist_selectable.description, + ) + ) + + # if the destination selectable is a plain select(), + # turn it into an alias(). + if isinstance(right_selectable, expression.SelectBase): + right_selectable = coercions.expect( + roles.FromClauseRole, right_selectable + ) + need_adapter = True + + # make the right hand side target into an ORM entity + right = AliasedClass(right_mapper, right_selectable) + + util.warn_deprecated( + "An alias is being generated automatically against " + "joined entity %s for raw clauseelement, which is " + "deprecated and will be removed in a later release. " + "Use the aliased() " + "construct explicitly, see the linked example." + % right_mapper, + "1.4", + code="xaj1", + ) + + # test for overlap: + # orm/inheritance/relationships.py + # SelfReferentialM2MTest + aliased_entity = right_mapper and not right_is_aliased and overlap + + if not need_adapter and aliased_entity: + # there are a few places in the ORM that automatic aliasing + # is still desirable, and can't be automatic with a Core + # only approach. For illustrations of "overlaps" see + # test/orm/inheritance/test_relationships.py. There are also + # general overlap cases with many-to-many tables where automatic + # aliasing is desirable. + right = AliasedClass(right, flat=True) + need_adapter = True + + util.warn( + "An alias is being generated automatically against " + "joined entity %s due to overlapping tables. This is a " + "legacy pattern which may be " + "deprecated in a later release. Use the " + "aliased(, flat=True) " + "construct explicitly, see the linked example." % right_mapper, + code="xaj2", + ) + + if need_adapter: + # if need_adapter is True, we are in a deprecated case and + # a warning has been emitted. + assert right_mapper + + adapter = ORMAdapter( + _TraceAdaptRole.DEPRECATED_JOIN_ADAPT_RIGHT_SIDE, + inspect(right), + equivalents=right_mapper._equivalent_columns, + ) + + # if an alias() on the right side was generated, + # which is intended to wrap a the right side in a subquery, + # ensure that columns retrieved from this target in the result + # set are also adapted. + self._mapper_loads_polymorphically_with(right_mapper, adapter) + elif ( + not r_info.is_clause_element + and not right_is_aliased + and right_mapper._has_aliased_polymorphic_fromclause + ): + # for the case where the target mapper has a with_polymorphic + # set up, ensure an adapter is set up for criteria that works + # against this mapper. Previously, this logic used to + # use the "create_aliases or aliased_entity" case to generate + # an aliased() object, but this creates an alias that isn't + # strictly necessary. + # see test/orm/test_core_compilation.py + # ::RelNaturalAliasedJoinsTest::test_straight + # and similar + self._mapper_loads_polymorphically_with( + right_mapper, + ORMAdapter( + _TraceAdaptRole.WITH_POLYMORPHIC_ADAPTER_RIGHT_JOIN, + right_mapper, + selectable=right_mapper.selectable, + equivalents=right_mapper._equivalent_columns, + ), + ) + # if the onclause is a ClauseElement, adapt it with any + # adapters that are in place right now + if isinstance(onclause, expression.ClauseElement): + current_adapter = self._get_current_adapter() + if current_adapter: + onclause = current_adapter(onclause, True) + + # if joining on a MapperProperty path, + # track the path to prevent redundant joins + if prop: + self._already_joined_edges += ((left, right, prop.key),) + + return inspect(right), right, onclause + + @property + def _select_args(self): + return { + "limit_clause": self.select_statement._limit_clause, + "offset_clause": self.select_statement._offset_clause, + "distinct": self.distinct, + "distinct_on": self.distinct_on, + "prefixes": self.select_statement._prefixes, + "suffixes": self.select_statement._suffixes, + "group_by": self.group_by or None, + "fetch_clause": self.select_statement._fetch_clause, + "fetch_clause_options": ( + self.select_statement._fetch_clause_options + ), + "independent_ctes": self.select_statement._independent_ctes, + "independent_ctes_opts": ( + self.select_statement._independent_ctes_opts + ), + } + + @property + def _should_nest_selectable(self): + kwargs = self._select_args + return ( + kwargs.get("limit_clause") is not None + or kwargs.get("offset_clause") is not None + or kwargs.get("distinct", False) + or kwargs.get("distinct_on", ()) + or kwargs.get("group_by", False) + ) + + def _get_extra_criteria(self, ext_info): + if ( + "additional_entity_criteria", + ext_info.mapper, + ) in self.global_attributes: + return tuple( + ae._resolve_where_criteria(ext_info) + for ae in self.global_attributes[ + ("additional_entity_criteria", ext_info.mapper) + ] + if (ae.include_aliases or ae.entity is ext_info) + and ae._should_include(self) + ) + else: + return () + + def _adjust_for_extra_criteria(self): + """Apply extra criteria filtering. + + For all distinct single-table-inheritance mappers represented in + the columns clause of this query, as well as the "select from entity", + add criterion to the WHERE + clause of the given QueryContext such that only the appropriate + subtypes are selected from the total results. + + Additionally, add WHERE criteria originating from LoaderCriteriaOptions + associated with the global context. + + """ + + for fromclause in self.from_clauses: + ext_info = fromclause._annotations.get("parententity", None) + + if ( + ext_info + and ( + ext_info.mapper._single_table_criterion is not None + or ("additional_entity_criteria", ext_info.mapper) + in self.global_attributes + ) + and ext_info not in self.extra_criteria_entities + ): + self.extra_criteria_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + + search = set(self.extra_criteria_entities.values()) + + for ext_info, adapter in search: + if ext_info in self._join_entities: + continue + + single_crit = ext_info.mapper._single_table_criterion + + if self.compile_options._for_refresh_state: + additional_entity_criteria = [] + else: + additional_entity_criteria = self._get_extra_criteria(ext_info) + + if single_crit is not None: + additional_entity_criteria += (single_crit,) + + current_adapter = self._get_current_adapter() + for crit in additional_entity_criteria: + if adapter: + crit = adapter.traverse(crit) + + if current_adapter: + crit = sql_util._deep_annotate(crit, {"_orm_adapt": True}) + crit = current_adapter(crit, False) + self._where_criteria += (crit,) + + +def _column_descriptions( + query_or_select_stmt: Union[Query, Select, FromStatement], + compile_state: Optional[ORMSelectCompileState] = None, + legacy: bool = False, +) -> List[ORMColumnDescription]: + if compile_state is None: + compile_state = ORMSelectCompileState._create_entities_collection( + query_or_select_stmt, legacy=legacy + ) + ctx = compile_state + d = [ + { + "name": ent._label_name, + "type": ent.type, + "aliased": getattr(insp_ent, "is_aliased_class", False), + "expr": ent.expr, + "entity": ( + getattr(insp_ent, "entity", None) + if ent.entity_zero is not None + and not insp_ent.is_clause_element + else None + ), + } + for ent, insp_ent in [ + (_ent, _ent.entity_zero) for _ent in ctx._entities + ] + ] + return d + + +def _legacy_filter_by_entity_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: + self = query_or_augmented_select + if self._setup_joins: + _last_joined_entity = self._last_joined_entity + if _last_joined_entity is not None: + return _last_joined_entity + + if self._from_obj and "parententity" in self._from_obj[0]._annotations: + return self._from_obj[0]._annotations["parententity"] + + return _entity_from_pre_ent_zero(self) + + +def _entity_from_pre_ent_zero( + query_or_augmented_select: Union[Query[Any], Select[Any]] +) -> Optional[_InternalEntityType[Any]]: + self = query_or_augmented_select + if not self._raw_columns: + return None + + ent = self._raw_columns[0] + + if "parententity" in ent._annotations: + return ent._annotations["parententity"] + elif isinstance(ent, ORMColumnsClauseRole): + return ent.entity + elif "bundle" in ent._annotations: + return ent._annotations["bundle"] + else: + return ent + + +def _determine_last_joined_entity( + setup_joins: Tuple[_SetupJoinsElement, ...], + entity_zero: Optional[_InternalEntityType[Any]] = None, +) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: + if not setup_joins: + return None + + (target, onclause, from_, flags) = setup_joins[-1] + + if isinstance( + target, + attributes.QueryableAttribute, + ): + return target.entity + else: + return target + + +class _QueryEntity: + """represent an entity column returned within a Query result.""" + + __slots__ = () + + supports_single_entity: bool + + _non_hashable_value = False + _null_column_type = False + use_id_for_hash = False + + _label_name: Optional[str] + type: Union[Type[Any], TypeEngine[Any]] + expr: Union[_InternalEntityType, ColumnElement[Any]] + entity_zero: Optional[_InternalEntityType] + + def setup_compile_state(self, compile_state: ORMCompileState) -> None: + raise NotImplementedError() + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + raise NotImplementedError() + + def row_processor(self, context, result): + raise NotImplementedError() + + @classmethod + def to_compile_state( + cls, compile_state, entities, entities_collection, is_current_entities + ): + for idx, entity in enumerate(entities): + if entity._is_lambda_element: + if entity._is_sequence: + cls.to_compile_state( + compile_state, + entity._resolved, + entities_collection, + is_current_entities, + ) + continue + else: + entity = entity._resolved + + if entity.is_clause_element: + if entity.is_selectable: + if "parententity" in entity._annotations: + _MapperEntity( + compile_state, + entity, + entities_collection, + is_current_entities, + ) + else: + _ColumnEntity._for_columns( + compile_state, + entity._select_iterable, + entities_collection, + idx, + is_current_entities, + ) + else: + if entity._annotations.get("bundle", False): + _BundleEntity( + compile_state, + entity, + entities_collection, + is_current_entities, + ) + elif entity._is_clause_list: + # this is legacy only - test_composites.py + # test_query_cols_legacy + _ColumnEntity._for_columns( + compile_state, + entity._select_iterable, + entities_collection, + idx, + is_current_entities, + ) + else: + _ColumnEntity._for_columns( + compile_state, + [entity], + entities_collection, + idx, + is_current_entities, + ) + elif entity.is_bundle: + _BundleEntity(compile_state, entity, entities_collection) + + return entities_collection + + +class _MapperEntity(_QueryEntity): + """mapper/class/AliasedClass entity""" + + __slots__ = ( + "expr", + "mapper", + "entity_zero", + "is_aliased_class", + "path", + "_extra_entities", + "_label_name", + "_with_polymorphic_mappers", + "selectable", + "_polymorphic_discriminator", + ) + + expr: _InternalEntityType + mapper: Mapper[Any] + entity_zero: _InternalEntityType + is_aliased_class: bool + path: PathRegistry + _label_name: str + + def __init__( + self, compile_state, entity, entities_collection, is_current_entities + ): + entities_collection.append(self) + if is_current_entities: + if compile_state._primary_entity is None: + compile_state._primary_entity = self + compile_state._has_mapper_entities = True + compile_state._has_orm_entities = True + + entity = entity._annotations["parententity"] + entity._post_inspect + ext_info = self.entity_zero = entity + entity = ext_info.entity + + self.expr = entity + self.mapper = mapper = ext_info.mapper + + self._extra_entities = (self.expr,) + + if ext_info.is_aliased_class: + self._label_name = ext_info.name + else: + self._label_name = mapper.class_.__name__ + + self.is_aliased_class = ext_info.is_aliased_class + self.path = ext_info._path_registry + + self.selectable = ext_info.selectable + self._with_polymorphic_mappers = ext_info.with_polymorphic_mappers + self._polymorphic_discriminator = ext_info.polymorphic_on + + if mapper._should_select_with_poly_adapter: + compile_state._create_with_polymorphic_adapter( + ext_info, self.selectable + ) + + supports_single_entity = True + + _non_hashable_value = True + use_id_for_hash = True + + @property + def type(self): + return self.mapper.class_ + + @property + def entity_zero_or_selectable(self): + return self.entity_zero + + def corresponds_to(self, entity): + return _entity_corresponds_to(self.entity_zero, entity) + + def _get_entity_clauses(self, compile_state): + adapter = None + + if not self.is_aliased_class: + if compile_state._polymorphic_adapters: + adapter = compile_state._polymorphic_adapters.get( + self.mapper, None + ) + else: + adapter = self.entity_zero._adapter + + if adapter: + if compile_state._from_obj_alias: + ret = adapter.wrap(compile_state._from_obj_alias) + else: + ret = adapter + else: + ret = compile_state._from_obj_alias + + return ret + + def row_processor(self, context, result): + compile_state = context.compile_state + adapter = self._get_entity_clauses(compile_state) + + if compile_state.compound_eager_adapter and adapter: + adapter = adapter.wrap(compile_state.compound_eager_adapter) + elif not adapter: + adapter = compile_state.compound_eager_adapter + + if compile_state._primary_entity is self: + only_load_props = compile_state.compile_options._only_load_props + refresh_state = context.refresh_state + else: + only_load_props = refresh_state = None + + _instance = loading._instance_processor( + self, + self.mapper, + context, + result, + self.path, + adapter, + only_load_props=only_load_props, + refresh_state=refresh_state, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + + return _instance, self._label_name, self._extra_entities + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + + def setup_compile_state(self, compile_state): + adapter = self._get_entity_clauses(compile_state) + + single_table_crit = self.mapper._single_table_criterion + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): + ext_info = self.entity_zero + compile_state.extra_criteria_entities[ext_info] = ( + ext_info, + ext_info._adapter if ext_info.is_aliased_class else None, + ) + + loading._setup_entity_query( + compile_state, + self.mapper, + self, + self.path, + adapter, + compile_state.primary_columns, + with_polymorphic=self._with_polymorphic_mappers, + only_load_props=compile_state.compile_options._only_load_props, + polymorphic_discriminator=self._polymorphic_discriminator, + ) + compile_state._fallback_from_clauses.append(self.selectable) + + +class _BundleEntity(_QueryEntity): + _extra_entities = () + + __slots__ = ( + "bundle", + "expr", + "type", + "_label_name", + "_entities", + "supports_single_entity", + ) + + _entities: List[_QueryEntity] + bundle: Bundle + type: Type[Any] + _label_name: str + supports_single_entity: bool + expr: Bundle + + def __init__( + self, + compile_state, + expr, + entities_collection, + is_current_entities, + setup_entities=True, + parent_bundle=None, + ): + compile_state._has_orm_entities = True + + expr = expr._annotations["bundle"] + if parent_bundle: + parent_bundle._entities.append(self) + else: + entities_collection.append(self) + + if isinstance( + expr, (attributes.QueryableAttribute, interfaces.PropComparator) + ): + bundle = expr.__clause_element__() + else: + bundle = expr + + self.bundle = self.expr = bundle + self.type = type(bundle) + self._label_name = bundle.name + self._entities = [] + + if setup_entities: + for expr in bundle.exprs: + if "bundle" in expr._annotations: + _BundleEntity( + compile_state, + expr, + entities_collection, + is_current_entities, + parent_bundle=self, + ) + elif isinstance(expr, Bundle): + _BundleEntity( + compile_state, + expr, + entities_collection, + is_current_entities, + parent_bundle=self, + ) + else: + _ORMColumnEntity._for_columns( + compile_state, + [expr], + entities_collection, + None, + is_current_entities, + parent_bundle=self, + ) + + self.supports_single_entity = self.bundle.single_entity + + @property + def mapper(self): + ezero = self.entity_zero + if ezero is not None: + return ezero.mapper + else: + return None + + @property + def entity_zero(self): + for ent in self._entities: + ezero = ent.entity_zero + if ezero is not None: + return ezero + else: + return None + + def corresponds_to(self, entity): + # TODO: we might be able to implement this but for now + # we are working around it + return False + + @property + def entity_zero_or_selectable(self): + for ent in self._entities: + ezero = ent.entity_zero_or_selectable + if ezero is not None: + return ezero + else: + return None + + def setup_compile_state(self, compile_state): + for ent in self._entities: + ent.setup_compile_state(compile_state) + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + return self.setup_compile_state(compile_state) + + def row_processor(self, context, result): + procs, labels, extra = zip( + *[ent.row_processor(context, result) for ent in self._entities] + ) + + proc = self.bundle.create_row_processor(context.query, procs, labels) + + return proc, self._label_name, self._extra_entities + + +class _ColumnEntity(_QueryEntity): + __slots__ = ( + "_fetch_column", + "_row_processor", + "raw_column_index", + "translate_raw_column", + ) + + @classmethod + def _for_columns( + cls, + compile_state, + columns, + entities_collection, + raw_column_index, + is_current_entities, + parent_bundle=None, + ): + for column in columns: + annotations = column._annotations + if "parententity" in annotations: + _entity = annotations["parententity"] + else: + _entity = sql_util.extract_first_column_annotation( + column, "parententity" + ) + + if _entity: + if "identity_token" in column._annotations: + _IdentityTokenEntity( + compile_state, + column, + entities_collection, + _entity, + raw_column_index, + is_current_entities, + parent_bundle=parent_bundle, + ) + else: + _ORMColumnEntity( + compile_state, + column, + entities_collection, + _entity, + raw_column_index, + is_current_entities, + parent_bundle=parent_bundle, + ) + else: + _RawColumnEntity( + compile_state, + column, + entities_collection, + raw_column_index, + is_current_entities, + parent_bundle=parent_bundle, + ) + + @property + def type(self): + return self.column.type + + @property + def _non_hashable_value(self): + return not self.column.type.hashable + + @property + def _null_column_type(self): + return self.column.type._isnull + + def row_processor(self, context, result): + compile_state = context.compile_state + + # the resulting callable is entirely cacheable so just return + # it if we already made one + if self._row_processor is not None: + getter, label_name, extra_entities = self._row_processor + if self.translate_raw_column: + extra_entities += ( + context.query._raw_columns[self.raw_column_index], + ) + + return getter, label_name, extra_entities + + # retrieve the column that would have been set up in + # setup_compile_state, to avoid doing redundant work + if self._fetch_column is not None: + column = self._fetch_column + else: + # fetch_column will be None when we are doing a from_statement + # and setup_compile_state may not have been called. + column = self.column + + # previously, the RawColumnEntity didn't look for from_obj_alias + # however I can't think of a case where we would be here and + # we'd want to ignore it if this is the from_statement use case. + # it's not really a use case to have raw columns + from_statement + if compile_state._from_obj_alias: + column = compile_state._from_obj_alias.columns[column] + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + if compile_state.compound_eager_adapter: + column = compile_state.compound_eager_adapter.columns[column] + + getter = result._getter(column) + ret = getter, self._label_name, self._extra_entities + self._row_processor = ret + + if self.translate_raw_column: + extra_entities = self._extra_entities + ( + context.query._raw_columns[self.raw_column_index], + ) + return getter, self._label_name, extra_entities + else: + return ret + + +class _RawColumnEntity(_ColumnEntity): + entity_zero = None + mapper = None + supports_single_entity = False + + __slots__ = ( + "expr", + "column", + "_label_name", + "entity_zero_or_selectable", + "_extra_entities", + ) + + def __init__( + self, + compile_state, + column, + entities_collection, + raw_column_index, + is_current_entities, + parent_bundle=None, + ): + self.expr = column + self.raw_column_index = raw_column_index + self.translate_raw_column = raw_column_index is not None + + if column._is_star: + compile_state.compile_options += {"_is_star": True} + + if not is_current_entities or column._is_text_clause: + self._label_name = None + else: + self._label_name = compile_state._label_convention(column) + + if parent_bundle: + parent_bundle._entities.append(self) + else: + entities_collection.append(self) + + self.column = column + self.entity_zero_or_selectable = ( + self.column._from_objects[0] if self.column._from_objects else None + ) + self._extra_entities = (self.expr, self.column) + self._fetch_column = self._row_processor = None + + def corresponds_to(self, entity): + return False + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + return self.setup_compile_state(compile_state) + + def setup_compile_state(self, compile_state): + current_adapter = compile_state._get_current_adapter() + if current_adapter: + column = current_adapter(self.column, False) + if column is None: + return + else: + column = self.column + + if column._annotations: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + self._fetch_column = column + + +class _ORMColumnEntity(_ColumnEntity): + """Column/expression based entity.""" + + supports_single_entity = False + + __slots__ = ( + "expr", + "mapper", + "column", + "_label_name", + "entity_zero_or_selectable", + "entity_zero", + "_extra_entities", + ) + + def __init__( + self, + compile_state, + column, + entities_collection, + parententity, + raw_column_index, + is_current_entities, + parent_bundle=None, + ): + annotations = column._annotations + + _entity = parententity + + # an AliasedClass won't have proxy_key in the annotations for + # a column if it was acquired using the class' adapter directly, + # such as using AliasedInsp._adapt_element(). this occurs + # within internal loaders. + + orm_key = annotations.get("proxy_key", None) + proxy_owner = annotations.get("proxy_owner", _entity) + if orm_key: + self.expr = getattr(proxy_owner.entity, orm_key) + self.translate_raw_column = False + else: + # if orm_key is not present, that means this is an ad-hoc + # SQL ColumnElement, like a CASE() or other expression. + # include this column position from the invoked statement + # in the ORM-level ResultSetMetaData on each execute, so that + # it can be targeted by identity after caching + self.expr = column + self.translate_raw_column = raw_column_index is not None + + self.raw_column_index = raw_column_index + + if is_current_entities: + self._label_name = compile_state._label_convention( + column, col_name=orm_key + ) + else: + self._label_name = None + + _entity._post_inspect + self.entity_zero = self.entity_zero_or_selectable = ezero = _entity + self.mapper = mapper = _entity.mapper + + if parent_bundle: + parent_bundle._entities.append(self) + else: + entities_collection.append(self) + + compile_state._has_orm_entities = True + + self.column = column + + self._fetch_column = self._row_processor = None + + self._extra_entities = (self.expr, self.column) + + if mapper._should_select_with_poly_adapter: + compile_state._create_with_polymorphic_adapter( + ezero, ezero.selectable + ) + + def corresponds_to(self, entity): + if _is_aliased_class(entity): + # TODO: polymorphic subclasses ? + return entity is self.entity_zero + else: + return not _is_aliased_class( + self.entity_zero + ) and entity.common_parent(self.entity_zero) + + def setup_dml_returning_compile_state( + self, + compile_state: ORMCompileState, + adapter: DMLReturningColFilter, + ) -> None: + self._fetch_column = self.column + column = adapter(self.column, False) + if column is not None: + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + + def setup_compile_state(self, compile_state): + current_adapter = compile_state._get_current_adapter() + if current_adapter: + column = current_adapter(self.column, False) + if column is None: + assert compile_state.is_dml_returning + self._fetch_column = self.column + return + else: + column = self.column + + ezero = self.entity_zero + + single_table_crit = self.mapper._single_table_criterion + if ( + single_table_crit is not None + or ("additional_entity_criteria", self.mapper) + in compile_state.global_attributes + ): + compile_state.extra_criteria_entities[ezero] = ( + ezero, + ezero._adapter if ezero.is_aliased_class else None, + ) + + if column._annotations and not column._expression_label: + # annotated columns perform more slowly in compiler and + # result due to the __eq__() method, so use deannotated + column = column._deannotate() + + # use entity_zero as the from if we have it. this is necessary + # for polymorphic scenarios where our FROM is based on ORM entity, + # not the FROM of the column. but also, don't use it if our column + # doesn't actually have any FROMs that line up, such as when its + # a scalar subquery. + if set(self.column._from_objects).intersection( + ezero.selectable._from_objects + ): + compile_state._fallback_from_clauses.append(ezero.selectable) + + compile_state.dedupe_columns.add(column) + compile_state.primary_columns.append(column) + self._fetch_column = column + + +class _IdentityTokenEntity(_ORMColumnEntity): + translate_raw_column = False + + def setup_compile_state(self, compile_state): + pass + + def row_processor(self, context, result): + def getter(row): + return context.load_options._identity_token + + return getter, self._label_name, self._extra_entities diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_api.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_api.py new file mode 100644 index 0000000..09128ea --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_api.py @@ -0,0 +1,1875 @@ +# orm/decl_api.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Public API functions and helpers for declarative.""" + +from __future__ import annotations + +import itertools +import re +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import clsregistry +from . import instrumentation +from . import interfaces +from . import mapperlib +from ._orm_constructors import composite +from ._orm_constructors import deferred +from ._orm_constructors import mapped_column +from ._orm_constructors import relationship +from ._orm_constructors import synonym +from .attributes import InstrumentedAttribute +from .base import _inspect_mapped_class +from .base import _is_mapped_class +from .base import Mapped +from .base import ORMDescriptor +from .decl_base import _add_attribute +from .decl_base import _as_declarative +from .decl_base import _ClassScanMapperConfig +from .decl_base import _declarative_constructor +from .decl_base import _DeferredMapperConfig +from .decl_base import _del_attribute +from .decl_base import _mapper +from .descriptor_props import Composite +from .descriptor_props import Synonym +from .descriptor_props import Synonym as _orm_synonym +from .mapper import Mapper +from .properties import MappedColumn +from .relationships import RelationshipProperty +from .state import InstanceState +from .. import exc +from .. import inspection +from .. import util +from ..sql import sqltypes +from ..sql.base import _NoArg +from ..sql.elements import SQLCoreOperations +from ..sql.schema import MetaData +from ..sql.selectable import FromClause +from ..util import hybridmethod +from ..util import hybridproperty +from ..util import typing as compat_typing +from ..util.typing import CallableReference +from ..util.typing import flatten_newtype +from ..util.typing import is_generic +from ..util.typing import is_literal +from ..util.typing import is_newtype +from ..util.typing import is_pep695 +from ..util.typing import Literal +from ..util.typing import Self + +if TYPE_CHECKING: + from ._typing import _O + from ._typing import _RegistryType + from .decl_base import _DataclassArguments + from .instrumentation import ClassManager + from .interfaces import MapperProperty + from .state import InstanceState # noqa + from ..sql._typing import _TypeEngineArgument + from ..sql.type_api import _MatchedOnType + +_T = TypeVar("_T", bound=Any) + +_TT = TypeVar("_TT", bound=Any) + +# it's not clear how to have Annotated, Union objects etc. as keys here +# from a typing perspective so just leave it open ended for now +_TypeAnnotationMapType = Mapping[Any, "_TypeEngineArgument[Any]"] +_MutableTypeAnnotationMapType = Dict[Any, "_TypeEngineArgument[Any]"] + +_DeclaredAttrDecorated = Callable[ + ..., Union[Mapped[_T], ORMDescriptor[_T], SQLCoreOperations[_T]] +] + + +def has_inherited_table(cls: Type[_O]) -> bool: + """Given a class, return True if any of the classes it inherits from has a + mapped table, otherwise return False. + + This is used in declarative mixins to build attributes that behave + differently for the base class vs. a subclass in an inheritance + hierarchy. + + .. seealso:: + + :ref:`decl_mixin_inheritance` + + """ + for class_ in cls.__mro__[1:]: + if getattr(class_, "__table__", None) is not None: + return True + return False + + +class _DynamicAttributesType(type): + def __setattr__(cls, key: str, value: Any) -> None: + if "__mapper__" in cls.__dict__: + _add_attribute(cls, key, value) + else: + type.__setattr__(cls, key, value) + + def __delattr__(cls, key: str) -> None: + if "__mapper__" in cls.__dict__: + _del_attribute(cls, key) + else: + type.__delattr__(cls, key) + + +class DeclarativeAttributeIntercept( + _DynamicAttributesType, + # Inspectable is used only by the mypy plugin + inspection.Inspectable[Mapper[Any]], +): + """Metaclass that may be used in conjunction with the + :class:`_orm.DeclarativeBase` class to support addition of class + attributes dynamically. + + """ + + +@compat_typing.dataclass_transform( + field_specifiers=( + MappedColumn, + RelationshipProperty, + Composite, + Synonym, + mapped_column, + relationship, + composite, + synonym, + deferred, + ), +) +class DCTransformDeclarative(DeclarativeAttributeIntercept): + """metaclass that includes @dataclass_transforms""" + + +class DeclarativeMeta(DeclarativeAttributeIntercept): + metadata: MetaData + registry: RegistryType + + def __init__( + cls, classname: Any, bases: Any, dict_: Any, **kw: Any + ) -> None: + # use cls.__dict__, which can be modified by an + # __init_subclass__() method (#7900) + dict_ = cls.__dict__ + + # early-consume registry from the initial declarative base, + # assign privately to not conflict with subclass attributes named + # "registry" + reg = getattr(cls, "_sa_registry", None) + if reg is None: + reg = dict_.get("registry", None) + if not isinstance(reg, registry): + raise exc.InvalidRequestError( + "Declarative base class has no 'registry' attribute, " + "or registry is not a sqlalchemy.orm.registry() object" + ) + else: + cls._sa_registry = reg + + if not cls.__dict__.get("__abstract__", False): + _as_declarative(reg, cls, dict_) + type.__init__(cls, classname, bases, dict_) + + +def synonym_for( + name: str, map_column: bool = False +) -> Callable[[Callable[..., Any]], Synonym[Any]]: + """Decorator that produces an :func:`_orm.synonym` + attribute in conjunction with a Python descriptor. + + The function being decorated is passed to :func:`_orm.synonym` as the + :paramref:`.orm.synonym.descriptor` parameter:: + + class MyClass(Base): + __tablename__ = 'my_table' + + id = Column(Integer, primary_key=True) + _job_status = Column("job_status", String(50)) + + @synonym_for("job_status") + @property + def job_status(self): + return "Status: %s" % self._job_status + + The :ref:`hybrid properties ` feature of SQLAlchemy + is typically preferred instead of synonyms, which is a more legacy + feature. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + :func:`_orm.synonym` - the mapper-level function + + :ref:`mapper_hybrids` - The Hybrid Attribute extension provides an + updated approach to augmenting attribute behavior more flexibly than + can be achieved with synonyms. + + """ + + def decorate(fn: Callable[..., Any]) -> Synonym[Any]: + return _orm_synonym(name, map_column=map_column, descriptor=fn) + + return decorate + + +class _declared_attr_common: + def __init__( + self, + fn: Callable[..., Any], + cascading: bool = False, + quiet: bool = False, + ): + # suppport + # @declared_attr + # @classmethod + # def foo(cls) -> Mapped[thing]: + # ... + # which seems to help typing tools interpret the fn as a classmethod + # for situations where needed + if isinstance(fn, classmethod): + fn = fn.__func__ + + self.fget = fn + self._cascading = cascading + self._quiet = quiet + self.__doc__ = fn.__doc__ + + def _collect_return_annotation(self) -> Optional[Type[Any]]: + return util.get_annotations(self.fget).get("return") + + def __get__(self, instance: Optional[object], owner: Any) -> Any: + # the declared_attr needs to make use of a cache that exists + # for the span of the declarative scan_attributes() phase. + # to achieve this we look at the class manager that's configured. + + # note this method should not be called outside of the declarative + # setup phase + + cls = owner + manager = attributes.opt_manager_of_class(cls) + if manager is None: + if not re.match(r"^__.+__$", self.fget.__name__): + # if there is no manager at all, then this class hasn't been + # run through declarative or mapper() at all, emit a warning. + util.warn( + "Unmanaged access of declarative attribute %s from " + "non-mapped class %s" % (self.fget.__name__, cls.__name__) + ) + return self.fget(cls) + elif manager.is_mapped: + # the class is mapped, which means we're outside of the declarative + # scan setup, just run the function. + return self.fget(cls) + + # here, we are inside of the declarative scan. use the registry + # that is tracking the values of these attributes. + declarative_scan = manager.declarative_scan() + + # assert that we are in fact in the declarative scan + assert declarative_scan is not None + + reg = declarative_scan.declared_attr_reg + + if self in reg: + return reg[self] + else: + reg[self] = obj = self.fget(cls) + return obj + + +class _declared_directive(_declared_attr_common, Generic[_T]): + # see mapping_api.rst for docstring + + if typing.TYPE_CHECKING: + + def __init__( + self, + fn: Callable[..., _T], + cascading: bool = False, + ): ... + + def __get__(self, instance: Optional[object], owner: Any) -> _T: ... + + def __set__(self, instance: Any, value: Any) -> None: ... + + def __delete__(self, instance: Any) -> None: ... + + def __call__(self, fn: Callable[..., _TT]) -> _declared_directive[_TT]: + # extensive fooling of mypy underway... + ... + + +class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): + """Mark a class-level method as representing the definition of + a mapped property or Declarative directive. + + :class:`_orm.declared_attr` is typically applied as a decorator to a class + level method, turning the attribute into a scalar-like property that can be + invoked from the uninstantiated class. The Declarative mapping process + looks for these :class:`_orm.declared_attr` callables as it scans classes, + and assumes any attribute marked with :class:`_orm.declared_attr` will be a + callable that will produce an object specific to the Declarative mapping or + table configuration. + + :class:`_orm.declared_attr` is usually applicable to + :ref:`mixins `, to define relationships that are to be + applied to different implementors of the class. It may also be used to + define dynamically generated column expressions and other Declarative + attributes. + + Example:: + + class ProvidesUserMixin: + "A mixin that adds a 'user' relationship to classes." + + user_id: Mapped[int] = mapped_column(ForeignKey("user_table.id")) + + @declared_attr + def user(cls) -> Mapped["User"]: + return relationship("User") + + When used with Declarative directives such as ``__tablename__``, the + :meth:`_orm.declared_attr.directive` modifier may be used which indicates + to :pep:`484` typing tools that the given method is not dealing with + :class:`_orm.Mapped` attributes:: + + class CreateTableName: + @declared_attr.directive + def __tablename__(cls) -> str: + return cls.__name__.lower() + + :class:`_orm.declared_attr` can also be applied directly to mapped + classes, to allow for attributes that dynamically configure themselves + on subclasses when using mapped inheritance schemes. Below + illustrates :class:`_orm.declared_attr` to create a dynamic scheme + for generating the :paramref:`_orm.Mapper.polymorphic_identity` parameter + for subclasses:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + type: Mapped[str] = mapped_column(String(50)) + + @declared_attr.directive + def __mapper_args__(cls) -> Dict[str, Any]: + if cls.__name__ == 'Employee': + return { + "polymorphic_on":cls.type, + "polymorphic_identity":"Employee" + } + else: + return {"polymorphic_identity":cls.__name__} + + class Engineer(Employee): + pass + + :class:`_orm.declared_attr` supports decorating functions that are + explicitly decorated with ``@classmethod``. This is never necessary from a + runtime perspective, however may be needed in order to support :pep:`484` + typing tools that don't otherwise recognize the decorated function as + having class-level behaviors for the ``cls`` parameter:: + + class SomethingMixin: + x: Mapped[int] + y: Mapped[int] + + @declared_attr + @classmethod + def x_plus_y(cls) -> Mapped[int]: + return column_property(cls.x + cls.y) + + .. versionadded:: 2.0 - :class:`_orm.declared_attr` can accommodate a + function decorated with ``@classmethod`` to help with :pep:`484` + integration where needed. + + + .. seealso:: + + :ref:`orm_mixins_toplevel` - Declarative Mixin documentation with + background on use patterns for :class:`_orm.declared_attr`. + + """ # noqa: E501 + + if typing.TYPE_CHECKING: + + def __init__( + self, + fn: _DeclaredAttrDecorated[_T], + cascading: bool = False, + ): ... + + def __set__(self, instance: Any, value: Any) -> None: ... + + def __delete__(self, instance: Any) -> None: ... + + # this is the Mapped[] API where at class descriptor get time we want + # the type checker to see InstrumentedAttribute[_T]. However the + # callable function prior to mapping in fact calls the given + # declarative function that does not return InstrumentedAttribute + @overload + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T]: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _T: ... + + def __get__( + self, instance: Optional[object], owner: Any + ) -> Union[InstrumentedAttribute[_T], _T]: ... + + @hybridmethod + def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]: + return _stateful_declared_attr(**kw) + + @hybridproperty + def directive(cls) -> _declared_directive[Any]: + # see mapping_api.rst for docstring + return _declared_directive # type: ignore + + @hybridproperty + def cascading(cls) -> _stateful_declared_attr[_T]: + # see mapping_api.rst for docstring + return cls._stateful(cascading=True) + + +class _stateful_declared_attr(declared_attr[_T]): + kw: Dict[str, Any] + + def __init__(self, **kw: Any): + self.kw = kw + + @hybridmethod + def _stateful(self, **kw: Any) -> _stateful_declared_attr[_T]: + new_kw = self.kw.copy() + new_kw.update(kw) + return _stateful_declared_attr(**new_kw) + + def __call__(self, fn: _DeclaredAttrDecorated[_T]) -> declared_attr[_T]: + return declared_attr(fn, **self.kw) + + +def declarative_mixin(cls: Type[_T]) -> Type[_T]: + """Mark a class as providing the feature of "declarative mixin". + + E.g.:: + + from sqlalchemy.orm import declared_attr + from sqlalchemy.orm import declarative_mixin + + @declarative_mixin + class MyMixin: + + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + __table_args__ = {'mysql_engine': 'InnoDB'} + __mapper_args__= {'always_refresh': True} + + id = Column(Integer, primary_key=True) + + class MyModel(MyMixin, Base): + name = Column(String(1000)) + + The :func:`_orm.declarative_mixin` decorator currently does not modify + the given class in any way; it's current purpose is strictly to assist + the :ref:`Mypy plugin ` in being able to identify + SQLAlchemy declarative mixin classes when no other context is present. + + .. versionadded:: 1.4.6 + + .. seealso:: + + :ref:`orm_mixins_toplevel` + + :ref:`mypy_declarative_mixins` - in the + :ref:`Mypy plugin documentation ` + + """ # noqa: E501 + + return cls + + +def _setup_declarative_base(cls: Type[Any]) -> None: + if "metadata" in cls.__dict__: + metadata = cls.__dict__["metadata"] + else: + metadata = None + + if "type_annotation_map" in cls.__dict__: + type_annotation_map = cls.__dict__["type_annotation_map"] + else: + type_annotation_map = None + + reg = cls.__dict__.get("registry", None) + if reg is not None: + if not isinstance(reg, registry): + raise exc.InvalidRequestError( + "Declarative base class has a 'registry' attribute that is " + "not an instance of sqlalchemy.orm.registry()" + ) + elif type_annotation_map is not None: + raise exc.InvalidRequestError( + "Declarative base class has both a 'registry' attribute and a " + "type_annotation_map entry. Per-base type_annotation_maps " + "are not supported. Please apply the type_annotation_map " + "to this registry directly." + ) + + else: + reg = registry( + metadata=metadata, type_annotation_map=type_annotation_map + ) + cls.registry = reg + + cls._sa_registry = reg + + if "metadata" not in cls.__dict__: + cls.metadata = cls.registry.metadata + + if getattr(cls, "__init__", object.__init__) is object.__init__: + cls.__init__ = cls.registry.constructor + + +class MappedAsDataclass(metaclass=DCTransformDeclarative): + """Mixin class to indicate when mapping this class, also convert it to be + a dataclass. + + .. seealso:: + + :ref:`orm_declarative_native_dataclasses` - complete background + on SQLAlchemy native dataclass mapping + + .. versionadded:: 2.0 + + """ + + def __init_subclass__( + cls, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + eq: Union[_NoArg, bool] = _NoArg.NO_ARG, + order: Union[_NoArg, bool] = _NoArg.NO_ARG, + unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, + match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, + **kw: Any, + ) -> None: + apply_dc_transforms: _DataclassArguments = { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + "match_args": match_args, + "kw_only": kw_only, + "dataclass_callable": dataclass_callable, + } + + current_transforms: _DataclassArguments + + if hasattr(cls, "_sa_apply_dc_transforms"): + current = cls._sa_apply_dc_transforms + + _ClassScanMapperConfig._assert_dc_arguments(current) + + cls._sa_apply_dc_transforms = current_transforms = { # type: ignore # noqa: E501 + k: current.get(k, _NoArg.NO_ARG) if v is _NoArg.NO_ARG else v + for k, v in apply_dc_transforms.items() + } + else: + cls._sa_apply_dc_transforms = current_transforms = ( + apply_dc_transforms + ) + + super().__init_subclass__(**kw) + + if not _is_mapped_class(cls): + new_anno = ( + _ClassScanMapperConfig._update_annotations_for_non_mapped_class + )(cls) + _ClassScanMapperConfig._apply_dataclasses_to_any_class( + current_transforms, cls, new_anno + ) + + +class DeclarativeBase( + # Inspectable is used only by the mypy plugin + inspection.Inspectable[InstanceState[Any]], + metaclass=DeclarativeAttributeIntercept, +): + """Base class used for declarative class definitions. + + The :class:`_orm.DeclarativeBase` allows for the creation of new + declarative bases in such a way that is compatible with type checkers:: + + + from sqlalchemy.orm import DeclarativeBase + + class Base(DeclarativeBase): + pass + + + The above ``Base`` class is now usable as the base for new declarative + mappings. The superclass makes use of the ``__init_subclass__()`` + method to set up new classes and metaclasses aren't used. + + When first used, the :class:`_orm.DeclarativeBase` class instantiates a new + :class:`_orm.registry` to be used with the base, assuming one was not + provided explicitly. The :class:`_orm.DeclarativeBase` class supports + class-level attributes which act as parameters for the construction of this + registry; such as to indicate a specific :class:`_schema.MetaData` + collection as well as a specific value for + :paramref:`_orm.registry.type_annotation_map`:: + + from typing_extensions import Annotated + + from sqlalchemy import BigInteger + from sqlalchemy import MetaData + from sqlalchemy import String + from sqlalchemy.orm import DeclarativeBase + + bigint = Annotated[int, "bigint"] + my_metadata = MetaData() + + class Base(DeclarativeBase): + metadata = my_metadata + type_annotation_map = { + str: String().with_variant(String(255), "mysql", "mariadb"), + bigint: BigInteger() + } + + Class-level attributes which may be specified include: + + :param metadata: optional :class:`_schema.MetaData` collection. + If a :class:`_orm.registry` is constructed automatically, this + :class:`_schema.MetaData` collection will be used to construct it. + Otherwise, the local :class:`_schema.MetaData` collection will supercede + that used by an existing :class:`_orm.registry` passed using the + :paramref:`_orm.DeclarativeBase.registry` parameter. + :param type_annotation_map: optional type annotation map that will be + passed to the :class:`_orm.registry` as + :paramref:`_orm.registry.type_annotation_map`. + :param registry: supply a pre-existing :class:`_orm.registry` directly. + + .. versionadded:: 2.0 Added :class:`.DeclarativeBase`, so that declarative + base classes may be constructed in such a way that is also recognized + by :pep:`484` type checkers. As a result, :class:`.DeclarativeBase` + and other subclassing-oriented APIs should be seen as + superseding previous "class returned by a function" APIs, namely + :func:`_orm.declarative_base` and :meth:`_orm.registry.generate_base`, + where the base class returned cannot be recognized by type checkers + without using plugins. + + **__init__ behavior** + + In a plain Python class, the base-most ``__init__()`` method in the class + hierarchy is ``object.__init__()``, which accepts no arguments. However, + when the :class:`_orm.DeclarativeBase` subclass is first declared, the + class is given an ``__init__()`` method that links to the + :paramref:`_orm.registry.constructor` constructor function, if no + ``__init__()`` method is already present; this is the usual declarative + constructor that will assign keyword arguments as attributes on the + instance, assuming those attributes are established at the class level + (i.e. are mapped, or are linked to a descriptor). This constructor is + **never accessed by a mapped class without being called explicitly via + super()**, as mapped classes are themselves given an ``__init__()`` method + directly which calls :paramref:`_orm.registry.constructor`, so in the + default case works independently of what the base-most ``__init__()`` + method does. + + .. versionchanged:: 2.0.1 :class:`_orm.DeclarativeBase` has a default + constructor that links to :paramref:`_orm.registry.constructor` by + default, so that calls to ``super().__init__()`` can access this + constructor. Previously, due to an implementation mistake, this default + constructor was missing, and calling ``super().__init__()`` would invoke + ``object.__init__()``. + + The :class:`_orm.DeclarativeBase` subclass may also declare an explicit + ``__init__()`` method which will replace the use of the + :paramref:`_orm.registry.constructor` function at this level:: + + class Base(DeclarativeBase): + def __init__(self, id=None): + self.id = id + + Mapped classes still will not invoke this constructor implicitly; it + remains only accessible by calling ``super().__init__()``:: + + class MyClass(Base): + def __init__(self, id=None, name=None): + self.name = name + super().__init__(id=id) + + Note that this is a different behavior from what functions like the legacy + :func:`_orm.declarative_base` would do; the base created by those functions + would always install :paramref:`_orm.registry.constructor` for + ``__init__()``. + + + """ + + if typing.TYPE_CHECKING: + + def _sa_inspect_type(self) -> Mapper[Self]: ... + + def _sa_inspect_instance(self) -> InstanceState[Self]: ... + + _sa_registry: ClassVar[_RegistryType] + + registry: ClassVar[_RegistryType] + """Refers to the :class:`_orm.registry` in use where new + :class:`_orm.Mapper` objects will be associated.""" + + metadata: ClassVar[MetaData] + """Refers to the :class:`_schema.MetaData` collection that will be used + for new :class:`_schema.Table` objects. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + __name__: ClassVar[str] + + # this ideally should be Mapper[Self], but mypy as of 1.4.1 does not + # like it, and breaks the declared_attr_one test. Pyright/pylance is + # ok with it. + __mapper__: ClassVar[Mapper[Any]] + """The :class:`_orm.Mapper` object to which a particular class is + mapped. + + May also be acquired using :func:`_sa.inspect`, e.g. + ``inspect(klass)``. + + """ + + __table__: ClassVar[FromClause] + """The :class:`_sql.FromClause` to which a particular subclass is + mapped. + + This is usually an instance of :class:`_schema.Table` but may also + refer to other kinds of :class:`_sql.FromClause` such as + :class:`_sql.Subquery`, depending on how the class is mapped. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + # pyright/pylance do not consider a classmethod a ClassVar so use Any + # https://github.com/microsoft/pylance-release/issues/3484 + __tablename__: Any + """String name to assign to the generated + :class:`_schema.Table` object, if not specified directly via + :attr:`_orm.DeclarativeBase.__table__`. + + .. seealso:: + + :ref:`orm_declarative_table` + + """ + + __mapper_args__: Any + """Dictionary of arguments which will be passed to the + :class:`_orm.Mapper` constructor. + + .. seealso:: + + :ref:`orm_declarative_mapper_options` + + """ + + __table_args__: Any + """A dictionary or tuple of arguments that will be passed to the + :class:`_schema.Table` constructor. See + :ref:`orm_declarative_table_configuration` + for background on the specific structure of this collection. + + .. seealso:: + + :ref:`orm_declarative_table_configuration` + + """ + + def __init__(self, **kw: Any): ... + + def __init_subclass__(cls, **kw: Any) -> None: + if DeclarativeBase in cls.__bases__: + _check_not_declarative(cls, DeclarativeBase) + _setup_declarative_base(cls) + else: + _as_declarative(cls._sa_registry, cls, cls.__dict__) + super().__init_subclass__(**kw) + + +def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: + cls_dict = cls.__dict__ + if ( + "__table__" in cls_dict + and not ( + callable(cls_dict["__table__"]) + or hasattr(cls_dict["__table__"], "__get__") + ) + ) or isinstance(cls_dict.get("__tablename__", None), str): + raise exc.InvalidRequestError( + f"Cannot use {base.__name__!r} directly as a declarative base " + "class. Create a Base by creating a subclass of it." + ) + + +class DeclarativeBaseNoMeta( + # Inspectable is used only by the mypy plugin + inspection.Inspectable[InstanceState[Any]] +): + """Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass + to intercept new attributes. + + The :class:`_orm.DeclarativeBaseNoMeta` base may be used when use of + custom metaclasses is desirable. + + .. versionadded:: 2.0 + + + """ + + _sa_registry: ClassVar[_RegistryType] + + registry: ClassVar[_RegistryType] + """Refers to the :class:`_orm.registry` in use where new + :class:`_orm.Mapper` objects will be associated.""" + + metadata: ClassVar[MetaData] + """Refers to the :class:`_schema.MetaData` collection that will be used + for new :class:`_schema.Table` objects. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + # this ideally should be Mapper[Self], but mypy as of 1.4.1 does not + # like it, and breaks the declared_attr_one test. Pyright/pylance is + # ok with it. + __mapper__: ClassVar[Mapper[Any]] + """The :class:`_orm.Mapper` object to which a particular class is + mapped. + + May also be acquired using :func:`_sa.inspect`, e.g. + ``inspect(klass)``. + + """ + + __table__: Optional[FromClause] + """The :class:`_sql.FromClause` to which a particular subclass is + mapped. + + This is usually an instance of :class:`_schema.Table` but may also + refer to other kinds of :class:`_sql.FromClause` such as + :class:`_sql.Subquery`, depending on how the class is mapped. + + .. seealso:: + + :ref:`orm_declarative_metadata` + + """ + + if typing.TYPE_CHECKING: + + def _sa_inspect_type(self) -> Mapper[Self]: ... + + def _sa_inspect_instance(self) -> InstanceState[Self]: ... + + __tablename__: Any + """String name to assign to the generated + :class:`_schema.Table` object, if not specified directly via + :attr:`_orm.DeclarativeBase.__table__`. + + .. seealso:: + + :ref:`orm_declarative_table` + + """ + + __mapper_args__: Any + """Dictionary of arguments which will be passed to the + :class:`_orm.Mapper` constructor. + + .. seealso:: + + :ref:`orm_declarative_mapper_options` + + """ + + __table_args__: Any + """A dictionary or tuple of arguments that will be passed to the + :class:`_schema.Table` constructor. See + :ref:`orm_declarative_table_configuration` + for background on the specific structure of this collection. + + .. seealso:: + + :ref:`orm_declarative_table_configuration` + + """ + + def __init__(self, **kw: Any): ... + + def __init_subclass__(cls, **kw: Any) -> None: + if DeclarativeBaseNoMeta in cls.__bases__: + _check_not_declarative(cls, DeclarativeBaseNoMeta) + _setup_declarative_base(cls) + else: + _as_declarative(cls._sa_registry, cls, cls.__dict__) + super().__init_subclass__(**kw) + + +def add_mapped_attribute( + target: Type[_O], key: str, attr: MapperProperty[Any] +) -> None: + """Add a new mapped attribute to an ORM mapped class. + + E.g.:: + + add_mapped_attribute(User, "addresses", relationship(Address)) + + This may be used for ORM mappings that aren't using a declarative + metaclass that intercepts attribute set operations. + + .. versionadded:: 2.0 + + + """ + _add_attribute(target, key, attr) + + +def declarative_base( + *, + metadata: Optional[MetaData] = None, + mapper: Optional[Callable[..., Mapper[Any]]] = None, + cls: Type[Any] = object, + name: str = "Base", + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, + metaclass: Type[Any] = DeclarativeMeta, +) -> Any: + r"""Construct a base class for declarative class definitions. + + The new base class will be given a metaclass that produces + appropriate :class:`~sqlalchemy.schema.Table` objects and makes + the appropriate :class:`_orm.Mapper` calls based on the + information provided declaratively in the class and any subclasses + of the class. + + .. versionchanged:: 2.0 Note that the :func:`_orm.declarative_base` + function is superseded by the new :class:`_orm.DeclarativeBase` class, + which generates a new "base" class using subclassing, rather than + return value of a function. This allows an approach that is compatible + with :pep:`484` typing tools. + + The :func:`_orm.declarative_base` function is a shorthand version + of using the :meth:`_orm.registry.generate_base` + method. That is, the following:: + + from sqlalchemy.orm import declarative_base + + Base = declarative_base() + + Is equivalent to:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + Base = mapper_registry.generate_base() + + See the docstring for :class:`_orm.registry` + and :meth:`_orm.registry.generate_base` + for more details. + + .. versionchanged:: 1.4 The :func:`_orm.declarative_base` + function is now a specialization of the more generic + :class:`_orm.registry` class. The function also moves to the + ``sqlalchemy.orm`` package from the ``declarative.ext`` package. + + + :param metadata: + An optional :class:`~sqlalchemy.schema.MetaData` instance. All + :class:`~sqlalchemy.schema.Table` objects implicitly declared by + subclasses of the base will share this MetaData. A MetaData instance + will be created if none is provided. The + :class:`~sqlalchemy.schema.MetaData` instance will be available via the + ``metadata`` attribute of the generated declarative base class. + + :param mapper: + An optional callable, defaults to :class:`_orm.Mapper`. Will + be used to map subclasses to their Tables. + + :param cls: + Defaults to :class:`object`. A type to use as the base for the generated + declarative base class. May be a class or tuple of classes. + + :param name: + Defaults to ``Base``. The display name for the generated + class. Customizing this is not required, but can improve clarity in + tracebacks and debugging. + + :param constructor: + Specify the implementation for the ``__init__`` function on a mapped + class that has no ``__init__`` of its own. Defaults to an + implementation that assigns \**kwargs for declared + fields and relationships to an instance. If ``None`` is supplied, + no __init__ will be provided and construction will fall back to + cls.__init__ by way of the normal Python semantics. + + :param class_registry: optional dictionary that will serve as the + registry of class names-> mapped classes when string names + are used to identify classes inside of :func:`_orm.relationship` + and others. Allows two or more declarative base classes + to share the same registry of class names for simplified + inter-base relationships. + + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_declarative_mapped_column_type_map` + + :param metaclass: + Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ + compatible callable to use as the meta type of the generated + declarative base class. + + .. seealso:: + + :class:`_orm.registry` + + """ + + return registry( + metadata=metadata, + class_registry=class_registry, + constructor=constructor, + type_annotation_map=type_annotation_map, + ).generate_base( + mapper=mapper, + cls=cls, + name=name, + metaclass=metaclass, + ) + + +class registry: + """Generalized registry for mapping classes. + + The :class:`_orm.registry` serves as the basis for maintaining a collection + of mappings, and provides configurational hooks used to map classes. + + The three general kinds of mappings supported are Declarative Base, + Declarative Decorator, and Imperative Mapping. All of these mapping + styles may be used interchangeably: + + * :meth:`_orm.registry.generate_base` returns a new declarative base + class, and is the underlying implementation of the + :func:`_orm.declarative_base` function. + + * :meth:`_orm.registry.mapped` provides a class decorator that will + apply declarative mapping to a class without the use of a declarative + base class. + + * :meth:`_orm.registry.map_imperatively` will produce a + :class:`_orm.Mapper` for a class without scanning the class for + declarative class attributes. This method suits the use case historically + provided by the ``sqlalchemy.orm.mapper()`` classical mapping function, + which is removed as of SQLAlchemy 2.0. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_mapping_classes_toplevel` - overview of class mapping + styles. + + """ + + _class_registry: clsregistry._ClsRegistryType + _managers: weakref.WeakKeyDictionary[ClassManager[Any], Literal[True]] + _non_primary_mappers: weakref.WeakKeyDictionary[Mapper[Any], Literal[True]] + metadata: MetaData + constructor: CallableReference[Callable[..., None]] + type_annotation_map: _MutableTypeAnnotationMapType + _dependents: Set[_RegistryType] + _dependencies: Set[_RegistryType] + _new_mappers: bool + + def __init__( + self, + *, + metadata: Optional[MetaData] = None, + class_registry: Optional[clsregistry._ClsRegistryType] = None, + type_annotation_map: Optional[_TypeAnnotationMapType] = None, + constructor: Callable[..., None] = _declarative_constructor, + ): + r"""Construct a new :class:`_orm.registry` + + :param metadata: + An optional :class:`_schema.MetaData` instance. All + :class:`_schema.Table` objects generated using declarative + table mapping will make use of this :class:`_schema.MetaData` + collection. If this argument is left at its default of ``None``, + a blank :class:`_schema.MetaData` collection is created. + + :param constructor: + Specify the implementation for the ``__init__`` function on a mapped + class that has no ``__init__`` of its own. Defaults to an + implementation that assigns \**kwargs for declared + fields and relationships to an instance. If ``None`` is supplied, + no __init__ will be provided and construction will fall back to + cls.__init__ by way of the normal Python semantics. + + :param class_registry: optional dictionary that will serve as the + registry of class names-> mapped classes when string names + are used to identify classes inside of :func:`_orm.relationship` + and others. Allows two or more declarative base classes + to share the same registry of class names for simplified + inter-base relationships. + + :param type_annotation_map: optional dictionary of Python types to + SQLAlchemy :class:`_types.TypeEngine` classes or instances. + The provided dict will update the default type mapping. This + is used exclusively by the :class:`_orm.MappedColumn` construct + to produce column types based on annotations within the + :class:`_orm.Mapped` type. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_declarative_mapped_column_type_map` + + + """ + lcl_metadata = metadata or MetaData() + + if class_registry is None: + class_registry = weakref.WeakValueDictionary() + + self._class_registry = class_registry + self._managers = weakref.WeakKeyDictionary() + self._non_primary_mappers = weakref.WeakKeyDictionary() + self.metadata = lcl_metadata + self.constructor = constructor + self.type_annotation_map = {} + if type_annotation_map is not None: + self.update_type_annotation_map(type_annotation_map) + self._dependents = set() + self._dependencies = set() + + self._new_mappers = False + + with mapperlib._CONFIGURE_MUTEX: + mapperlib._mapper_registries[self] = True + + def update_type_annotation_map( + self, + type_annotation_map: _TypeAnnotationMapType, + ) -> None: + """update the :paramref:`_orm.registry.type_annotation_map` with new + values.""" + + self.type_annotation_map.update( + { + sub_type: sqltype + for typ, sqltype in type_annotation_map.items() + for sub_type in compat_typing.expand_unions( + typ, include_union=True, discard_none=True + ) + } + ) + + def _resolve_type( + self, python_type: _MatchedOnType + ) -> Optional[sqltypes.TypeEngine[Any]]: + search: Iterable[Tuple[_MatchedOnType, Type[Any]]] + python_type_type: Type[Any] + + if is_generic(python_type): + if is_literal(python_type): + python_type_type = cast("Type[Any]", python_type) + + search = ( # type: ignore[assignment] + (python_type, python_type_type), + (Literal, python_type_type), + ) + else: + python_type_type = python_type.__origin__ + search = ((python_type, python_type_type),) + elif is_newtype(python_type): + python_type_type = flatten_newtype(python_type) + search = ((python_type, python_type_type),) + elif is_pep695(python_type): + python_type_type = python_type.__value__ + flattened = None + search = ((python_type, python_type_type),) + else: + python_type_type = cast("Type[Any]", python_type) + flattened = None + search = ((pt, pt) for pt in python_type_type.__mro__) + + for pt, flattened in search: + # we search through full __mro__ for types. however... + sql_type = self.type_annotation_map.get(pt) + if sql_type is None: + sql_type = sqltypes._type_map_get(pt) # type: ignore # noqa: E501 + + if sql_type is not None: + sql_type_inst = sqltypes.to_instance(sql_type) + + # ... this additional step will reject most + # type -> supertype matches, such as if we had + # a MyInt(int) subclass. note also we pass NewType() + # here directly; these always have to be in the + # type_annotation_map to be useful + resolved_sql_type = sql_type_inst._resolve_for_python_type( + python_type_type, + pt, + flattened, + ) + if resolved_sql_type is not None: + return resolved_sql_type + + return None + + @property + def mappers(self) -> FrozenSet[Mapper[Any]]: + """read only collection of all :class:`_orm.Mapper` objects.""" + + return frozenset(manager.mapper for manager in self._managers).union( + self._non_primary_mappers + ) + + def _set_depends_on(self, registry: RegistryType) -> None: + if registry is self: + return + registry._dependents.add(self) + self._dependencies.add(registry) + + def _flag_new_mapper(self, mapper: Mapper[Any]) -> None: + mapper._ready_for_configure = True + if self._new_mappers: + return + + for reg in self._recurse_with_dependents({self}): + reg._new_mappers = True + + @classmethod + def _recurse_with_dependents( + cls, registries: Set[RegistryType] + ) -> Iterator[RegistryType]: + todo = registries + done = set() + while todo: + reg = todo.pop() + done.add(reg) + + # if yielding would remove dependents, make sure we have + # them before + todo.update(reg._dependents.difference(done)) + yield reg + + # if yielding would add dependents, make sure we have them + # after + todo.update(reg._dependents.difference(done)) + + @classmethod + def _recurse_with_dependencies( + cls, registries: Set[RegistryType] + ) -> Iterator[RegistryType]: + todo = registries + done = set() + while todo: + reg = todo.pop() + done.add(reg) + + # if yielding would remove dependencies, make sure we have + # them before + todo.update(reg._dependencies.difference(done)) + + yield reg + + # if yielding would remove dependencies, make sure we have + # them before + todo.update(reg._dependencies.difference(done)) + + def _mappers_to_configure(self) -> Iterator[Mapper[Any]]: + return itertools.chain( + ( + manager.mapper + for manager in list(self._managers) + if manager.is_mapped + and not manager.mapper.configured + and manager.mapper._ready_for_configure + ), + ( + npm + for npm in list(self._non_primary_mappers) + if not npm.configured and npm._ready_for_configure + ), + ) + + def _add_non_primary_mapper(self, np_mapper: Mapper[Any]) -> None: + self._non_primary_mappers[np_mapper] = True + + def _dispose_cls(self, cls: Type[_O]) -> None: + clsregistry.remove_class(cls.__name__, cls, self._class_registry) + + def _add_manager(self, manager: ClassManager[Any]) -> None: + self._managers[manager] = True + if manager.is_mapped: + raise exc.ArgumentError( + "Class '%s' already has a primary mapper defined. " + % manager.class_ + ) + assert manager.registry is None + manager.registry = self + + def configure(self, cascade: bool = False) -> None: + """Configure all as-yet unconfigured mappers in this + :class:`_orm.registry`. + + The configure step is used to reconcile and initialize the + :func:`_orm.relationship` linkages between mapped classes, as well as + to invoke configuration events such as the + :meth:`_orm.MapperEvents.before_configured` and + :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM + extensions or user-defined extension hooks. + + If one or more mappers in this registry contain + :func:`_orm.relationship` constructs that refer to mapped classes in + other registries, this registry is said to be *dependent* on those + registries. In order to configure those dependent registries + automatically, the :paramref:`_orm.registry.configure.cascade` flag + should be set to ``True``. Otherwise, if they are not configured, an + exception will be raised. The rationale behind this behavior is to + allow an application to programmatically invoke configuration of + registries while controlling whether or not the process implicitly + reaches other registries. + + As an alternative to invoking :meth:`_orm.registry.configure`, the ORM + function :func:`_orm.configure_mappers` function may be used to ensure + configuration is complete for all :class:`_orm.registry` objects in + memory. This is generally simpler to use and also predates the usage of + :class:`_orm.registry` objects overall. However, this function will + impact all mappings throughout the running Python process and may be + more memory/time consuming for an application that has many registries + in use for different purposes that may not be needed immediately. + + .. seealso:: + + :func:`_orm.configure_mappers` + + + .. versionadded:: 1.4.0b2 + + """ + mapperlib._configure_registries({self}, cascade=cascade) + + def dispose(self, cascade: bool = False) -> None: + """Dispose of all mappers in this :class:`_orm.registry`. + + After invocation, all the classes that were mapped within this registry + will no longer have class instrumentation associated with them. This + method is the per-:class:`_orm.registry` analogue to the + application-wide :func:`_orm.clear_mappers` function. + + If this registry contains mappers that are dependencies of other + registries, typically via :func:`_orm.relationship` links, then those + registries must be disposed as well. When such registries exist in + relation to this one, their :meth:`_orm.registry.dispose` method will + also be called, if the :paramref:`_orm.registry.dispose.cascade` flag + is set to ``True``; otherwise, an error is raised if those registries + were not already disposed. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :func:`_orm.clear_mappers` + + """ + + mapperlib._dispose_registries({self}, cascade=cascade) + + def _dispose_manager_and_mapper(self, manager: ClassManager[Any]) -> None: + if "mapper" in manager.__dict__: + mapper = manager.mapper + + mapper._set_dispose_flags() + + class_ = manager.class_ + self._dispose_cls(class_) + instrumentation._instrumentation_factory.unregister(class_) + + def generate_base( + self, + mapper: Optional[Callable[..., Mapper[Any]]] = None, + cls: Type[Any] = object, + name: str = "Base", + metaclass: Type[Any] = DeclarativeMeta, + ) -> Any: + """Generate a declarative base class. + + Classes that inherit from the returned class object will be + automatically mapped using declarative mapping. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + Base = mapper_registry.generate_base() + + class MyClass(Base): + __tablename__ = "my_table" + id = Column(Integer, primary_key=True) + + The above dynamically generated class is equivalent to the + non-dynamic example below:: + + from sqlalchemy.orm import registry + from sqlalchemy.orm.decl_api import DeclarativeMeta + + mapper_registry = registry() + + class Base(metaclass=DeclarativeMeta): + __abstract__ = True + registry = mapper_registry + metadata = mapper_registry.metadata + + __init__ = mapper_registry.constructor + + .. versionchanged:: 2.0 Note that the + :meth:`_orm.registry.generate_base` method is superseded by the new + :class:`_orm.DeclarativeBase` class, which generates a new "base" + class using subclassing, rather than return value of a function. + This allows an approach that is compatible with :pep:`484` typing + tools. + + The :meth:`_orm.registry.generate_base` method provides the + implementation for the :func:`_orm.declarative_base` function, which + creates the :class:`_orm.registry` and base class all at once. + + See the section :ref:`orm_declarative_mapping` for background and + examples. + + :param mapper: + An optional callable, defaults to :class:`_orm.Mapper`. + This function is used to generate new :class:`_orm.Mapper` objects. + + :param cls: + Defaults to :class:`object`. A type to use as the base for the + generated declarative base class. May be a class or tuple of classes. + + :param name: + Defaults to ``Base``. The display name for the generated + class. Customizing this is not required, but can improve clarity in + tracebacks and debugging. + + :param metaclass: + Defaults to :class:`.DeclarativeMeta`. A metaclass or __metaclass__ + compatible callable to use as the meta type of the generated + declarative base class. + + .. seealso:: + + :ref:`orm_declarative_mapping` + + :func:`_orm.declarative_base` + + """ + metadata = self.metadata + + bases = not isinstance(cls, tuple) and (cls,) or cls + + class_dict: Dict[str, Any] = dict(registry=self, metadata=metadata) + if isinstance(cls, type): + class_dict["__doc__"] = cls.__doc__ + + if self.constructor is not None: + class_dict["__init__"] = self.constructor + + class_dict["__abstract__"] = True + if mapper: + class_dict["__mapper_cls__"] = mapper + + if hasattr(cls, "__class_getitem__"): + + def __class_getitem__(cls: Type[_T], key: Any) -> Type[_T]: + # allow generic classes in py3.9+ + return cls + + class_dict["__class_getitem__"] = __class_getitem__ + + return metaclass(name, bases, class_dict) + + @compat_typing.dataclass_transform( + field_specifiers=( + MappedColumn, + RelationshipProperty, + Composite, + Synonym, + mapped_column, + relationship, + composite, + synonym, + deferred, + ), + ) + @overload + def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: ... + + @overload + def mapped_as_dataclass( + self, + __cls: Literal[None] = ..., + *, + init: Union[_NoArg, bool] = ..., + repr: Union[_NoArg, bool] = ..., # noqa: A002 + eq: Union[_NoArg, bool] = ..., + order: Union[_NoArg, bool] = ..., + unsafe_hash: Union[_NoArg, bool] = ..., + match_args: Union[_NoArg, bool] = ..., + kw_only: Union[_NoArg, bool] = ..., + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ..., + ) -> Callable[[Type[_O]], Type[_O]]: ... + + def mapped_as_dataclass( + self, + __cls: Optional[Type[_O]] = None, + *, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, + repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 + eq: Union[_NoArg, bool] = _NoArg.NO_ARG, + order: Union[_NoArg, bool] = _NoArg.NO_ARG, + unsafe_hash: Union[_NoArg, bool] = _NoArg.NO_ARG, + match_args: Union[_NoArg, bool] = _NoArg.NO_ARG, + kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, + dataclass_callable: Union[ + _NoArg, Callable[..., Type[Any]] + ] = _NoArg.NO_ARG, + ) -> Union[Type[_O], Callable[[Type[_O]], Type[_O]]]: + """Class decorator that will apply the Declarative mapping process + to a given class, and additionally convert the class to be a + Python dataclass. + + .. seealso:: + + :ref:`orm_declarative_native_dataclasses` - complete background + on SQLAlchemy native dataclass mapping + + + .. versionadded:: 2.0 + + + """ + + def decorate(cls: Type[_O]) -> Type[_O]: + setattr( + cls, + "_sa_apply_dc_transforms", + { + "init": init, + "repr": repr, + "eq": eq, + "order": order, + "unsafe_hash": unsafe_hash, + "match_args": match_args, + "kw_only": kw_only, + "dataclass_callable": dataclass_callable, + }, + ) + _as_declarative(self, cls, cls.__dict__) + return cls + + if __cls: + return decorate(__cls) + else: + return decorate + + def mapped(self, cls: Type[_O]) -> Type[_O]: + """Class decorator that will apply the Declarative mapping process + to a given class. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + @mapper_registry.mapped + class Foo: + __tablename__ = 'some_table' + + id = Column(Integer, primary_key=True) + name = Column(String) + + See the section :ref:`orm_declarative_mapping` for complete + details and examples. + + :param cls: class to be mapped. + + :return: the class that was passed. + + .. seealso:: + + :ref:`orm_declarative_mapping` + + :meth:`_orm.registry.generate_base` - generates a base class + that will apply Declarative mapping to subclasses automatically + using a Python metaclass. + + .. seealso:: + + :meth:`_orm.registry.mapped_as_dataclass` + + """ + _as_declarative(self, cls, cls.__dict__) + return cls + + def as_declarative_base(self, **kw: Any) -> Callable[[Type[_T]], Type[_T]]: + """ + Class decorator which will invoke + :meth:`_orm.registry.generate_base` + for a given base class. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + @mapper_registry.as_declarative_base() + class Base: + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + id = Column(Integer, primary_key=True) + + class MyMappedClass(Base): + # ... + + All keyword arguments passed to + :meth:`_orm.registry.as_declarative_base` are passed + along to :meth:`_orm.registry.generate_base`. + + """ + + def decorate(cls: Type[_T]) -> Type[_T]: + kw["cls"] = cls + kw["name"] = cls.__name__ + return self.generate_base(**kw) # type: ignore + + return decorate + + def map_declaratively(self, cls: Type[_O]) -> Mapper[_O]: + """Map a class declaratively. + + In this form of mapping, the class is scanned for mapping information, + including for columns to be associated with a table, and/or an + actual table object. + + Returns the :class:`_orm.Mapper` object. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + class Foo: + __tablename__ = 'some_table' + + id = Column(Integer, primary_key=True) + name = Column(String) + + mapper = mapper_registry.map_declaratively(Foo) + + This function is more conveniently invoked indirectly via either the + :meth:`_orm.registry.mapped` class decorator or by subclassing a + declarative metaclass generated from + :meth:`_orm.registry.generate_base`. + + See the section :ref:`orm_declarative_mapping` for complete + details and examples. + + :param cls: class to be mapped. + + :return: a :class:`_orm.Mapper` object. + + .. seealso:: + + :ref:`orm_declarative_mapping` + + :meth:`_orm.registry.mapped` - more common decorator interface + to this function. + + :meth:`_orm.registry.map_imperatively` + + """ + _as_declarative(self, cls, cls.__dict__) + return cls.__mapper__ # type: ignore + + def map_imperatively( + self, + class_: Type[_O], + local_table: Optional[FromClause] = None, + **kw: Any, + ) -> Mapper[_O]: + r"""Map a class imperatively. + + In this form of mapping, the class is not scanned for any mapping + information. Instead, all mapping constructs are passed as + arguments. + + This method is intended to be fully equivalent to the now-removed + SQLAlchemy ``mapper()`` function, except that it's in terms of + a particular registry. + + E.g.:: + + from sqlalchemy.orm import registry + + mapper_registry = registry() + + my_table = Table( + "my_table", + mapper_registry.metadata, + Column('id', Integer, primary_key=True) + ) + + class MyClass: + pass + + mapper_registry.map_imperatively(MyClass, my_table) + + See the section :ref:`orm_imperative_mapping` for complete background + and usage examples. + + :param class\_: The class to be mapped. Corresponds to the + :paramref:`_orm.Mapper.class_` parameter. + + :param local_table: the :class:`_schema.Table` or other + :class:`_sql.FromClause` object that is the subject of the mapping. + Corresponds to the + :paramref:`_orm.Mapper.local_table` parameter. + + :param \**kw: all other keyword arguments are passed to the + :class:`_orm.Mapper` constructor directly. + + .. seealso:: + + :ref:`orm_imperative_mapping` + + :ref:`orm_declarative_mapping` + + """ + return _mapper(self, class_, local_table, kw) + + +RegistryType = registry + +if not TYPE_CHECKING: + # allow for runtime type resolution of ``ClassVar[_RegistryType]`` + _RegistryType = registry # noqa + + +def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: + """ + Class decorator which will adapt a given class into a + :func:`_orm.declarative_base`. + + This function makes use of the :meth:`_orm.registry.as_declarative_base` + method, by first creating a :class:`_orm.registry` automatically + and then invoking the decorator. + + E.g.:: + + from sqlalchemy.orm import as_declarative + + @as_declarative() + class Base: + @declared_attr + def __tablename__(cls): + return cls.__name__.lower() + id = Column(Integer, primary_key=True) + + class MyMappedClass(Base): + # ... + + .. seealso:: + + :meth:`_orm.registry.as_declarative_base` + + """ + metadata, class_registry = ( + kw.pop("metadata", None), + kw.pop("class_registry", None), + ) + + return registry( + metadata=metadata, class_registry=class_registry + ).as_declarative_base(**kw) + + +@inspection._inspects( + DeclarativeMeta, DeclarativeBase, DeclarativeAttributeIntercept +) +def _inspect_decl_meta(cls: Type[Any]) -> Optional[Mapper[Any]]: + mp: Optional[Mapper[Any]] = _inspect_mapped_class(cls) + if mp is None: + if _DeferredMapperConfig.has_cls(cls): + _DeferredMapperConfig.raise_unmapped_for_cls(cls) + return mp diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_base.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_base.py new file mode 100644 index 0000000..96530c3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/decl_base.py @@ -0,0 +1,2152 @@ +# orm/decl_base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Internal implementation for declarative.""" + +from __future__ import annotations + +import collections +import dataclasses +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import clsregistry +from . import exc as orm_exc +from . import instrumentation +from . import mapperlib +from ._typing import _O +from ._typing import attr_is_internal_proxy +from .attributes import InstrumentedAttribute +from .attributes import QueryableAttribute +from .base import _is_mapped_class +from .base import InspectionAttr +from .descriptor_props import CompositeProperty +from .descriptor_props import SynonymProperty +from .interfaces import _AttributeOptions +from .interfaces import _DCAttributeOptions +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MappedAttribute +from .interfaces import _MapsColumns +from .interfaces import MapperProperty +from .mapper import Mapper +from .properties import ColumnProperty +from .properties import MappedColumn +from .util import _extract_mapped_subtype +from .util import _is_mapped_annotation +from .util import class_mapper +from .util import de_stringify_annotation +from .. import event +from .. import exc +from .. import util +from ..sql import expression +from ..sql.base import _NoArg +from ..sql.schema import Column +from ..sql.schema import Table +from ..util import topological +from ..util.typing import _AnnotationScanType +from ..util.typing import is_fwd_ref +from ..util.typing import is_literal +from ..util.typing import Protocol +from ..util.typing import TypedDict +from ..util.typing import typing_get_args + +if TYPE_CHECKING: + from ._typing import _ClassDict + from ._typing import _RegistryType + from .base import Mapped + from .decl_api import declared_attr + from .instrumentation import ClassManager + from ..sql.elements import NamedColumn + from ..sql.schema import MetaData + from ..sql.selectable import FromClause + +_T = TypeVar("_T", bound=Any) + +_MapperKwArgs = Mapping[str, Any] +_TableArgsType = Union[Tuple[Any, ...], Dict[str, Any]] + + +class MappedClassProtocol(Protocol[_O]): + """A protocol representing a SQLAlchemy mapped class. + + The protocol is generic on the type of class, use + ``MappedClassProtocol[Any]`` to allow any mapped class. + """ + + __name__: str + __mapper__: Mapper[_O] + __table__: FromClause + + def __call__(self, **kw: Any) -> _O: ... + + +class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): + "Internal more detailed version of ``MappedClassProtocol``." + metadata: MetaData + __tablename__: str + __mapper_args__: _MapperKwArgs + __table_args__: Optional[_TableArgsType] + + _sa_apply_dc_transforms: Optional[_DataclassArguments] + + def __declare_first__(self) -> None: ... + + def __declare_last__(self) -> None: ... + + +class _DataclassArguments(TypedDict): + init: Union[_NoArg, bool] + repr: Union[_NoArg, bool] + eq: Union[_NoArg, bool] + order: Union[_NoArg, bool] + unsafe_hash: Union[_NoArg, bool] + match_args: Union[_NoArg, bool] + kw_only: Union[_NoArg, bool] + dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] + + +def _declared_mapping_info( + cls: Type[Any], +) -> Optional[Union[_DeferredMapperConfig, Mapper[Any]]]: + # deferred mapping + if _DeferredMapperConfig.has_cls(cls): + return _DeferredMapperConfig.config_for_cls(cls) + # regular mapping + elif _is_mapped_class(cls): + return class_mapper(cls, configure=False) + else: + return None + + +def _is_supercls_for_inherits(cls: Type[Any]) -> bool: + """return True if this class will be used as a superclass to set in + 'inherits'. + + This includes deferred mapper configs that aren't mapped yet, however does + not include classes with _sa_decl_prepare_nocascade (e.g. + ``AbstractConcreteBase``); these concrete-only classes are not set up as + "inherits" until after mappers are configured using + mapper._set_concrete_base() + + """ + if _DeferredMapperConfig.has_cls(cls): + return not _get_immediate_cls_attr( + cls, "_sa_decl_prepare_nocascade", strict=True + ) + # regular mapping + elif _is_mapped_class(cls): + return True + else: + return False + + +def _resolve_for_abstract_or_classical(cls: Type[Any]) -> Optional[Type[Any]]: + if cls is object: + return None + + sup: Optional[Type[Any]] + + if cls.__dict__.get("__abstract__", False): + for base_ in cls.__bases__: + sup = _resolve_for_abstract_or_classical(base_) + if sup is not None: + return sup + else: + return None + else: + clsmanager = _dive_for_cls_manager(cls) + + if clsmanager: + return clsmanager.class_ + else: + return cls + + +def _get_immediate_cls_attr( + cls: Type[Any], attrname: str, strict: bool = False +) -> Optional[Any]: + """return an attribute of the class that is either present directly + on the class, e.g. not on a superclass, or is from a superclass but + this superclass is a non-mapped mixin, that is, not a descendant of + the declarative base and is also not classically mapped. + + This is used to detect attributes that indicate something about + a mapped class independently from any mapped classes that it may + inherit from. + + """ + + # the rules are different for this name than others, + # make sure we've moved it out. transitional + assert attrname != "__abstract__" + + if not issubclass(cls, object): + return None + + if attrname in cls.__dict__: + return getattr(cls, attrname) + + for base in cls.__mro__[1:]: + _is_classical_inherits = _dive_for_cls_manager(base) is not None + + if attrname in base.__dict__ and ( + base is cls + or ( + (base in cls.__bases__ if strict else True) + and not _is_classical_inherits + ) + ): + return getattr(base, attrname) + else: + return None + + +def _dive_for_cls_manager(cls: Type[_O]) -> Optional[ClassManager[_O]]: + # because the class manager registration is pluggable, + # we need to do the search for every class in the hierarchy, + # rather than just a simple "cls._sa_class_manager" + + for base in cls.__mro__: + manager: Optional[ClassManager[_O]] = attributes.opt_manager_of_class( + base + ) + if manager: + return manager + return None + + +def _as_declarative( + registry: _RegistryType, cls: Type[Any], dict_: _ClassDict +) -> Optional[_MapperConfig]: + # declarative scans the class for attributes. no table or mapper + # args passed separately. + return _MapperConfig.setup_mapping(registry, cls, dict_, None, {}) + + +def _mapper( + registry: _RegistryType, + cls: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, +) -> Mapper[_O]: + _ImperativeMapperConfig(registry, cls, table, mapper_kw) + return cast("MappedClassProtocol[_O]", cls).__mapper__ + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _is_declarative_props(obj: Any) -> bool: + _declared_attr_common = util.preloaded.orm_decl_api._declared_attr_common + + return isinstance(obj, (_declared_attr_common, util.classproperty)) + + +def _check_declared_props_nocascade( + obj: Any, name: str, cls: Type[_O] +) -> bool: + if _is_declarative_props(obj): + if getattr(obj, "_cascading", False): + util.warn( + "@declared_attr.cascading is not supported on the %s " + "attribute on class %s. This attribute invokes for " + "subclasses in any case." % (name, cls) + ) + return True + else: + return False + + +class _MapperConfig: + __slots__ = ( + "cls", + "classname", + "properties", + "declared_attr_reg", + "__weakref__", + ) + + cls: Type[Any] + classname: str + properties: util.OrderedDict[ + str, + Union[ + Sequence[NamedColumn[Any]], NamedColumn[Any], MapperProperty[Any] + ], + ] + declared_attr_reg: Dict[declared_attr[Any], Any] + + @classmethod + def setup_mapping( + cls, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ) -> Optional[_MapperConfig]: + manager = attributes.opt_manager_of_class(cls) + if manager and manager.class_ is cls_: + raise exc.InvalidRequestError( + f"Class {cls!r} already has been instrumented declaratively" + ) + + if cls_.__dict__.get("__abstract__", False): + return None + + defer_map = _get_immediate_cls_attr( + cls_, "_sa_decl_prepare_nocascade", strict=True + ) or hasattr(cls_, "_sa_decl_prepare") + + if defer_map: + return _DeferredMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) + else: + return _ClassScanMapperConfig( + registry, cls_, dict_, table, mapper_kw + ) + + def __init__( + self, + registry: _RegistryType, + cls_: Type[Any], + mapper_kw: _MapperKwArgs, + ): + self.cls = util.assert_arg_type(cls_, type, "cls_") + self.classname = cls_.__name__ + self.properties = util.OrderedDict() + self.declared_attr_reg = {} + + if not mapper_kw.get("non_primary", False): + instrumentation.register_class( + self.cls, + finalize=False, + registry=registry, + declarative_scan=self, + init_method=registry.constructor, + ) + else: + manager = attributes.opt_manager_of_class(self.cls) + if not manager or not manager.is_mapped: + raise exc.InvalidRequestError( + "Class %s has no primary mapper configured. Configure " + "a primary mapper first before setting up a non primary " + "Mapper." % self.cls + ) + + def set_cls_attribute(self, attrname: str, value: _T) -> _T: + manager = instrumentation.manager_of_class(self.cls) + manager.install_member(attrname, value) + return value + + def map(self, mapper_kw: _MapperKwArgs = ...) -> Mapper[Any]: + raise NotImplementedError() + + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: + self.map(mapper_kw) + + +class _ImperativeMapperConfig(_MapperConfig): + __slots__ = ("local_table", "inherits") + + def __init__( + self, + registry: _RegistryType, + cls_: Type[_O], + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ): + super().__init__(registry, cls_, mapper_kw) + + self.local_table = self.set_cls_attribute("__table__", table) + + with mapperlib._CONFIGURE_MUTEX: + if not mapper_kw.get("non_primary", False): + clsregistry.add_class( + self.classname, self.cls, registry._class_registry + ) + + self._setup_inheritance(mapper_kw) + + self._early_mapping(mapper_kw) + + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: + mapper_cls = Mapper + + return self.set_cls_attribute( + "__mapper__", + mapper_cls(self.cls, self.local_table, **mapper_kw), + ) + + def _setup_inheritance(self, mapper_kw: _MapperKwArgs) -> None: + cls = self.cls + + inherits = mapper_kw.get("inherits", None) + + if inherits is None: + # since we search for classical mappings now, search for + # multiple mapped bases as well and raise an error. + inherits_search = [] + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) + if c is None: + continue + + if _is_supercls_for_inherits(c) and c not in inherits_search: + inherits_search.append(c) + + if inherits_search: + if len(inherits_search) > 1: + raise exc.InvalidRequestError( + "Class %s has multiple mapped bases: %r" + % (cls, inherits_search) + ) + inherits = inherits_search[0] + elif isinstance(inherits, Mapper): + inherits = inherits.class_ + + self.inherits = inherits + + +class _CollectedAnnotation(NamedTuple): + raw_annotation: _AnnotationScanType + mapped_container: Optional[Type[Mapped[Any]]] + extracted_mapped_annotation: Union[Type[Any], str] + is_dataclass: bool + attr_value: Any + originating_module: str + originating_class: Type[Any] + + +class _ClassScanMapperConfig(_MapperConfig): + __slots__ = ( + "registry", + "clsdict_view", + "collected_attributes", + "collected_annotations", + "local_table", + "persist_selectable", + "declared_columns", + "column_ordering", + "column_copies", + "table_args", + "tablename", + "mapper_args", + "mapper_args_fn", + "inherits", + "single", + "allow_dataclass_fields", + "dataclass_setup_arguments", + "is_dataclass_prior_to_mapping", + "allow_unmapped_annotations", + ) + + is_deferred = False + registry: _RegistryType + clsdict_view: _ClassDict + collected_annotations: Dict[str, _CollectedAnnotation] + collected_attributes: Dict[str, Any] + local_table: Optional[FromClause] + persist_selectable: Optional[FromClause] + declared_columns: util.OrderedSet[Column[Any]] + column_ordering: Dict[Column[Any], int] + column_copies: Dict[ + Union[MappedColumn[Any], Column[Any]], + Union[MappedColumn[Any], Column[Any]], + ] + tablename: Optional[str] + mapper_args: Mapping[str, Any] + table_args: Optional[_TableArgsType] + mapper_args_fn: Optional[Callable[[], Dict[str, Any]]] + inherits: Optional[Type[Any]] + single: bool + + is_dataclass_prior_to_mapping: bool + allow_unmapped_annotations: bool + + dataclass_setup_arguments: Optional[_DataclassArguments] + """if the class has SQLAlchemy native dataclass parameters, where + we will turn the class into a dataclass within the declarative mapping + process. + + """ + + allow_dataclass_fields: bool + """if true, look for dataclass-processed Field objects on the target + class as well as superclasses and extract ORM mapping directives from + the "metadata" attribute of each Field. + + if False, dataclass fields can still be used, however they won't be + mapped. + + """ + + def __init__( + self, + registry: _RegistryType, + cls_: Type[_O], + dict_: _ClassDict, + table: Optional[FromClause], + mapper_kw: _MapperKwArgs, + ): + # grab class dict before the instrumentation manager has been added. + # reduces cycles + self.clsdict_view = ( + util.immutabledict(dict_) if dict_ else util.EMPTY_DICT + ) + super().__init__(registry, cls_, mapper_kw) + self.registry = registry + self.persist_selectable = None + + self.collected_attributes = {} + self.collected_annotations = {} + self.declared_columns = util.OrderedSet() + self.column_ordering = {} + self.column_copies = {} + self.single = False + self.dataclass_setup_arguments = dca = getattr( + self.cls, "_sa_apply_dc_transforms", None + ) + + self.allow_unmapped_annotations = getattr( + self.cls, "__allow_unmapped__", False + ) or bool(self.dataclass_setup_arguments) + + self.is_dataclass_prior_to_mapping = cld = dataclasses.is_dataclass( + cls_ + ) + + sdk = _get_immediate_cls_attr(cls_, "__sa_dataclass_metadata_key__") + + # we don't want to consume Field objects from a not-already-dataclass. + # the Field objects won't have their "name" or "type" populated, + # and while it seems like we could just set these on Field as we + # read them, Field is documented as "user read only" and we need to + # stay far away from any off-label use of dataclasses APIs. + if (not cld or dca) and sdk: + raise exc.InvalidRequestError( + "SQLAlchemy mapped dataclasses can't consume mapping " + "information from dataclass.Field() objects if the immediate " + "class is not already a dataclass." + ) + + # if already a dataclass, and __sa_dataclass_metadata_key__ present, + # then also look inside of dataclass.Field() objects yielded by + # dataclasses.get_fields(cls) when scanning for attributes + self.allow_dataclass_fields = bool(sdk and cld) + + self._setup_declared_events() + + self._scan_attributes() + + self._setup_dataclasses_transforms() + + with mapperlib._CONFIGURE_MUTEX: + clsregistry.add_class( + self.classname, self.cls, registry._class_registry + ) + + self._setup_inheriting_mapper(mapper_kw) + + self._extract_mappable_attributes() + + self._extract_declared_columns() + + self._setup_table(table) + + self._setup_inheriting_columns(mapper_kw) + + self._early_mapping(mapper_kw) + + def _setup_declared_events(self) -> None: + if _get_immediate_cls_attr(self.cls, "__declare_last__"): + + @event.listens_for(Mapper, "after_configured") + def after_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_last__() + + if _get_immediate_cls_attr(self.cls, "__declare_first__"): + + @event.listens_for(Mapper, "before_configured") + def before_configured() -> None: + cast( + "_DeclMappedClassProtocol[Any]", self.cls + ).__declare_first__() + + def _cls_attr_override_checker( + self, cls: Type[_O] + ) -> Callable[[str, Any], bool]: + """Produce a function that checks if a class has overridden an + attribute, taking SQLAlchemy-enabled dataclass fields into account. + + """ + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: + + def attribute_is_overridden(key: str, obj: Any) -> bool: + return getattr(cls, key, obj) is not obj + + else: + all_datacls_fields = { + f.name: f.metadata[sa_dataclass_metadata_key] + for f in util.dataclass_fields(cls) + if sa_dataclass_metadata_key in f.metadata + } + local_datacls_fields = { + f.name: f.metadata[sa_dataclass_metadata_key] + for f in util.local_dataclass_fields(cls) + if sa_dataclass_metadata_key in f.metadata + } + + absent = object() + + def attribute_is_overridden(key: str, obj: Any) -> bool: + if _is_declarative_props(obj): + obj = obj.fget + + # this function likely has some failure modes still if + # someone is doing a deep mixing of the same attribute + # name as plain Python attribute vs. dataclass field. + + ret = local_datacls_fields.get(key, absent) + if _is_declarative_props(ret): + ret = ret.fget + + if ret is obj: + return False + elif ret is not absent: + return True + + all_field = all_datacls_fields.get(key, absent) + + ret = getattr(cls, key, obj) + + if ret is obj: + return False + + # for dataclasses, this could be the + # 'default' of the field. so filter more specifically + # for an already-mapped InstrumentedAttribute + if ret is not absent and isinstance( + ret, InstrumentedAttribute + ): + return True + + if all_field is obj: + return False + elif all_field is not absent: + return True + + # can't find another attribute + return False + + return attribute_is_overridden + + _include_dunders = { + "__table__", + "__mapper_args__", + "__tablename__", + "__table_args__", + } + + _match_exclude_dunders = re.compile(r"^(?:_sa_|__)") + + def _cls_attr_resolver( + self, cls: Type[Any] + ) -> Callable[[], Iterable[Tuple[str, Any, Any, bool]]]: + """produce a function to iterate the "attributes" of a class + which we want to consider for mapping, adjusting for SQLAlchemy fields + embedded in dataclass fields. + + """ + cls_annotations = util.get_annotations(cls) + + cls_vars = vars(cls) + + _include_dunders = self._include_dunders + _match_exclude_dunders = self._match_exclude_dunders + + names = [ + n + for n in util.merge_lists_w_ordering( + list(cls_vars), list(cls_annotations) + ) + if not _match_exclude_dunders.match(n) or n in _include_dunders + ] + + if self.allow_dataclass_fields: + sa_dataclass_metadata_key: Optional[str] = _get_immediate_cls_attr( + cls, "__sa_dataclass_metadata_key__" + ) + else: + sa_dataclass_metadata_key = None + + if not sa_dataclass_metadata_key: + + def local_attributes_for_class() -> ( + Iterable[Tuple[str, Any, Any, bool]] + ): + return ( + ( + name, + cls_vars.get(name), + cls_annotations.get(name), + False, + ) + for name in names + ) + + else: + dataclass_fields = { + field.name: field for field in util.local_dataclass_fields(cls) + } + + fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key + + def local_attributes_for_class() -> ( + Iterable[Tuple[str, Any, Any, bool]] + ): + for name in names: + field = dataclass_fields.get(name, None) + if field and sa_dataclass_metadata_key in field.metadata: + yield field.name, _as_dc_declaredattr( + field.metadata, fixed_sa_dataclass_metadata_key + ), cls_annotations.get(field.name), True + else: + yield name, cls_vars.get(name), cls_annotations.get( + name + ), False + + return local_attributes_for_class + + def _scan_attributes(self) -> None: + cls = self.cls + + cls_as_Decl = cast("_DeclMappedClassProtocol[Any]", cls) + + clsdict_view = self.clsdict_view + collected_attributes = self.collected_attributes + column_copies = self.column_copies + _include_dunders = self._include_dunders + mapper_args_fn = None + table_args = inherited_table_args = None + + tablename = None + fixed_table = "__table__" in clsdict_view + + attribute_is_overridden = self._cls_attr_override_checker(self.cls) + + bases = [] + + for base in cls.__mro__: + # collect bases and make sure standalone columns are copied + # to be the column they will ultimately be on the class, + # so that declared_attr functions use the right columns. + # need to do this all the way up the hierarchy first + # (see #8190) + + class_mapped = base is not cls and _is_supercls_for_inherits(base) + + local_attributes_for_class = self._cls_attr_resolver(base) + + if not class_mapped and base is not cls: + locally_collected_columns = self._produce_column_copies( + local_attributes_for_class, + attribute_is_overridden, + fixed_table, + base, + ) + else: + locally_collected_columns = {} + + bases.append( + ( + base, + class_mapped, + local_attributes_for_class, + locally_collected_columns, + ) + ) + + for ( + base, + class_mapped, + local_attributes_for_class, + locally_collected_columns, + ) in bases: + # this transfer can also take place as we scan each name + # for finer-grained control of how collected_attributes is + # populated, as this is what impacts column ordering. + # however it's simpler to get it out of the way here. + collected_attributes.update(locally_collected_columns) + + for ( + name, + obj, + annotation, + is_dataclass_field, + ) in local_attributes_for_class(): + if name in _include_dunders: + if name == "__mapper_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not mapper_args_fn and ( + not class_mapped or check_decl + ): + # don't even invoke __mapper_args__ until + # after we've determined everything about the + # mapped table. + # make a copy of it so a class-level dictionary + # is not overwritten when we update column-based + # arguments. + def _mapper_args_fn() -> Dict[str, Any]: + return dict(cls_as_Decl.__mapper_args__) + + mapper_args_fn = _mapper_args_fn + + elif name == "__tablename__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not tablename and (not class_mapped or check_decl): + tablename = cls_as_Decl.__tablename__ + elif name == "__table_args__": + check_decl = _check_declared_props_nocascade( + obj, name, cls + ) + if not table_args and (not class_mapped or check_decl): + table_args = cls_as_Decl.__table_args__ + if not isinstance( + table_args, (tuple, dict, type(None)) + ): + raise exc.ArgumentError( + "__table_args__ value must be a tuple, " + "dict, or None" + ) + if base is not cls: + inherited_table_args = True + else: + # skip all other dunder names, which at the moment + # should only be __table__ + continue + elif class_mapped: + if _is_declarative_props(obj) and not obj._quiet: + util.warn( + "Regular (i.e. not __special__) " + "attribute '%s.%s' uses @declared_attr, " + "but owning class %s is mapped - " + "not applying to subclass %s." + % (base.__name__, name, base, cls) + ) + + continue + elif base is not cls: + # we're a mixin, abstract base, or something that is + # acting like that for now. + + if isinstance(obj, (Column, MappedColumn)): + # already copied columns to the mapped class. + continue + elif isinstance(obj, MapperProperty): + raise exc.InvalidRequestError( + "Mapper properties (i.e. deferred," + "column_property(), relationship(), etc.) must " + "be declared as @declared_attr callables " + "on declarative mixin classes. For dataclass " + "field() objects, use a lambda:" + ) + elif _is_declarative_props(obj): + # tried to get overloads to tell this to + # pylance, no luck + assert obj is not None + + if obj._cascading: + if name in clsdict_view: + # unfortunately, while we can use the user- + # defined attribute here to allow a clean + # override, if there's another + # subclass below then it still tries to use + # this. not sure if there is enough + # information here to add this as a feature + # later on. + util.warn( + "Attribute '%s' on class %s cannot be " + "processed due to " + "@declared_attr.cascading; " + "skipping" % (name, cls) + ) + collected_attributes[name] = column_copies[obj] = ( + ret + ) = obj.__get__(obj, cls) + setattr(cls, name, ret) + else: + if is_dataclass_field: + # access attribute using normal class access + # first, to see if it's been mapped on a + # superclass. note if the dataclasses.field() + # has "default", this value can be anything. + ret = getattr(cls, name, None) + + # so, if it's anything that's not ORM + # mapped, assume we should invoke the + # declared_attr + if not isinstance(ret, InspectionAttr): + ret = obj.fget() + else: + # access attribute using normal class access. + # if the declared attr already took place + # on a superclass that is mapped, then + # this is no longer a declared_attr, it will + # be the InstrumentedAttribute + ret = getattr(cls, name) + + # correct for proxies created from hybrid_property + # or similar. note there is no known case that + # produces nested proxies, so we are only + # looking one level deep right now. + + if ( + isinstance(ret, InspectionAttr) + and attr_is_internal_proxy(ret) + and not isinstance( + ret.original_property, MapperProperty + ) + ): + ret = ret.descriptor + + collected_attributes[name] = column_copies[obj] = ( + ret + ) + + if ( + isinstance(ret, (Column, MapperProperty)) + and ret.doc is None + ): + ret.doc = obj.__doc__ + + self._collect_annotation( + name, + obj._collect_return_annotation(), + base, + True, + obj, + ) + elif _is_mapped_annotation(annotation, cls, base): + # Mapped annotation without any object. + # product_column_copies should have handled this. + # if future support for other MapperProperty, + # then test if this name is already handled and + # otherwise proceed to generate. + if not fixed_table: + assert ( + name in collected_attributes + or attribute_is_overridden(name, None) + ) + continue + else: + # here, the attribute is some other kind of + # property that we assume is not part of the + # declarative mapping. however, check for some + # more common mistakes + self._warn_for_decl_attributes(base, name, obj) + elif is_dataclass_field and ( + name not in clsdict_view or clsdict_view[name] is not obj + ): + # here, we are definitely looking at the target class + # and not a superclass. this is currently a + # dataclass-only path. if the name is only + # a dataclass field and isn't in local cls.__dict__, + # put the object there. + # assert that the dataclass-enabled resolver agrees + # with what we are seeing + + assert not attribute_is_overridden(name, obj) + + if _is_declarative_props(obj): + obj = obj.fget() + + collected_attributes[name] = obj + self._collect_annotation( + name, annotation, base, False, obj + ) + else: + collected_annotation = self._collect_annotation( + name, annotation, base, None, obj + ) + is_mapped = ( + collected_annotation is not None + and collected_annotation.mapped_container is not None + ) + generated_obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + if obj is None and not fixed_table and is_mapped: + collected_attributes[name] = ( + generated_obj + if generated_obj is not None + else MappedColumn() + ) + elif name in clsdict_view: + collected_attributes[name] = obj + # else if the name is not in the cls.__dict__, + # don't collect it as an attribute. + # we will see the annotation only, which is meaningful + # both for mapping and dataclasses setup + + if inherited_table_args and not tablename: + table_args = None + + self.table_args = table_args + self.tablename = tablename + self.mapper_args_fn = mapper_args_fn + + def _setup_dataclasses_transforms(self) -> None: + dataclass_setup_arguments = self.dataclass_setup_arguments + if not dataclass_setup_arguments: + return + + # can't use is_dataclass since it uses hasattr + if "__dataclass_fields__" in self.cls.__dict__: + raise exc.InvalidRequestError( + f"Class {self.cls} is already a dataclass; ensure that " + "base classes / decorator styles of establishing dataclasses " + "are not being mixed. " + "This can happen if a class that inherits from " + "'MappedAsDataclass', even indirectly, is been mapped with " + "'@registry.mapped_as_dataclass'" + ) + + warn_for_non_dc_attrs = collections.defaultdict(list) + + def _allow_dataclass_field( + key: str, originating_class: Type[Any] + ) -> bool: + if ( + originating_class is not self.cls + and "__dataclass_fields__" not in originating_class.__dict__ + ): + warn_for_non_dc_attrs[originating_class].append(key) + + return True + + manager = instrumentation.manager_of_class(self.cls) + assert manager is not None + + field_list = [ + _AttributeOptions._get_arguments_for_make_dataclass( + key, + anno, + mapped_container, + self.collected_attributes.get(key, _NoArg.NO_ARG), + ) + for key, anno, mapped_container in ( + ( + key, + mapped_anno if mapped_anno else raw_anno, + mapped_container, + ) + for key, ( + raw_anno, + mapped_container, + mapped_anno, + is_dc, + attr_value, + originating_module, + originating_class, + ) in self.collected_annotations.items() + if _allow_dataclass_field(key, originating_class) + and ( + key not in self.collected_attributes + # issue #9226; check for attributes that we've collected + # which are already instrumented, which we would assume + # mean we are in an ORM inheritance mapping and this + # attribute is already mapped on the superclass. Under + # no circumstance should any QueryableAttribute be sent to + # the dataclass() function; anything that's mapped should + # be Field and that's it + or not isinstance( + self.collected_attributes[key], QueryableAttribute + ) + ) + ) + ] + + if warn_for_non_dc_attrs: + for ( + originating_class, + non_dc_attrs, + ) in warn_for_non_dc_attrs.items(): + util.warn_deprecated( + f"When transforming {self.cls} to a dataclass, " + f"attribute(s) " + f"{', '.join(repr(key) for key in non_dc_attrs)} " + f"originates from superclass " + f"{originating_class}, which is not a dataclass. This " + f"usage is deprecated and will raise an error in " + f"SQLAlchemy 2.1. When declaring SQLAlchemy Declarative " + f"Dataclasses, ensure that all mixin classes and other " + f"superclasses which include attributes are also a " + f"subclass of MappedAsDataclass.", + "2.0", + code="dcmx", + ) + + annotations = {} + defaults = {} + for item in field_list: + if len(item) == 2: + name, tp = item + elif len(item) == 3: + name, tp, spec = item + defaults[name] = spec + else: + assert False + annotations[name] = tp + + for k, v in defaults.items(): + setattr(self.cls, k, v) + + self._apply_dataclasses_to_any_class( + dataclass_setup_arguments, self.cls, annotations + ) + + @classmethod + def _update_annotations_for_non_mapped_class( + cls, klass: Type[_O] + ) -> Mapping[str, _AnnotationScanType]: + cls_annotations = util.get_annotations(klass) + + new_anno = {} + for name, annotation in cls_annotations.items(): + if _is_mapped_annotation(annotation, klass, klass): + extracted = _extract_mapped_subtype( + annotation, + klass, + klass.__module__, + name, + type(None), + required=False, + is_dataclass_field=False, + expect_mapped=False, + ) + if extracted: + inner, _ = extracted + new_anno[name] = inner + else: + new_anno[name] = annotation + return new_anno + + @classmethod + def _apply_dataclasses_to_any_class( + cls, + dataclass_setup_arguments: _DataclassArguments, + klass: Type[_O], + use_annotations: Mapping[str, _AnnotationScanType], + ) -> None: + cls._assert_dc_arguments(dataclass_setup_arguments) + + dataclass_callable = dataclass_setup_arguments["dataclass_callable"] + if dataclass_callable is _NoArg.NO_ARG: + dataclass_callable = dataclasses.dataclass + + restored: Optional[Any] + + if use_annotations: + # apply constructed annotations that should look "normal" to a + # dataclasses callable, based on the fields present. This + # means remove the Mapped[] container and ensure all Field + # entries have an annotation + restored = getattr(klass, "__annotations__", None) + klass.__annotations__ = cast("Dict[str, Any]", use_annotations) + else: + restored = None + + try: + dataclass_callable( + klass, + **{ + k: v + for k, v in dataclass_setup_arguments.items() + if v is not _NoArg.NO_ARG and k != "dataclass_callable" + }, + ) + except (TypeError, ValueError) as ex: + raise exc.InvalidRequestError( + f"Python dataclasses error encountered when creating " + f"dataclass for {klass.__name__!r}: " + f"{ex!r}. Please refer to Python dataclasses " + "documentation for additional information.", + code="dcte", + ) from ex + finally: + # restore original annotations outside of the dataclasses + # process; for mixins and __abstract__ superclasses, SQLAlchemy + # Declarative will need to see the Mapped[] container inside the + # annotations in order to map subclasses + if use_annotations: + if restored is None: + del klass.__annotations__ + else: + klass.__annotations__ = restored + + @classmethod + def _assert_dc_arguments(cls, arguments: _DataclassArguments) -> None: + allowed = { + "init", + "repr", + "order", + "eq", + "unsafe_hash", + "kw_only", + "match_args", + "dataclass_callable", + } + disallowed_args = set(arguments).difference(allowed) + if disallowed_args: + msg = ", ".join(f"{arg!r}" for arg in sorted(disallowed_args)) + raise exc.ArgumentError( + f"Dataclass argument(s) {msg} are not accepted" + ) + + def _collect_annotation( + self, + name: str, + raw_annotation: _AnnotationScanType, + originating_class: Type[Any], + expect_mapped: Optional[bool], + attr_value: Any, + ) -> Optional[_CollectedAnnotation]: + if name in self.collected_annotations: + return self.collected_annotations[name] + + if raw_annotation is None: + return None + + is_dataclass = self.is_dataclass_prior_to_mapping + allow_unmapped = self.allow_unmapped_annotations + + if expect_mapped is None: + is_dataclass_field = isinstance(attr_value, dataclasses.Field) + expect_mapped = ( + not is_dataclass_field + and not allow_unmapped + and ( + attr_value is None + or isinstance(attr_value, _MappedAttribute) + ) + ) + else: + is_dataclass_field = False + + is_dataclass_field = False + extracted = _extract_mapped_subtype( + raw_annotation, + self.cls, + originating_class.__module__, + name, + type(attr_value), + required=False, + is_dataclass_field=is_dataclass_field, + expect_mapped=expect_mapped + and not is_dataclass, # self.allow_dataclass_fields, + ) + + if extracted is None: + # ClassVar can come out here + return None + + extracted_mapped_annotation, mapped_container = extracted + + if attr_value is None and not is_literal(extracted_mapped_annotation): + for elem in typing_get_args(extracted_mapped_annotation): + if isinstance(elem, str) or is_fwd_ref( + elem, check_generic=True + ): + elem = de_stringify_annotation( + self.cls, + elem, + originating_class.__module__, + include_generic=True, + ) + # look in Annotated[...] for an ORM construct, + # such as Annotated[int, mapped_column(primary_key=True)] + if isinstance(elem, _IntrospectsAnnotations): + attr_value = elem.found_in_pep593_annotated() + + self.collected_annotations[name] = ca = _CollectedAnnotation( + raw_annotation, + mapped_container, + extracted_mapped_annotation, + is_dataclass, + attr_value, + originating_class.__module__, + originating_class, + ) + return ca + + def _warn_for_decl_attributes( + self, cls: Type[Any], key: str, c: Any + ) -> None: + if isinstance(c, expression.ColumnElement): + util.warn( + f"Attribute '{key}' on class {cls} appears to " + "be a non-schema SQLAlchemy expression " + "object; this won't be part of the declarative mapping. " + "To map arbitrary expressions, use ``column_property()`` " + "or a similar function such as ``deferred()``, " + "``query_expression()`` etc. " + ) + + def _produce_column_copies( + self, + attributes_for_class: Callable[ + [], Iterable[Tuple[str, Any, Any, bool]] + ], + attribute_is_overridden: Callable[[str, Any], bool], + fixed_table: bool, + originating_class: Type[Any], + ) -> Dict[str, Union[Column[Any], MappedColumn[Any]]]: + cls = self.cls + dict_ = self.clsdict_view + locally_collected_attributes = {} + column_copies = self.column_copies + # copy mixin columns to the mapped class + + for name, obj, annotation, is_dataclass in attributes_for_class(): + if ( + not fixed_table + and obj is None + and _is_mapped_annotation(annotation, cls, originating_class) + ): + # obj is None means this is the annotation only path + + if attribute_is_overridden(name, obj): + # perform same "overridden" check as we do for + # Column/MappedColumn, this is how a mixin col is not + # applied to an inherited subclass that does not have + # the mixin. the anno-only path added here for + # #9564 + continue + + collected_annotation = self._collect_annotation( + name, annotation, originating_class, True, obj + ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + if obj is None: + obj = MappedColumn() + + locally_collected_attributes[name] = obj + setattr(cls, name, obj) + + elif isinstance(obj, (Column, MappedColumn)): + if attribute_is_overridden(name, obj): + # if column has been overridden + # (like by the InstrumentedAttribute of the + # superclass), skip. don't collect the annotation + # either (issue #8718) + continue + + collected_annotation = self._collect_annotation( + name, annotation, originating_class, True, obj + ) + obj = ( + collected_annotation.attr_value + if collected_annotation is not None + else obj + ) + + if name not in dict_ and not ( + "__table__" in dict_ + and (getattr(obj, "name", None) or name) + in dict_["__table__"].c + ): + if obj.foreign_keys: + for fk in obj.foreign_keys: + if ( + fk._table_column is not None + and fk._table_column.table is None + ): + raise exc.InvalidRequestError( + "Columns with foreign keys to " + "non-table-bound " + "columns must be declared as " + "@declared_attr callables " + "on declarative mixin classes. " + "For dataclass " + "field() objects, use a lambda:." + ) + + column_copies[obj] = copy_ = obj._copy() + + locally_collected_attributes[name] = copy_ + setattr(cls, name, copy_) + + return locally_collected_attributes + + def _extract_mappable_attributes(self) -> None: + cls = self.cls + collected_attributes = self.collected_attributes + + our_stuff = self.properties + + _include_dunders = self._include_dunders + + late_mapped = _get_immediate_cls_attr( + cls, "_sa_decl_prepare_nocascade", strict=True + ) + + allow_unmapped_annotations = self.allow_unmapped_annotations + expect_annotations_wo_mapped = ( + allow_unmapped_annotations or self.is_dataclass_prior_to_mapping + ) + + look_for_dataclass_things = bool(self.dataclass_setup_arguments) + + for k in list(collected_attributes): + if k in _include_dunders: + continue + + value = collected_attributes[k] + + if _is_declarative_props(value): + # @declared_attr in collected_attributes only occurs here for a + # @declared_attr that's directly on the mapped class; + # for a mixin, these have already been evaluated + if value._cascading: + util.warn( + "Use of @declared_attr.cascading only applies to " + "Declarative 'mixin' and 'abstract' classes. " + "Currently, this flag is ignored on mapped class " + "%s" % self.cls + ) + + value = getattr(cls, k) + + elif ( + isinstance(value, QueryableAttribute) + and value.class_ is not cls + and value.key != k + ): + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = SynonymProperty(value.key) + setattr(cls, k, value) + + if ( + isinstance(value, tuple) + and len(value) == 1 + and isinstance(value[0], (Column, _MappedAttribute)) + ): + util.warn( + "Ignoring declarative-like tuple value of attribute " + "'%s': possibly a copy-and-paste error with a comma " + "accidentally placed at the end of the line?" % k + ) + continue + elif look_for_dataclass_things and isinstance( + value, dataclasses.Field + ): + # we collected a dataclass Field; dataclasses would have + # set up the correct state on the class + continue + elif not isinstance(value, (Column, _DCAttributeOptions)): + # using @declared_attr for some object that + # isn't Column/MapperProperty/_DCAttributeOptions; remove + # from the clsdict_view + # and place the evaluated value onto the class. + collected_attributes.pop(k) + self._warn_for_decl_attributes(cls, k, value) + if not late_mapped: + setattr(cls, k, value) + continue + # we expect to see the name 'metadata' in some valid cases; + # however at this point we see it's assigned to something trying + # to be mapped, so raise for that. + # TODO: should "registry" here be also? might be too late + # to change that now (2.0 betas) + elif k in ("metadata",): + raise exc.InvalidRequestError( + f"Attribute name '{k}' is reserved when using the " + "Declarative API." + ) + elif isinstance(value, Column): + _undefer_column_name( + k, self.column_copies.get(value, value) # type: ignore + ) + else: + if isinstance(value, _IntrospectsAnnotations): + ( + annotation, + mapped_container, + extracted_mapped_annotation, + is_dataclass, + attr_value, + originating_module, + originating_class, + ) = self.collected_annotations.get( + k, (None, None, None, False, None, None, None) + ) + + # issue #8692 - don't do any annotation interpretation if + # an annotation were present and a container such as + # Mapped[] etc. were not used. If annotation is None, + # do declarative_scan so that the property can raise + # for required + if ( + mapped_container is not None + or annotation is None + # issue #10516: need to do declarative_scan even with + # a non-Mapped annotation if we are doing + # __allow_unmapped__, for things like col.name + # assignment + or allow_unmapped_annotations + ): + try: + value.declarative_scan( + self, + self.registry, + cls, + originating_module, + k, + mapped_container, + annotation, + extracted_mapped_annotation, + is_dataclass, + ) + except NameError as ne: + raise exc.ArgumentError( + f"Could not resolve all types within mapped " + f'annotation: "{annotation}". Ensure all ' + f"types are written correctly and are " + f"imported within the module in use." + ) from ne + else: + # assert that we were expecting annotations + # without Mapped[] were going to be passed. + # otherwise an error should have been raised + # by util._extract_mapped_subtype before we got here. + assert expect_annotations_wo_mapped + + if isinstance(value, _DCAttributeOptions): + if ( + value._has_dataclass_arguments + and not look_for_dataclass_things + ): + if isinstance(value, MapperProperty): + argnames = [ + "init", + "default_factory", + "repr", + "default", + ] + else: + argnames = ["init", "default_factory", "repr"] + + args = { + a + for a in argnames + if getattr( + value._attribute_options, f"dataclasses_{a}" + ) + is not _NoArg.NO_ARG + } + + raise exc.ArgumentError( + f"Attribute '{k}' on class {cls} includes " + f"dataclasses argument(s): " + f"{', '.join(sorted(repr(a) for a in args))} but " + f"class does not specify " + "SQLAlchemy native dataclass configuration." + ) + + if not isinstance(value, (MapperProperty, _MapsColumns)): + # filter for _DCAttributeOptions objects that aren't + # MapperProperty / mapped_column(). Currently this + # includes AssociationProxy. pop it from the things + # we're going to map and set it up as a descriptor + # on the class. + collected_attributes.pop(k) + + # Assoc Prox (or other descriptor object that may + # use _DCAttributeOptions) is usually here, except if + # 1. we're a + # dataclass, dataclasses would have removed the + # attr here or 2. assoc proxy is coming from a + # superclass, we want it to be direct here so it + # tracks state or 3. assoc prox comes from + # declared_attr, uncommon case + setattr(cls, k, value) + continue + + our_stuff[k] = value + + def _extract_declared_columns(self) -> None: + our_stuff = self.properties + + # extract columns from the class dict + declared_columns = self.declared_columns + column_ordering = self.column_ordering + name_to_prop_key = collections.defaultdict(set) + + for key, c in list(our_stuff.items()): + if isinstance(c, _MapsColumns): + mp_to_assign = c.mapper_property_to_assign + if mp_to_assign: + our_stuff[key] = mp_to_assign + else: + # if no mapper property to assign, this currently means + # this is a MappedColumn that will produce a Column for us + del our_stuff[key] + + for col, sort_order in c.columns_to_assign: + if not isinstance(c, CompositeProperty): + name_to_prop_key[col.name].add(key) + declared_columns.add(col) + + # we would assert this, however we want the below + # warning to take effect instead. See #9630 + # assert col not in column_ordering + + column_ordering[col] = sort_order + + # if this is a MappedColumn and the attribute key we + # have is not what the column has for its key, map the + # Column explicitly under the attribute key name. + # otherwise, Mapper will map it under the column key. + if mp_to_assign is None and key != col.key: + our_stuff[key] = col + elif isinstance(c, Column): + # undefer previously occurred here, and now occurs earlier. + # ensure every column we get here has been named + assert c.name is not None + name_to_prop_key[c.name].add(key) + declared_columns.add(c) + # if the column is the same name as the key, + # remove it from the explicit properties dict. + # the normal rules for assigning column-based properties + # will take over, including precedence of columns + # in multi-column ColumnProperties. + if key == c.key: + del our_stuff[key] + + for name, keys in name_to_prop_key.items(): + if len(keys) > 1: + util.warn( + "On class %r, Column object %r named " + "directly multiple times, " + "only one will be used: %s. " + "Consider using orm.synonym instead" + % (self.classname, name, (", ".join(sorted(keys)))) + ) + + def _setup_table(self, table: Optional[FromClause] = None) -> None: + cls = self.cls + cls_as_Decl = cast("MappedClassProtocol[Any]", cls) + + tablename = self.tablename + table_args = self.table_args + clsdict_view = self.clsdict_view + declared_columns = self.declared_columns + column_ordering = self.column_ordering + + manager = attributes.manager_of_class(cls) + + if "__table__" not in clsdict_view and table is None: + if hasattr(cls, "__table_cls__"): + table_cls = cast( + Type[Table], + util.unbound_method_to_callable(cls.__table_cls__), # type: ignore # noqa: E501 + ) + else: + table_cls = Table + + if tablename is not None: + args: Tuple[Any, ...] = () + table_kw: Dict[str, Any] = {} + + if table_args: + if isinstance(table_args, dict): + table_kw = table_args + elif isinstance(table_args, tuple): + if isinstance(table_args[-1], dict): + args, table_kw = table_args[0:-1], table_args[-1] + else: + args = table_args + + autoload_with = clsdict_view.get("__autoload_with__") + if autoload_with: + table_kw["autoload_with"] = autoload_with + + autoload = clsdict_view.get("__autoload__") + if autoload: + table_kw["autoload"] = True + + sorted_columns = sorted( + declared_columns, + key=lambda c: column_ordering.get(c, 0), + ) + table = self.set_cls_attribute( + "__table__", + table_cls( + tablename, + self._metadata_for_cls(manager), + *sorted_columns, + *args, + **table_kw, + ), + ) + else: + if table is None: + table = cls_as_Decl.__table__ + if declared_columns: + for c in declared_columns: + if not table.c.contains_column(c): + raise exc.ArgumentError( + "Can't add additional column %r when " + "specifying __table__" % c.key + ) + + self.local_table = table + + def _metadata_for_cls(self, manager: ClassManager[Any]) -> MetaData: + meta: Optional[MetaData] = getattr(self.cls, "metadata", None) + if meta is not None: + return meta + else: + return manager.registry.metadata + + def _setup_inheriting_mapper(self, mapper_kw: _MapperKwArgs) -> None: + cls = self.cls + + inherits = mapper_kw.get("inherits", None) + + if inherits is None: + # since we search for classical mappings now, search for + # multiple mapped bases as well and raise an error. + inherits_search = [] + for base_ in cls.__bases__: + c = _resolve_for_abstract_or_classical(base_) + if c is None: + continue + + if _is_supercls_for_inherits(c) and c not in inherits_search: + inherits_search.append(c) + + if inherits_search: + if len(inherits_search) > 1: + raise exc.InvalidRequestError( + "Class %s has multiple mapped bases: %r" + % (cls, inherits_search) + ) + inherits = inherits_search[0] + elif isinstance(inherits, Mapper): + inherits = inherits.class_ + + self.inherits = inherits + + clsdict_view = self.clsdict_view + if "__table__" not in clsdict_view and self.tablename is None: + self.single = True + + def _setup_inheriting_columns(self, mapper_kw: _MapperKwArgs) -> None: + table = self.local_table + cls = self.cls + table_args = self.table_args + declared_columns = self.declared_columns + + if ( + table is None + and self.inherits is None + and not _get_immediate_cls_attr(cls, "__no_table__") + ): + raise exc.InvalidRequestError( + "Class %r does not have a __table__ or __tablename__ " + "specified and does not inherit from an existing " + "table-mapped class." % cls + ) + elif self.inherits: + inherited_mapper_or_config = _declared_mapping_info(self.inherits) + assert inherited_mapper_or_config is not None + inherited_table = inherited_mapper_or_config.local_table + inherited_persist_selectable = ( + inherited_mapper_or_config.persist_selectable + ) + + if table is None: + # single table inheritance. + # ensure no table args + if table_args: + raise exc.ArgumentError( + "Can't place __table_args__ on an inherited class " + "with no table." + ) + + # add any columns declared here to the inherited table. + if declared_columns and not isinstance(inherited_table, Table): + raise exc.ArgumentError( + f"Can't declare columns on single-table-inherited " + f"subclass {self.cls}; superclass {self.inherits} " + "is not mapped to a Table" + ) + + for col in declared_columns: + assert inherited_table is not None + if col.name in inherited_table.c: + if inherited_table.c[col.name] is col: + continue + raise exc.ArgumentError( + f"Column '{col}' on class {cls.__name__} " + f"conflicts with existing column " + f"'{inherited_table.c[col.name]}'. If using " + f"Declarative, consider using the " + "use_existing_column parameter of mapped_column() " + "to resolve conflicts." + ) + if col.primary_key: + raise exc.ArgumentError( + "Can't place primary key columns on an inherited " + "class with no table." + ) + + if TYPE_CHECKING: + assert isinstance(inherited_table, Table) + + inherited_table.append_column(col) + if ( + inherited_persist_selectable is not None + and inherited_persist_selectable is not inherited_table + ): + inherited_persist_selectable._refresh_for_new_column( + col + ) + + def _prepare_mapper_arguments(self, mapper_kw: _MapperKwArgs) -> None: + properties = self.properties + + if self.mapper_args_fn: + mapper_args = self.mapper_args_fn() + else: + mapper_args = {} + + if mapper_kw: + mapper_args.update(mapper_kw) + + if "properties" in mapper_args: + properties = dict(properties) + properties.update(mapper_args["properties"]) + + # make sure that column copies are used rather + # than the original columns from any mixins + for k in ("version_id_col", "polymorphic_on"): + if k in mapper_args: + v = mapper_args[k] + mapper_args[k] = self.column_copies.get(v, v) + + if "primary_key" in mapper_args: + mapper_args["primary_key"] = [ + self.column_copies.get(v, v) + for v in util.to_list(mapper_args["primary_key"]) + ] + + if "inherits" in mapper_args: + inherits_arg = mapper_args["inherits"] + if isinstance(inherits_arg, Mapper): + inherits_arg = inherits_arg.class_ + + if inherits_arg is not self.inherits: + raise exc.InvalidRequestError( + "mapper inherits argument given for non-inheriting " + "class %s" % (mapper_args["inherits"]) + ) + + if self.inherits: + mapper_args["inherits"] = self.inherits + + if self.inherits and not mapper_args.get("concrete", False): + # note the superclass is expected to have a Mapper assigned and + # not be a deferred config, as this is called within map() + inherited_mapper = class_mapper(self.inherits, False) + inherited_table = inherited_mapper.local_table + + # single or joined inheritance + # exclude any cols on the inherited table which are + # not mapped on the parent class, to avoid + # mapping columns specific to sibling/nephew classes + if "exclude_properties" not in mapper_args: + mapper_args["exclude_properties"] = exclude_properties = { + c.key + for c in inherited_table.c + if c not in inherited_mapper._columntoproperty + }.union(inherited_mapper.exclude_properties or ()) + exclude_properties.difference_update( + [c.key for c in self.declared_columns] + ) + + # look through columns in the current mapper that + # are keyed to a propname different than the colname + # (if names were the same, we'd have popped it out above, + # in which case the mapper makes this combination). + # See if the superclass has a similar column property. + # If so, join them together. + for k, col in list(properties.items()): + if not isinstance(col, expression.ColumnElement): + continue + if k in inherited_mapper._props: + p = inherited_mapper._props[k] + if isinstance(p, ColumnProperty): + # note here we place the subclass column + # first. See [ticket:1892] for background. + properties[k] = [col] + p.columns + result_mapper_args = mapper_args.copy() + result_mapper_args["properties"] = properties + self.mapper_args = result_mapper_args + + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: + self._prepare_mapper_arguments(mapper_kw) + if hasattr(self.cls, "__mapper_cls__"): + mapper_cls = cast( + "Type[Mapper[Any]]", + util.unbound_method_to_callable( + self.cls.__mapper_cls__ # type: ignore + ), + ) + else: + mapper_cls = Mapper + + return self.set_cls_attribute( + "__mapper__", + mapper_cls(self.cls, self.local_table, **self.mapper_args), + ) + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _as_dc_declaredattr( + field_metadata: Mapping[str, Any], sa_dataclass_metadata_key: str +) -> Any: + # wrap lambdas inside dataclass fields inside an ad-hoc declared_attr. + # we can't write it because field.metadata is immutable :( so we have + # to go through extra trouble to compare these + decl_api = util.preloaded.orm_decl_api + obj = field_metadata[sa_dataclass_metadata_key] + if callable(obj) and not isinstance(obj, decl_api.declared_attr): + return decl_api.declared_attr(obj) + else: + return obj + + +class _DeferredMapperConfig(_ClassScanMapperConfig): + _cls: weakref.ref[Type[Any]] + + is_deferred = True + + _configs: util.OrderedDict[ + weakref.ref[Type[Any]], _DeferredMapperConfig + ] = util.OrderedDict() + + def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: + pass + + # mypy disallows plain property override of variable + @property # type: ignore + def cls(self) -> Type[Any]: + return self._cls() # type: ignore + + @cls.setter + def cls(self, class_: Type[Any]) -> None: + self._cls = weakref.ref(class_, self._remove_config_cls) + self._configs[self._cls] = self + + @classmethod + def _remove_config_cls(cls, ref: weakref.ref[Type[Any]]) -> None: + cls._configs.pop(ref, None) + + @classmethod + def has_cls(cls, class_: Type[Any]) -> bool: + # 2.6 fails on weakref if class_ is an old style class + return isinstance(class_, type) and weakref.ref(class_) in cls._configs + + @classmethod + def raise_unmapped_for_cls(cls, class_: Type[Any]) -> NoReturn: + if hasattr(class_, "_sa_raise_deferred_config"): + class_._sa_raise_deferred_config() + + raise orm_exc.UnmappedClassError( + class_, + msg=( + f"Class {orm_exc._safe_cls_name(class_)} has a deferred " + "mapping on it. It is not yet usable as a mapped class." + ), + ) + + @classmethod + def config_for_cls(cls, class_: Type[Any]) -> _DeferredMapperConfig: + return cls._configs[weakref.ref(class_)] + + @classmethod + def classes_for_base( + cls, base_cls: Type[Any], sort: bool = True + ) -> List[_DeferredMapperConfig]: + classes_for_base = [ + m + for m, cls_ in [(m, m.cls) for m in cls._configs.values()] + if cls_ is not None and issubclass(cls_, base_cls) + ] + + if not sort: + return classes_for_base + + all_m_by_cls = {m.cls: m for m in classes_for_base} + + tuples: List[Tuple[_DeferredMapperConfig, _DeferredMapperConfig]] = [] + for m_cls in all_m_by_cls: + tuples.extend( + (all_m_by_cls[base_cls], all_m_by_cls[m_cls]) + for base_cls in m_cls.__bases__ + if base_cls in all_m_by_cls + ) + return list(topological.sort(tuples, classes_for_base)) + + def map(self, mapper_kw: _MapperKwArgs = util.EMPTY_DICT) -> Mapper[Any]: + self._configs.pop(self._cls, None) + return super().map(mapper_kw) + + +def _add_attribute( + cls: Type[Any], key: str, value: MapperProperty[Any] +) -> None: + """add an attribute to an existing declarative class. + + This runs through the logic to determine MapperProperty, + adds it to the Mapper, adds a column to the mapped Table, etc. + + """ + + if "__mapper__" in cls.__dict__: + mapped_cls = cast("MappedClassProtocol[Any]", cls) + + def _table_or_raise(mc: MappedClassProtocol[Any]) -> Table: + if isinstance(mc.__table__, Table): + return mc.__table__ + raise exc.InvalidRequestError( + f"Cannot add a new attribute to mapped class {mc.__name__!r} " + "because it's not mapped against a table." + ) + + if isinstance(value, Column): + _undefer_column_name(key, value) + _table_or_raise(mapped_cls).append_column( + value, replace_existing=True + ) + mapped_cls.__mapper__.add_property(key, value) + elif isinstance(value, _MapsColumns): + mp = value.mapper_property_to_assign + for col, _ in value.columns_to_assign: + _undefer_column_name(key, col) + _table_or_raise(mapped_cls).append_column( + col, replace_existing=True + ) + if not mp: + mapped_cls.__mapper__.add_property(key, col) + if mp: + mapped_cls.__mapper__.add_property(key, mp) + elif isinstance(value, MapperProperty): + mapped_cls.__mapper__.add_property(key, value) + elif isinstance(value, QueryableAttribute) and value.key != key: + # detect a QueryableAttribute that's already mapped being + # assigned elsewhere in userland, turn into a synonym() + value = SynonymProperty(value.key) + mapped_cls.__mapper__.add_property(key, value) + else: + type.__setattr__(cls, key, value) + mapped_cls.__mapper__._expire_memoizations() + else: + type.__setattr__(cls, key, value) + + +def _del_attribute(cls: Type[Any], key: str) -> None: + if ( + "__mapper__" in cls.__dict__ + and key in cls.__dict__ + and not cast( + "MappedClassProtocol[Any]", cls + ).__mapper__._dispose_called + ): + value = cls.__dict__[key] + if isinstance( + value, (Column, _MapsColumns, MapperProperty, QueryableAttribute) + ): + raise NotImplementedError( + "Can't un-map individual mapped attributes on a mapped class." + ) + else: + type.__delattr__(cls, key) + cast( + "MappedClassProtocol[Any]", cls + ).__mapper__._expire_memoizations() + else: + type.__delattr__(cls, key) + + +def _declarative_constructor(self: Any, **kwargs: Any) -> None: + """A simple constructor that allows initialization from kwargs. + + Sets attributes on the constructed instance using the names and + values in ``kwargs``. + + Only keys that are present as + attributes of the instance's class are allowed. These could be, + for example, any mapped columns or relationships. + """ + cls_ = type(self) + for k in kwargs: + if not hasattr(cls_, k): + raise TypeError( + "%r is an invalid keyword argument for %s" % (k, cls_.__name__) + ) + setattr(self, k, kwargs[k]) + + +_declarative_constructor.__name__ = "__init__" + + +def _undefer_column_name(key: str, column: Column[Any]) -> None: + if column.key is None: + column.key = key + if column.name is None: + column.name = key diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/dependency.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/dependency.py new file mode 100644 index 0000000..71c06fb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/dependency.py @@ -0,0 +1,1304 @@ +# orm/dependency.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""Relationship dependencies. + +""" + +from __future__ import annotations + +from . import attributes +from . import exc +from . import sync +from . import unitofwork +from . import util as mapperutil +from .interfaces import MANYTOMANY +from .interfaces import MANYTOONE +from .interfaces import ONETOMANY +from .. import exc as sa_exc +from .. import sql +from .. import util + + +class DependencyProcessor: + def __init__(self, prop): + self.prop = prop + self.cascade = prop.cascade + self.mapper = prop.mapper + self.parent = prop.parent + self.secondary = prop.secondary + self.direction = prop.direction + self.post_update = prop.post_update + self.passive_deletes = prop.passive_deletes + self.passive_updates = prop.passive_updates + self.enable_typechecks = prop.enable_typechecks + if self.passive_deletes: + self._passive_delete_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_delete_flag = attributes.PASSIVE_OFF + if self.passive_updates: + self._passive_update_flag = attributes.PASSIVE_NO_INITIALIZE + else: + self._passive_update_flag = attributes.PASSIVE_OFF + + self.sort_key = "%s_%s" % (self.parent._sort_key, prop.key) + self.key = prop.key + if not self.prop.synchronize_pairs: + raise sa_exc.ArgumentError( + "Can't build a DependencyProcessor for relationship %s. " + "No target attributes to populate between parent and " + "child are present" % self.prop + ) + + @classmethod + def from_relationship(cls, prop): + return _direction_to_processor[prop.direction](prop) + + def hasparent(self, state): + """return True if the given object instance has a parent, + according to the ``InstrumentedAttribute`` handled by this + ``DependencyProcessor``. + + """ + return self.parent.class_manager.get_impl(self.key).hasparent(state) + + def per_property_preprocessors(self, uow): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states in + the aggregate. + + """ + uow.register_preprocessor(self, True) + + def per_property_flush_actions(self, uow): + after_save = unitofwork.ProcessAll(uow, self, False, True) + before_delete = unitofwork.ProcessAll(uow, self, True, True) + + parent_saves = unitofwork.SaveUpdateAll( + uow, self.parent.primary_base_mapper + ) + child_saves = unitofwork.SaveUpdateAll( + uow, self.mapper.primary_base_mapper + ) + + parent_deletes = unitofwork.DeleteAll( + uow, self.parent.primary_base_mapper + ) + child_deletes = unitofwork.DeleteAll( + uow, self.mapper.primary_base_mapper + ) + + self.per_property_dependencies( + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ) + + def per_state_flush_actions(self, uow, states, isdelete): + """establish actions and dependencies related to a flush. + + These actions will operate on all relevant states + individually. This occurs only if there are cycles + in the 'aggregated' version of events. + + """ + + child_base_mapper = self.mapper.primary_base_mapper + child_saves = unitofwork.SaveUpdateAll(uow, child_base_mapper) + child_deletes = unitofwork.DeleteAll(uow, child_base_mapper) + + # locate and disable the aggregate processors + # for this dependency + + if isdelete: + before_delete = unitofwork.ProcessAll(uow, self, True, True) + before_delete.disabled = True + else: + after_save = unitofwork.ProcessAll(uow, self, False, True) + after_save.disabled = True + + # check if the "child" side is part of the cycle + + if child_saves not in uow.cycles: + # based on the current dependencies we use, the saves/ + # deletes should always be in the 'cycles' collection + # together. if this changes, we will have to break up + # this method a bit more. + assert child_deletes not in uow.cycles + + # child side is not part of the cycle, so we will link per-state + # actions to the aggregate "saves", "deletes" actions + child_actions = [(child_saves, False), (child_deletes, True)] + child_in_cycles = False + else: + child_in_cycles = True + + # check if the "parent" side is part of the cycle + if not isdelete: + parent_saves = unitofwork.SaveUpdateAll( + uow, self.parent.base_mapper + ) + parent_deletes = before_delete = None + if parent_saves in uow.cycles: + parent_in_cycles = True + else: + parent_deletes = unitofwork.DeleteAll(uow, self.parent.base_mapper) + parent_saves = after_save = None + if parent_deletes in uow.cycles: + parent_in_cycles = True + + # now create actions /dependencies for each state. + + for state in states: + # detect if there's anything changed or loaded + # by a preprocessor on this state/attribute. In the + # case of deletes we may try to load missing items here as well. + sum_ = state.manager[self.key].impl.get_all_pending( + state, + state.dict, + ( + self._passive_delete_flag + if isdelete + else attributes.PASSIVE_NO_INITIALIZE + ), + ) + + if not sum_: + continue + + if isdelete: + before_delete = unitofwork.ProcessState(uow, self, True, state) + if parent_in_cycles: + parent_deletes = unitofwork.DeleteState(uow, state) + else: + after_save = unitofwork.ProcessState(uow, self, False, state) + if parent_in_cycles: + parent_saves = unitofwork.SaveUpdateState(uow, state) + + if child_in_cycles: + child_actions = [] + for child_state, child in sum_: + if child_state not in uow.states: + child_action = (None, None) + else: + (deleted, listonly) = uow.states[child_state] + if deleted: + child_action = ( + unitofwork.DeleteState(uow, child_state), + True, + ) + else: + child_action = ( + unitofwork.SaveUpdateState(uow, child_state), + False, + ) + child_actions.append(child_action) + + # establish dependencies between our possibly per-state + # parent action and our possibly per-state child action. + for child_action, childisdelete in child_actions: + self.per_state_dependencies( + uow, + parent_saves, + parent_deletes, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ) + + def presort_deletes(self, uowcommit, states): + return False + + def presort_saves(self, uowcommit, states): + return False + + def process_deletes(self, uowcommit, states): + pass + + def process_saves(self, uowcommit, states): + pass + + def prop_has_changes(self, uowcommit, states, isdelete): + if not isdelete or self.passive_deletes: + passive = ( + attributes.PASSIVE_NO_INITIALIZE + | attributes.INCLUDE_PENDING_MUTATIONS + ) + elif self.direction is MANYTOONE: + # here, we were hoping to optimize having to fetch many-to-one + # for history and ignore it, if there's no further cascades + # to take place. however there are too many less common conditions + # that still take place and tests in test_relationships / + # test_cascade etc. will still fail. + passive = attributes.PASSIVE_NO_FETCH_RELATED + else: + passive = ( + attributes.PASSIVE_OFF | attributes.INCLUDE_PENDING_MUTATIONS + ) + + for s in states: + # TODO: add a high speed method + # to InstanceState which returns: attribute + # has a non-None value, or had one + history = uowcommit.get_attribute_history(s, self.key, passive) + if history and not history.empty(): + return True + else: + return ( + states + and not self.prop._is_self_referential + and self.mapper in uowcommit.mappers + ) + + def _verify_canload(self, state): + if self.prop.uselist and state is None: + raise exc.FlushError( + "Can't flush None value found in " + "collection %s" % (self.prop,) + ) + elif state is not None and not self.mapper._canload( + state, allow_subtypes=not self.enable_typechecks + ): + if self.mapper._canload(state, allow_subtypes=True): + raise exc.FlushError( + "Attempting to flush an item of type " + "%(x)s as a member of collection " + '"%(y)s". Expected an object of type ' + "%(z)s or a polymorphic subclass of " + "this type. If %(x)s is a subclass of " + '%(z)s, configure mapper "%(zm)s" to ' + "load this subtype polymorphically, or " + "set enable_typechecks=False to allow " + "any subtype to be accepted for flush. " + % { + "x": state.class_, + "y": self.prop, + "z": self.mapper.class_, + "zm": self.mapper, + } + ) + else: + raise exc.FlushError( + "Attempting to flush an item of type " + "%(x)s as a member of collection " + '"%(y)s". Expected an object of type ' + "%(z)s or a polymorphic subclass of " + "this type." + % { + "x": state.class_, + "y": self.prop, + "z": self.mapper.class_, + } + ) + + def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): + raise NotImplementedError() + + def _get_reversed_processed_set(self, uow): + if not self.prop._reverse_property: + return None + + process_key = tuple( + sorted([self.key] + [p.key for p in self.prop._reverse_property]) + ) + return uow.memo(("reverse_key", process_key), set) + + def _post_update(self, state, uowcommit, related, is_m2o_delete=False): + for x in related: + if not is_m2o_delete or x is not None: + uowcommit.register_post_update( + state, [r for l, r in self.prop.synchronize_pairs] + ) + break + + def _pks_changed(self, uowcommit, state): + raise NotImplementedError() + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.prop) + + +class OneToManyDP(DependencyProcessor): + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + if self.post_update: + child_post_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, False + ) + child_pre_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, child_post_updates), + (before_delete, child_pre_updates), + (child_pre_updates, parent_deletes), + (child_pre_updates, child_deletes), + ] + ) + else: + uow.dependencies.update( + [ + (parent_saves, after_save), + (after_save, child_saves), + (after_save, child_deletes), + (child_saves, parent_deletes), + (child_deletes, parent_deletes), + (before_delete, child_saves), + (before_delete, child_deletes), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): + if self.post_update: + child_post_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, False + ) + child_pre_updates = unitofwork.PostUpdateAll( + uow, self.mapper.primary_base_mapper, True + ) + + # TODO: this whole block is not covered + # by any tests + if not isdelete: + if childisdelete: + uow.dependencies.update( + [ + (child_action, after_save), + (after_save, child_post_updates), + ] + ) + else: + uow.dependencies.update( + [ + (save_parent, after_save), + (child_action, after_save), + (after_save, child_post_updates), + ] + ) + else: + if childisdelete: + uow.dependencies.update( + [ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ] + ) + else: + uow.dependencies.update( + [ + (before_delete, child_pre_updates), + (child_pre_updates, delete_parent), + ] + ) + elif not isdelete: + uow.dependencies.update( + [ + (save_parent, after_save), + (after_save, child_action), + (save_parent, child_action), + ] + ) + else: + uow.dependencies.update( + [(before_delete, child_action), (child_action, delete_parent)] + ) + + def presort_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their + # foreign key to the parent set to NULL + should_null_fks = ( + not self.cascade.delete and not self.passive_deletes == "all" + ) + + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.deleted: + if child is not None and self.hasparent(child) is False: + if self.cascade.delete_orphan: + uowcommit.register_object(child, isdelete=True) + else: + uowcommit.register_object(child) + + if should_null_fks: + for child in history.unchanged: + if child is not None: + uowcommit.register_object( + child, operation="delete", prop=self.prop + ) + + def presort_saves(self, uowcommit, states): + children_added = uowcommit.memo(("children_added", self), set) + + should_null_fks = ( + not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ) + + for state in states: + pks_changed = self._pks_changed(uowcommit, state) + + if not pks_changed or self.passive_updates: + passive = ( + attributes.PASSIVE_NO_INITIALIZE + | attributes.INCLUDE_PENDING_MUTATIONS + ) + else: + passive = ( + attributes.PASSIVE_OFF + | attributes.INCLUDE_PENDING_MUTATIONS + ) + + history = uowcommit.get_attribute_history(state, self.key, passive) + if history: + for child in history.added: + if child is not None: + uowcommit.register_object( + child, + cancel_delete=True, + operation="add", + prop=self.prop, + ) + + children_added.update(history.added) + + for child in history.deleted: + if not self.cascade.delete_orphan: + if should_null_fks: + uowcommit.register_object( + child, + isdelete=False, + operation="delete", + prop=self.prop, + ) + elif self.hasparent(child) is False: + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + "delete", child + ): + uowcommit.register_object(st_, isdelete=True) + + if pks_changed: + if history: + for child in history.unchanged: + if child is not None: + uowcommit.register_object( + child, + False, + self.passive_updates, + operation="pk change", + prop=self.prop, + ) + + def process_deletes(self, uowcommit, states): + # head object is being deleted, and we manage its list of + # child objects the child objects have to have their foreign + # key to the parent set to NULL this phase can be called + # safely for any cascade but is unnecessary if delete cascade + # is on. + + if self.post_update or not self.passive_deletes == "all": + children_added = uowcommit.memo(("children_added", self), set) + + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.deleted: + if ( + child is not None + and self.hasparent(child) is False + ): + self._synchronize( + state, child, None, True, uowcommit, False + ) + if self.post_update and child: + self._post_update(child, uowcommit, [state]) + + if self.post_update or not self.cascade.delete: + for child in set(history.unchanged).difference( + children_added + ): + if child is not None: + self._synchronize( + state, child, None, True, uowcommit, False + ) + if self.post_update and child: + self._post_update( + child, uowcommit, [state] + ) + + # technically, we can even remove each child from the + # collection here too. but this would be a somewhat + # inconsistent behavior since it wouldn't happen + # if the old parent wasn't deleted but child was moved. + + def process_saves(self, uowcommit, states): + should_null_fks = ( + not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ) + + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) + if history: + for child in history.added: + self._synchronize( + state, child, None, False, uowcommit, False + ) + if child is not None and self.post_update: + self._post_update(child, uowcommit, [state]) + + for child in history.deleted: + if ( + should_null_fks + and not self.cascade.delete_orphan + and not self.hasparent(child) + ): + self._synchronize( + state, child, None, True, uowcommit, False + ) + + if self._pks_changed(uowcommit, state): + for child in history.unchanged: + self._synchronize( + state, child, None, False, uowcommit, True + ) + + def _synchronize( + self, state, child, associationrow, clearkeys, uowcommit, pks_changed + ): + source = state + dest = child + self._verify_canload(child) + if dest is None or ( + not self.post_update and uowcommit.is_deleted(dest) + ): + return + if clearkeys: + sync.clear(dest, self.mapper, self.prop.synchronize_pairs) + else: + sync.populate( + source, + self.parent, + dest, + self.mapper, + self.prop.synchronize_pairs, + uowcommit, + self.passive_updates and pks_changed, + ) + + def _pks_changed(self, uowcommit, state): + return sync.source_modified( + uowcommit, state, self.parent, self.prop.synchronize_pairs + ) + + +class ManyToOneDP(DependencyProcessor): + def __init__(self, prop): + DependencyProcessor.__init__(self, prop) + for mapper in self.mapper.self_and_descendants: + mapper._dependency_processors.append(DetectKeySwitch(prop)) + + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + if self.post_update: + parent_post_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, False + ) + parent_pre_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (child_saves, after_save), + (parent_saves, after_save), + (after_save, parent_post_updates), + (after_save, parent_pre_updates), + (before_delete, parent_pre_updates), + (parent_pre_updates, child_deletes), + (parent_pre_updates, parent_deletes), + ] + ) + else: + uow.dependencies.update( + [ + (child_saves, after_save), + (after_save, parent_saves), + (parent_saves, child_deletes), + (parent_deletes, child_deletes), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): + if self.post_update: + if not isdelete: + parent_post_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, False + ) + if childisdelete: + uow.dependencies.update( + [ + (after_save, parent_post_updates), + (parent_post_updates, child_action), + ] + ) + else: + uow.dependencies.update( + [ + (save_parent, after_save), + (child_action, after_save), + (after_save, parent_post_updates), + ] + ) + else: + parent_pre_updates = unitofwork.PostUpdateAll( + uow, self.parent.primary_base_mapper, True + ) + + uow.dependencies.update( + [ + (before_delete, parent_pre_updates), + (parent_pre_updates, delete_parent), + (parent_pre_updates, child_action), + ] + ) + + elif not isdelete: + if not childisdelete: + uow.dependencies.update( + [(child_action, after_save), (after_save, save_parent)] + ) + else: + uow.dependencies.update([(after_save, save_parent)]) + + else: + if childisdelete: + uow.dependencies.update([(delete_parent, child_action)]) + + def presort_deletes(self, uowcommit, states): + if self.cascade.delete or self.cascade.delete_orphan: + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + if self.cascade.delete_orphan: + todelete = history.sum() + else: + todelete = history.non_deleted() + for child in todelete: + if child is None: + continue + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + t = self.mapper.cascade_iterator("delete", child) + for c, m, st_, dct_ in t: + uowcommit.register_object(st_, isdelete=True) + + def presort_saves(self, uowcommit, states): + for state in states: + uowcommit.register_object(state, operation="add", prop=self.prop) + if self.cascade.delete_orphan: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + + t = self.mapper.cascade_iterator("delete", child) + for c, m, st_, dct_ in t: + uowcommit.register_object(st_, isdelete=True) + + def process_deletes(self, uowcommit, states): + if ( + self.post_update + and not self.cascade.delete_orphan + and not self.passive_deletes == "all" + ): + # post_update means we have to update our + # row to not reference the child object + # before we can DELETE the row + for state in states: + self._synchronize(state, None, None, True, uowcommit) + if state and self.post_update: + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + self._post_update( + state, uowcommit, history.sum(), is_m2o_delete=True + ) + + def process_saves(self, uowcommit, states): + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) + if history: + if history.added: + for child in history.added: + self._synchronize( + state, child, None, False, uowcommit, "add" + ) + elif history.deleted: + self._synchronize( + state, None, None, True, uowcommit, "delete" + ) + if self.post_update: + self._post_update(state, uowcommit, history.sum()) + + def _synchronize( + self, + state, + child, + associationrow, + clearkeys, + uowcommit, + operation=None, + ): + if state is None or ( + not self.post_update and uowcommit.is_deleted(state) + ): + return + + if ( + operation is not None + and child is not None + and not uowcommit.session._contains_state(child) + ): + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" + % (mapperutil.state_class_str(child), operation, self.prop) + ) + return + + if clearkeys or child is None: + sync.clear(state, self.parent, self.prop.synchronize_pairs) + else: + self._verify_canload(child) + sync.populate( + child, + self.mapper, + state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + False, + ) + + +class DetectKeySwitch(DependencyProcessor): + """For many-to-one relationships with no one-to-many backref, + searches for parents through the unit of work when a primary + key has changed and updates them. + + Theoretically, this approach could be expanded to support transparent + deletion of objects referenced via many-to-one as well, although + the current attribute system doesn't do enough bookkeeping for this + to be efficient. + + """ + + def per_property_preprocessors(self, uow): + if self.prop._reverse_property: + if self.passive_updates: + return + else: + if False in ( + prop.passive_updates + for prop in self.prop._reverse_property + ): + return + + uow.register_preprocessor(self, False) + + def per_property_flush_actions(self, uow): + parent_saves = unitofwork.SaveUpdateAll(uow, self.parent.base_mapper) + after_save = unitofwork.ProcessAll(uow, self, False, False) + uow.dependencies.update([(parent_saves, after_save)]) + + def per_state_flush_actions(self, uow, states, isdelete): + pass + + def presort_deletes(self, uowcommit, states): + pass + + def presort_saves(self, uow, states): + if not self.passive_updates: + # for non-passive updates, register in the preprocess stage + # so that mapper save_obj() gets a hold of changes + self._process_key_switches(states, uow) + + def prop_has_changes(self, uow, states, isdelete): + if not isdelete and self.passive_updates: + d = self._key_switchers(uow, states) + return bool(d) + + return False + + def process_deletes(self, uowcommit, states): + assert False + + def process_saves(self, uowcommit, states): + # for passive updates, register objects in the process stage + # so that we avoid ManyToOneDP's registering the object without + # the listonly flag in its own preprocess stage (results in UPDATE) + # statements being emitted + assert self.passive_updates + self._process_key_switches(states, uowcommit) + + def _key_switchers(self, uow, states): + switched, notswitched = uow.memo( + ("pk_switchers", self), lambda: (set(), set()) + ) + + allstates = switched.union(notswitched) + for s in states: + if s not in allstates: + if self._pks_changed(uow, s): + switched.add(s) + else: + notswitched.add(s) + return switched + + def _process_key_switches(self, deplist, uowcommit): + switchers = self._key_switchers(uowcommit, deplist) + if switchers: + # if primary key values have actually changed somewhere, perform + # a linear search through the UOW in search of a parent. + for state in uowcommit.session.identity_map.all_states(): + if not issubclass(state.class_, self.parent.class_): + continue + dict_ = state.dict + related = state.get_impl(self.key).get( + state, dict_, passive=self._passive_update_flag + ) + if ( + related is not attributes.PASSIVE_NO_RESULT + and related is not None + ): + if self.prop.uselist: + if not related: + continue + related_obj = related[0] + else: + related_obj = related + related_state = attributes.instance_state(related_obj) + if related_state in switchers: + uowcommit.register_object( + state, False, self.passive_updates + ) + sync.populate( + related_state, + self.mapper, + state, + self.parent, + self.prop.synchronize_pairs, + uowcommit, + self.passive_updates, + ) + + def _pks_changed(self, uowcommit, state): + return bool(state.key) and sync.source_modified( + uowcommit, state, self.mapper, self.prop.synchronize_pairs + ) + + +class ManyToManyDP(DependencyProcessor): + def per_property_dependencies( + self, + uow, + parent_saves, + child_saves, + parent_deletes, + child_deletes, + after_save, + before_delete, + ): + uow.dependencies.update( + [ + (parent_saves, after_save), + (child_saves, after_save), + (after_save, child_deletes), + # a rowswitch on the parent from deleted to saved + # can make this one occur, as the "save" may remove + # an element from the + # "deleted" list before we have a chance to + # process its child rows + (before_delete, parent_saves), + (before_delete, parent_deletes), + (before_delete, child_deletes), + (before_delete, child_saves), + ] + ) + + def per_state_dependencies( + self, + uow, + save_parent, + delete_parent, + child_action, + after_save, + before_delete, + isdelete, + childisdelete, + ): + if not isdelete: + if childisdelete: + uow.dependencies.update( + [(save_parent, after_save), (after_save, child_action)] + ) + else: + uow.dependencies.update( + [(save_parent, after_save), (child_action, after_save)] + ) + else: + uow.dependencies.update( + [(before_delete, child_action), (before_delete, delete_parent)] + ) + + def presort_deletes(self, uowcommit, states): + # TODO: no tests fail if this whole + # thing is removed !!!! + if not self.passive_deletes: + # if no passive deletes, load history on + # the collection, so that prop_has_changes() + # returns True + for state in states: + uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + + def presort_saves(self, uowcommit, states): + if not self.passive_updates: + # if no passive updates, load history on + # each collection where parent has changed PK, + # so that prop_has_changes() returns True + for state in states: + if self._pks_changed(uowcommit, state): + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_OFF + ) + + if not self.cascade.delete_orphan: + return + + # check for child items removed from the collection + # if delete_orphan check is turned on. + for state in states: + history = uowcommit.get_attribute_history( + state, self.key, attributes.PASSIVE_NO_INITIALIZE + ) + if history: + for child in history.deleted: + if self.hasparent(child) is False: + uowcommit.register_object( + child, + isdelete=True, + operation="delete", + prop=self.prop, + ) + for c, m, st_, dct_ in self.mapper.cascade_iterator( + "delete", child + ): + uowcommit.register_object(st_, isdelete=True) + + def process_deletes(self, uowcommit, states): + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + for state in states: + # this history should be cached already, as + # we loaded it in preprocess_deletes + history = uowcommit.get_attribute_history( + state, self.key, self._passive_delete_flag + ) + if history: + for child in history.non_added(): + if child is None or ( + processed is not None and (state, child) in processed + ): + continue + associationrow = {} + if not self._synchronize( + state, + child, + associationrow, + False, + uowcommit, + "delete", + ): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) for c in history.non_added()) + + if processed is not None: + processed.update(tmp) + + self._run_crud( + uowcommit, secondary_insert, secondary_update, secondary_delete + ) + + def process_saves(self, uowcommit, states): + secondary_delete = [] + secondary_insert = [] + secondary_update = [] + + processed = self._get_reversed_processed_set(uowcommit) + tmp = set() + + for state in states: + need_cascade_pks = not self.passive_updates and self._pks_changed( + uowcommit, state + ) + if need_cascade_pks: + passive = ( + attributes.PASSIVE_OFF + | attributes.INCLUDE_PENDING_MUTATIONS + ) + else: + passive = ( + attributes.PASSIVE_NO_INITIALIZE + | attributes.INCLUDE_PENDING_MUTATIONS + ) + history = uowcommit.get_attribute_history(state, self.key, passive) + if history: + for child in history.added: + if processed is not None and (state, child) in processed: + continue + associationrow = {} + if not self._synchronize( + state, child, associationrow, False, uowcommit, "add" + ): + continue + secondary_insert.append(associationrow) + for child in history.deleted: + if processed is not None and (state, child) in processed: + continue + associationrow = {} + if not self._synchronize( + state, + child, + associationrow, + False, + uowcommit, + "delete", + ): + continue + secondary_delete.append(associationrow) + + tmp.update((c, state) for c in history.added + history.deleted) + + if need_cascade_pks: + for child in history.unchanged: + associationrow = {} + sync.update( + state, + self.parent, + associationrow, + "old_", + self.prop.synchronize_pairs, + ) + sync.update( + child, + self.mapper, + associationrow, + "old_", + self.prop.secondary_synchronize_pairs, + ) + + secondary_update.append(associationrow) + + if processed is not None: + processed.update(tmp) + + self._run_crud( + uowcommit, secondary_insert, secondary_update, secondary_delete + ) + + def _run_crud( + self, uowcommit, secondary_insert, secondary_update, secondary_delete + ): + connection = uowcommit.transaction.connection(self.mapper) + + if secondary_delete: + associationrow = secondary_delete[0] + statement = self.secondary.delete().where( + sql.and_( + *[ + c == sql.bindparam(c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ] + ) + ) + result = connection.execute(statement, secondary_delete) + + if ( + result.supports_sane_multi_rowcount() + ) and result.rowcount != len(secondary_delete): + raise exc.StaleDataError( + "DELETE statement on table '%s' expected to delete " + "%d row(s); Only %d were matched." + % ( + self.secondary.description, + len(secondary_delete), + result.rowcount, + ) + ) + + if secondary_update: + associationrow = secondary_update[0] + statement = self.secondary.update().where( + sql.and_( + *[ + c == sql.bindparam("old_" + c.key, type_=c.type) + for c in self.secondary.c + if c.key in associationrow + ] + ) + ) + result = connection.execute(statement, secondary_update) + + if ( + result.supports_sane_multi_rowcount() + ) and result.rowcount != len(secondary_update): + raise exc.StaleDataError( + "UPDATE statement on table '%s' expected to update " + "%d row(s); Only %d were matched." + % ( + self.secondary.description, + len(secondary_update), + result.rowcount, + ) + ) + + if secondary_insert: + statement = self.secondary.insert() + connection.execute(statement, secondary_insert) + + def _synchronize( + self, state, child, associationrow, clearkeys, uowcommit, operation + ): + # this checks for None if uselist=True + self._verify_canload(child) + + # but if uselist=False we get here. If child is None, + # no association row can be generated, so return. + if child is None: + return False + + if child is not None and not uowcommit.session._contains_state(child): + if not child.deleted: + util.warn( + "Object of type %s not in session, %s " + "operation along '%s' won't proceed" + % (mapperutil.state_class_str(child), operation, self.prop) + ) + return False + + sync.populate_dict( + state, self.parent, associationrow, self.prop.synchronize_pairs + ) + sync.populate_dict( + child, + self.mapper, + associationrow, + self.prop.secondary_synchronize_pairs, + ) + + return True + + def _pks_changed(self, uowcommit, state): + return sync.source_modified( + uowcommit, state, self.parent, self.prop.synchronize_pairs + ) + + +_direction_to_processor = { + ONETOMANY: OneToManyDP, + MANYTOONE: ManyToOneDP, + MANYTOMANY: ManyToManyDP, +} diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/descriptor_props.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/descriptor_props.py new file mode 100644 index 0000000..a3650f5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/descriptor_props.py @@ -0,0 +1,1074 @@ +# orm/descriptor_props.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Descriptor properties are more "auxiliary" properties +that exist as configurational elements, but don't participate +as actively in the load/persist ORM loop. + +""" +from __future__ import annotations + +from dataclasses import is_dataclass +import inspect +import itertools +import operator +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import util as orm_util +from .base import _DeclarativeMapped +from .base import LoaderCallableStatus +from .base import Mapped +from .base import PassiveFlag +from .base import SQLORMOperations +from .interfaces import _AttributeOptions +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns +from .interfaces import MapperProperty +from .interfaces import PropComparator +from .util import _none_set +from .util import de_stringify_annotation +from .. import event +from .. import exc as sa_exc +from .. import schema +from .. import sql +from .. import util +from ..sql import expression +from ..sql import operators +from ..sql.elements import BindParameter +from ..util.typing import is_fwd_ref +from ..util.typing import is_pep593 +from ..util.typing import typing_get_args + +if typing.TYPE_CHECKING: + from ._typing import _InstanceDict + from ._typing import _RegistryType + from .attributes import History + from .attributes import InstrumentedAttribute + from .attributes import QueryableAttribute + from .context import ORMCompileState + from .decl_base import _ClassScanMapperConfig + from .mapper import Mapper + from .properties import ColumnProperty + from .properties import MappedColumn + from .state import InstanceState + from ..engine.base import Connection + from ..engine.row import Row + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql.elements import ClauseList + from ..sql.elements import ColumnElement + from ..sql.operators import OperatorType + from ..sql.schema import Column + from ..sql.selectable import Select + from ..util.typing import _AnnotationScanType + from ..util.typing import CallableReference + from ..util.typing import DescriptorReference + from ..util.typing import RODescriptorReference + +_T = TypeVar("_T", bound=Any) +_PT = TypeVar("_PT", bound=Any) + + +class DescriptorProperty(MapperProperty[_T]): + """:class:`.MapperProperty` which proxies access to a + user-defined descriptor.""" + + doc: Optional[str] = None + + uses_objects = False + _links_to_entity = False + + descriptor: DescriptorReference[Any] + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + raise NotImplementedError() + + def instrument_class(self, mapper: Mapper[Any]) -> None: + prop = self + + class _ProxyImpl(attributes.AttributeImpl): + accepts_scalar_loader = False + load_on_unexpire = True + collection = False + + @property + def uses_objects(self) -> bool: # type: ignore + return prop.uses_objects + + def __init__(self, key: str): + self.key = key + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + return prop.get_history(state, dict_, passive) + + if self.descriptor is None: + desc = getattr(mapper.class_, self.key, None) + if mapper._is_userland_descriptor(self.key, desc): + self.descriptor = desc + + if self.descriptor is None: + + def fset(obj: Any, value: Any) -> None: + setattr(obj, self.name, value) + + def fdel(obj: Any) -> None: + delattr(obj, self.name) + + def fget(obj: Any) -> Any: + return getattr(obj, self.name) + + self.descriptor = property(fget=fget, fset=fset, fdel=fdel) + + proxy_attr = attributes.create_proxied_attribute(self.descriptor)( + self.parent.class_, + self.key, + self.descriptor, + lambda: self._comparator_factory(mapper), + doc=self.doc, + original_property=self, + ) + proxy_attr.impl = _ProxyImpl(self.key) + mapper.class_manager.instrument_attribute(self.key, proxy_attr) + + +_CompositeAttrType = Union[ + str, + "Column[_T]", + "MappedColumn[_T]", + "InstrumentedAttribute[_T]", + "Mapped[_T]", +] + + +_CC = TypeVar("_CC", bound=Any) + + +_composite_getters: weakref.WeakKeyDictionary[ + Type[Any], Callable[[Any], Tuple[Any, ...]] +] = weakref.WeakKeyDictionary() + + +class CompositeProperty( + _MapsColumns[_CC], _IntrospectsAnnotations, DescriptorProperty[_CC] +): + """Defines a "composite" mapped attribute, representing a collection + of columns as one attribute. + + :class:`.CompositeProperty` is constructed using the :func:`.composite` + function. + + .. seealso:: + + :ref:`mapper_composite` + + """ + + composite_class: Union[Type[_CC], Callable[..., _CC]] + attrs: Tuple[_CompositeAttrType[Any], ...] + + _generated_composite_accessor: CallableReference[ + Optional[Callable[[_CC], Tuple[Any, ...]]] + ] + + comparator_factory: Type[Comparator[_CC]] + + def __init__( + self, + _class_or_attr: Union[ + None, Type[_CC], Callable[..., _CC], _CompositeAttrType[Any] + ] = None, + *attrs: _CompositeAttrType[Any], + attribute_options: Optional[_AttributeOptions] = None, + active_history: bool = False, + deferred: bool = False, + group: Optional[str] = None, + comparator_factory: Optional[Type[Comparator[_CC]]] = None, + info: Optional[_InfoType] = None, + **kwargs: Any, + ): + super().__init__(attribute_options=attribute_options) + + if isinstance(_class_or_attr, (Mapped, str, sql.ColumnElement)): + self.attrs = (_class_or_attr,) + attrs + # will initialize within declarative_scan + self.composite_class = None # type: ignore + else: + self.composite_class = _class_or_attr # type: ignore + self.attrs = attrs + + self.active_history = active_history + self.deferred = deferred + self.group = group + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator + ) + self._generated_composite_accessor = None + if info is not None: + self.info.update(info) + + util.set_creation_order(self) + self._create_descriptor() + self._init_accessor() + + def instrument_class(self, mapper: Mapper[Any]) -> None: + super().instrument_class(mapper) + self._setup_event_handlers() + + def _composite_values_from_instance(self, value: _CC) -> Tuple[Any, ...]: + if self._generated_composite_accessor: + return self._generated_composite_accessor(value) + else: + try: + accessor = value.__composite_values__ + except AttributeError as ae: + raise sa_exc.InvalidRequestError( + f"Composite class {self.composite_class.__name__} is not " + f"a dataclass and does not define a __composite_values__()" + " method; can't get state" + ) from ae + else: + return accessor() # type: ignore + + def do_init(self) -> None: + """Initialization which occurs after the :class:`.Composite` + has been associated with its parent mapper. + + """ + self._setup_arguments_on_columns() + + _COMPOSITE_FGET = object() + + def _create_descriptor(self) -> None: + """Create the Python descriptor that will serve as + the access point on instances of the mapped class. + + """ + + def fget(instance: Any) -> Any: + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + + if self.key not in dict_: + # key not present. Iterate through related + # attributes, retrieve their values. This + # ensures they all load. + values = [ + getattr(instance, key) for key in self._attribute_keys + ] + + # current expected behavior here is that the composite is + # created on access if the object is persistent or if + # col attributes have non-None. This would be better + # if the composite were created unconditionally, + # but that would be a behavioral change. + if self.key not in dict_ and ( + state.key is not None or not _none_set.issuperset(values) + ): + dict_[self.key] = self.composite_class(*values) + state.manager.dispatch.refresh( + state, self._COMPOSITE_FGET, [self.key] + ) + + return dict_.get(self.key, None) + + def fset(instance: Any, value: Any) -> None: + dict_ = attributes.instance_dict(instance) + state = attributes.instance_state(instance) + attr = state.manager[self.key] + + if attr.dispatch._active_history: + previous = fget(instance) + else: + previous = dict_.get(self.key, LoaderCallableStatus.NO_VALUE) + + for fn in attr.dispatch.set: + value = fn(state, value, previous, attr.impl) + dict_[self.key] = value + if value is None: + for key in self._attribute_keys: + setattr(instance, key, None) + else: + for key, value in zip( + self._attribute_keys, + self._composite_values_from_instance(value), + ): + setattr(instance, key, value) + + def fdel(instance: Any) -> None: + state = attributes.instance_state(instance) + dict_ = attributes.instance_dict(instance) + attr = state.manager[self.key] + + if attr.dispatch._active_history: + previous = fget(instance) + dict_.pop(self.key, None) + else: + previous = dict_.pop(self.key, LoaderCallableStatus.NO_VALUE) + + attr = state.manager[self.key] + attr.dispatch.remove(state, previous, attr.impl) + for key in self._attribute_keys: + setattr(instance, key, None) + + self.descriptor = property(fget, fset, fdel) + + @util.preload_module("sqlalchemy.orm.properties") + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + MappedColumn = util.preloaded.orm_properties.MappedColumn + if ( + self.composite_class is None + and extracted_mapped_annotation is None + ): + self._raise_for_required(key, cls) + argument = extracted_mapped_annotation + + if is_pep593(argument): + argument = typing_get_args(argument)[0] + + if argument and self.composite_class is None: + if isinstance(argument, str) or is_fwd_ref( + argument, check_generic=True + ): + if originating_module is None: + str_arg = ( + argument.__forward_arg__ + if hasattr(argument, "__forward_arg__") + else str(argument) + ) + raise sa_exc.ArgumentError( + f"Can't use forward ref {argument} for composite " + f"class argument; set up the type as Mapped[{str_arg}]" + ) + argument = de_stringify_annotation( + cls, argument, originating_module, include_generic=True + ) + + self.composite_class = argument + + if is_dataclass(self.composite_class): + self._setup_for_dataclass(registry, cls, originating_module, key) + else: + for attr in self.attrs: + if ( + isinstance(attr, (MappedColumn, schema.Column)) + and attr.name is None + ): + raise sa_exc.ArgumentError( + "Composite class column arguments must be named " + "unless a dataclass is used" + ) + self._init_accessor() + + def _init_accessor(self) -> None: + if is_dataclass(self.composite_class) and not hasattr( + self.composite_class, "__composite_values__" + ): + insp = inspect.signature(self.composite_class) + getter = operator.attrgetter( + *[p.name for p in insp.parameters.values()] + ) + if len(insp.parameters) == 1: + self._generated_composite_accessor = lambda obj: (getter(obj),) + else: + self._generated_composite_accessor = getter + + if ( + self.composite_class is not None + and isinstance(self.composite_class, type) + and self.composite_class not in _composite_getters + ): + if self._generated_composite_accessor is not None: + _composite_getters[self.composite_class] = ( + self._generated_composite_accessor + ) + elif hasattr(self.composite_class, "__composite_values__"): + _composite_getters[self.composite_class] = ( + lambda obj: obj.__composite_values__() + ) + + @util.preload_module("sqlalchemy.orm.properties") + @util.preload_module("sqlalchemy.orm.decl_base") + def _setup_for_dataclass( + self, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + ) -> None: + MappedColumn = util.preloaded.orm_properties.MappedColumn + + decl_base = util.preloaded.orm_decl_base + + insp = inspect.signature(self.composite_class) + for param, attr in itertools.zip_longest( + insp.parameters.values(), self.attrs + ): + if param is None: + raise sa_exc.ArgumentError( + f"number of composite attributes " + f"{len(self.attrs)} exceeds " + f"that of the number of attributes in class " + f"{self.composite_class.__name__} {len(insp.parameters)}" + ) + if attr is None: + # fill in missing attr spots with empty MappedColumn + attr = MappedColumn() + self.attrs += (attr,) + + if isinstance(attr, MappedColumn): + attr.declarative_scan_for_composite( + registry, + cls, + originating_module, + key, + param.name, + param.annotation, + ) + elif isinstance(attr, schema.Column): + decl_base._undefer_column_name(param.name, attr) + + @util.memoized_property + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: + return [getattr(self.parent.class_, prop.key) for prop in self.props] + + @util.memoized_property + @util.preload_module("orm.properties") + def props(self) -> Sequence[MapperProperty[Any]]: + props = [] + MappedColumn = util.preloaded.orm_properties.MappedColumn + + for attr in self.attrs: + if isinstance(attr, str): + prop = self.parent.get_property(attr, _configure_mappers=False) + elif isinstance(attr, schema.Column): + prop = self.parent._columntoproperty[attr] + elif isinstance(attr, MappedColumn): + prop = self.parent._columntoproperty[attr.column] + elif isinstance(attr, attributes.InstrumentedAttribute): + prop = attr.property + else: + prop = None + + if not isinstance(prop, MapperProperty): + raise sa_exc.ArgumentError( + "Composite expects Column objects or mapped " + f"attributes/attribute names as arguments, got: {attr!r}" + ) + + props.append(prop) + return props + + @util.non_memoized_property + @util.preload_module("orm.properties") + def columns(self) -> Sequence[Column[Any]]: + MappedColumn = util.preloaded.orm_properties.MappedColumn + return [ + a.column if isinstance(a, MappedColumn) else a + for a in self.attrs + if isinstance(a, (schema.Column, MappedColumn)) + ] + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_CC]]: + return self + + @property + def columns_to_assign(self) -> List[Tuple[schema.Column[Any], int]]: + return [(c, 0) for c in self.columns if c.table is None] + + @util.preload_module("orm.properties") + def _setup_arguments_on_columns(self) -> None: + """Propagate configuration arguments made on this composite + to the target columns, for those that apply. + + """ + ColumnProperty = util.preloaded.orm_properties.ColumnProperty + + for prop in self.props: + if not isinstance(prop, ColumnProperty): + continue + else: + cprop = prop + + cprop.active_history = self.active_history + if self.deferred: + cprop.deferred = self.deferred + cprop.strategy_key = (("deferred", True), ("instrument", True)) + cprop.group = self.group + + def _setup_event_handlers(self) -> None: + """Establish events that populate/expire the composite attribute.""" + + def load_handler( + state: InstanceState[Any], context: ORMCompileState + ) -> None: + _load_refresh_handler(state, context, None, is_refresh=False) + + def refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + ) -> None: + # note this corresponds to sqlalchemy.ext.mutable load_attrs() + + if not to_load or ( + {self.key}.union(self._attribute_keys) + ).intersection(to_load): + _load_refresh_handler(state, context, to_load, is_refresh=True) + + def _load_refresh_handler( + state: InstanceState[Any], + context: ORMCompileState, + to_load: Optional[Sequence[str]], + is_refresh: bool, + ) -> None: + dict_ = state.dict + + # if context indicates we are coming from the + # fget() handler, this already set the value; skip the + # handler here. (other handlers like mutablecomposite will still + # want to catch it) + # there's an insufficiency here in that the fget() handler + # really should not be using the refresh event and there should + # be some other event that mutablecomposite can subscribe + # towards for this. + + if ( + not is_refresh or context is self._COMPOSITE_FGET + ) and self.key in dict_: + return + + # if column elements aren't loaded, skip. + # __get__() will initiate a load for those + # columns + for k in self._attribute_keys: + if k not in dict_: + return + + dict_[self.key] = self.composite_class( + *[state.dict[key] for key in self._attribute_keys] + ) + + def expire_handler( + state: InstanceState[Any], keys: Optional[Sequence[str]] + ) -> None: + if keys is None or set(self._attribute_keys).intersection(keys): + state.dict.pop(self.key, None) + + def insert_update_handler( + mapper: Mapper[Any], + connection: Connection, + state: InstanceState[Any], + ) -> None: + """After an insert or update, some columns may be expired due + to server side defaults, or re-populated due to client side + defaults. Pop out the composite value here so that it + recreates. + + """ + + state.dict.pop(self.key, None) + + event.listen( + self.parent, "after_insert", insert_update_handler, raw=True + ) + event.listen( + self.parent, "after_update", insert_update_handler, raw=True + ) + event.listen( + self.parent, "load", load_handler, raw=True, propagate=True + ) + event.listen( + self.parent, "refresh", refresh_handler, raw=True, propagate=True + ) + event.listen( + self.parent, "expire", expire_handler, raw=True, propagate=True + ) + + proxy_attr = self.parent.class_manager[self.key] + proxy_attr.impl.dispatch = proxy_attr.dispatch # type: ignore + proxy_attr.impl.dispatch._active_history = self.active_history + + # TODO: need a deserialize hook here + + @util.memoized_property + def _attribute_keys(self) -> Sequence[str]: + return [prop.key for prop in self.props] + + def _populate_composite_bulk_save_mappings_fn( + self, + ) -> Callable[[Dict[str, Any]], None]: + if self._generated_composite_accessor: + get_values = self._generated_composite_accessor + else: + + def get_values(val: Any) -> Tuple[Any]: + return val.__composite_values__() # type: ignore + + attrs = [prop.key for prop in self.props] + + def populate(dest_dict: Dict[str, Any]) -> None: + dest_dict.update( + { + key: val + for key, val in zip( + attrs, get_values(dest_dict.pop(self.key)) + ) + } + ) + + return populate + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + """Provided for userland code that uses attributes.get_history().""" + + added: List[Any] = [] + deleted: List[Any] = [] + + has_history = False + for prop in self.props: + key = prop.key + hist = state.manager[key].impl.get_history(state, dict_) + if hist.has_changes(): + has_history = True + + non_deleted = hist.non_deleted() + if non_deleted: + added.extend(non_deleted) + else: + added.append(None) + if hist.deleted: + deleted.extend(hist.deleted) + else: + deleted.append(None) + + if has_history: + return attributes.History( + [self.composite_class(*added)], + (), + [self.composite_class(*deleted)], + ) + else: + return attributes.History((), [self.composite_class(*added)], ()) + + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Composite.Comparator[_CC]: + return self.comparator_factory(self, mapper) + + class CompositeBundle(orm_util.Bundle[_T]): + def __init__( + self, + property_: Composite[_T], + expr: ClauseList, + ): + self.property = property_ + super().__init__(property_.key, *expr) + + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: + def proc(row: Row[Any]) -> Any: + return self.property.composite_class( + *[proc(row) for proc in procs] + ) + + return proc + + class Comparator(PropComparator[_PT]): + """Produce boolean, comparison, and other operators for + :class:`.Composite` attributes. + + See the example in :ref:`composite_operations` for an overview + of usage , as well as the documentation for :class:`.PropComparator`. + + .. seealso:: + + :class:`.PropComparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore + + prop: RODescriptorReference[Composite[_PT]] + + @util.memoized_property + def clauses(self) -> ClauseList: + return expression.ClauseList( + group=False, *self._comparable_elements + ) + + def __clause_element__(self) -> CompositeProperty.CompositeBundle[_PT]: + return self.expression + + @util.memoized_property + def expression(self) -> CompositeProperty.CompositeBundle[_PT]: + clauses = self.clauses._annotate( + { + "parententity": self._parententity, + "parentmapper": self._parententity, + "proxy_key": self.prop.key, + } + ) + return CompositeProperty.CompositeBundle(self.prop, clauses) + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + if isinstance(value, BindParameter): + value = value.value + + values: Sequence[Any] + + if value is None: + values = [None for key in self.prop._attribute_keys] + elif isinstance(self.prop.composite_class, type) and isinstance( + value, self.prop.composite_class + ): + values = self.prop._composite_values_from_instance(value) + else: + raise sa_exc.ArgumentError( + "Can't UPDATE composite attribute %s to %r" + % (self.prop, value) + ) + + return list(zip(self._comparable_elements, values)) + + @util.memoized_property + def _comparable_elements(self) -> Sequence[QueryableAttribute[Any]]: + if self._adapt_to_entity: + return [ + getattr(self._adapt_to_entity.entity, prop.key) + for prop in self.prop._comparable_elements + ] + else: + return self.prop._comparable_elements + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.eq, other) + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + return self._compare(operators.ne, other) + + def __lt__(self, other: Any) -> ColumnElement[bool]: + return self._compare(operators.lt, other) + + def __gt__(self, other: Any) -> ColumnElement[bool]: + return self._compare(operators.gt, other) + + def __le__(self, other: Any) -> ColumnElement[bool]: + return self._compare(operators.le, other) + + def __ge__(self, other: Any) -> ColumnElement[bool]: + return self._compare(operators.ge, other) + + # what might be interesting would be if we create + # an instance of the composite class itself with + # the columns as data members, then use "hybrid style" comparison + # to create these comparisons. then your Point.__eq__() method could + # be where comparison behavior is defined for SQL also. Likely + # not a good choice for default behavior though, not clear how it would + # work w/ dataclasses, etc. also no demand for any of this anyway. + def _compare( + self, operator: OperatorType, other: Any + ) -> ColumnElement[bool]: + values: Sequence[Any] + if other is None: + values = [None] * len(self.prop._comparable_elements) + else: + values = self.prop._composite_values_from_instance(other) + comparisons = [ + operator(a, b) + for a, b in zip(self.prop._comparable_elements, values) + ] + if self._adapt_to_entity: + assert self.adapter is not None + comparisons = [self.adapter(x) for x in comparisons] + return sql.and_(*comparisons) + + def __str__(self) -> str: + return str(self.parent.class_.__name__) + "." + self.key + + +class Composite(CompositeProperty[_T], _DeclarativeMapped[_T]): + """Declarative-compatible front-end for the :class:`.CompositeProperty` + class. + + Public constructor is the :func:`_orm.composite` function. + + .. versionchanged:: 2.0 Added :class:`_orm.Composite` as a Declarative + compatible subclass of :class:`_orm.CompositeProperty`. + + .. seealso:: + + :ref:`mapper_composite` + + """ + + inherit_cache = True + """:meta private:""" + + +class ConcreteInheritedProperty(DescriptorProperty[_T]): + """A 'do nothing' :class:`.MapperProperty` that disables + an attribute on a concrete subclass that is only present + on the inherited mapper, not the concrete classes' mapper. + + Cases where this occurs include: + + * When the superclass mapper is mapped against a + "polymorphic union", which includes all attributes from + all subclasses. + * When a relationship() is configured on an inherited mapper, + but not on the subclass mapper. Concrete mappers require + that relationship() is configured explicitly on each + subclass. + + """ + + def _comparator_factory( + self, mapper: Mapper[Any] + ) -> Type[PropComparator[_T]]: + comparator_callable = None + + for m in self.parent.iterate_to_root(): + p = m._props[self.key] + if getattr(p, "comparator_factory", None) is not None: + comparator_callable = p.comparator_factory + break + assert comparator_callable is not None + return comparator_callable(p, mapper) # type: ignore + + def __init__(self) -> None: + super().__init__() + + def warn() -> NoReturn: + raise AttributeError( + "Concrete %s does not implement " + "attribute %r at the instance level. Add " + "this property explicitly to %s." + % (self.parent, self.key, self.parent) + ) + + class NoninheritedConcreteProp: + def __set__(s: Any, obj: Any, value: Any) -> NoReturn: + warn() + + def __delete__(s: Any, obj: Any) -> NoReturn: + warn() + + def __get__(s: Any, obj: Any, owner: Any) -> Any: + if obj is None: + return self.descriptor + warn() + + self.descriptor = NoninheritedConcreteProp() + + +class SynonymProperty(DescriptorProperty[_T]): + """Denote an attribute name as a synonym to a mapped property, + in that the attribute will mirror the value and expression behavior + of another attribute. + + :class:`.Synonym` is constructed using the :func:`_orm.synonym` + function. + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + """ + + comparator_factory: Optional[Type[PropComparator[_T]]] + + def __init__( + self, + name: str, + map_column: Optional[bool] = None, + descriptor: Optional[Any] = None, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + attribute_options: Optional[_AttributeOptions] = None, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + ): + super().__init__(attribute_options=attribute_options) + + self.name = name + self.map_column = map_column + self.descriptor = descriptor + self.comparator_factory = comparator_factory + if doc: + self.doc = doc + elif descriptor and descriptor.__doc__: + self.doc = descriptor.__doc__ + else: + self.doc = None + if info: + self.info.update(info) + + util.set_creation_order(self) + + if not TYPE_CHECKING: + + @property + def uses_objects(self) -> bool: + return getattr(self.parent.class_, self.name).impl.uses_objects + + # TODO: when initialized, check _proxied_object, + # emit a warning if its not a column-based property + + @util.memoized_property + def _proxied_object( + self, + ) -> Union[MapperProperty[_T], SQLORMOperations[_T]]: + attr = getattr(self.parent.class_, self.name) + if not hasattr(attr, "property") or not isinstance( + attr.property, MapperProperty + ): + # attribute is a non-MapperProprerty proxy such as + # hybrid or association proxy + if isinstance(attr, attributes.QueryableAttribute): + return attr.comparator + elif isinstance(attr, SQLORMOperations): + # assocaition proxy comes here + return attr + + raise sa_exc.InvalidRequestError( + """synonym() attribute "%s.%s" only supports """ + """ORM mapped attributes, got %r""" + % (self.parent.class_.__name__, self.name, attr) + ) + return attr.property + + def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]: + prop = self._proxied_object + + if isinstance(prop, MapperProperty): + if self.comparator_factory: + comp = self.comparator_factory(prop, mapper) + else: + comp = prop.comparator_factory(prop, mapper) + return comp + else: + return prop + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> History: + attr: QueryableAttribute[Any] = getattr(self.parent.class_, self.name) + return attr.impl.get_history(state, dict_, passive=passive) + + @util.preload_module("sqlalchemy.orm.properties") + def set_parent(self, parent: Mapper[Any], init: bool) -> None: + properties = util.preloaded.orm_properties + + if self.map_column: + # implement the 'map_column' option. + if self.key not in parent.persist_selectable.c: + raise sa_exc.ArgumentError( + "Can't compile synonym '%s': no column on table " + "'%s' named '%s'" + % ( + self.name, + parent.persist_selectable.description, + self.key, + ) + ) + elif ( + parent.persist_selectable.c[self.key] + in parent._columntoproperty + and parent._columntoproperty[ + parent.persist_selectable.c[self.key] + ].key + == self.name + ): + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" + % (self.key, self.name, self.name, self.key) + ) + p: ColumnProperty[Any] = properties.ColumnProperty( + parent.persist_selectable.c[self.key] + ) + parent._configure_property(self.name, p, init=init, setparent=True) + p._mapped_by_synonym = self.key + + self.parent = parent + + +class Synonym(SynonymProperty[_T], _DeclarativeMapped[_T]): + """Declarative front-end for the :class:`.SynonymProperty` class. + + Public constructor is the :func:`_orm.synonym` function. + + .. versionchanged:: 2.0 Added :class:`_orm.Synonym` as a Declarative + compatible subclass for :class:`_orm.SynonymProperty` + + .. seealso:: + + :ref:`synonyms` - Overview of synonyms + + """ + + inherit_cache = True + """:meta private:""" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/dynamic.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/dynamic.py new file mode 100644 index 0000000..7496e5c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/dynamic.py @@ -0,0 +1,298 @@ +# orm/dynamic.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +"""Dynamic collection API. + +Dynamic collections act like Query() objects for read operations and support +basic add/delete mutation. + +.. legacy:: the "dynamic" loader is a legacy feature, superseded by the + "write_only" loader. + + +""" + +from __future__ import annotations + +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import exc as orm_exc +from . import relationships +from . import util as orm_util +from .base import PassiveFlag +from .query import Query +from .session import object_session +from .writeonly import AbstractCollectionWriter +from .writeonly import WriteOnlyAttributeImpl +from .writeonly import WriteOnlyHistory +from .writeonly import WriteOnlyLoader +from .. import util +from ..engine import result + + +if TYPE_CHECKING: + from . import QueryableAttribute + from .mapper import Mapper + from .relationships import _RelationshipOrderByArg + from .session import Session + from .state import InstanceState + from .util import AliasedClass + from ..event import _Dispatch + from ..sql.elements import ColumnElement + +_T = TypeVar("_T", bound=Any) + + +class DynamicCollectionHistory(WriteOnlyHistory[_T]): + def __init__( + self, + attr: DynamicAttributeImpl, + state: InstanceState[_T], + passive: PassiveFlag, + apply_to: Optional[DynamicCollectionHistory[_T]] = None, + ) -> None: + if apply_to: + coll = AppenderQuery(attr, state).autoflush(False) + self.unchanged_items = util.OrderedIdentitySet(coll) + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items + self._reconcile_collection = True + else: + self.deleted_items = util.OrderedIdentitySet() + self.added_items = util.OrderedIdentitySet() + self.unchanged_items = util.OrderedIdentitySet() + self._reconcile_collection = False + + +class DynamicAttributeImpl(WriteOnlyAttributeImpl): + _supports_dynamic_iteration = True + collection_history_cls = DynamicCollectionHistory[Any] + query_class: Type[AppenderMixin[Any]] # type: ignore[assignment] + + def __init__( + self, + class_: Union[Type[Any], AliasedClass[Any]], + key: str, + dispatch: _Dispatch[QueryableAttribute[Any]], + target_mapper: Mapper[_T], + order_by: _RelationshipOrderByArg, + query_class: Optional[Type[AppenderMixin[_T]]] = None, + **kw: Any, + ) -> None: + attributes.AttributeImpl.__init__( + self, class_, key, None, dispatch, **kw + ) + self.target_mapper = target_mapper + if order_by: + self.order_by = tuple(order_by) + if not query_class: + self.query_class = AppenderQuery + elif AppenderMixin in query_class.mro(): + self.query_class = query_class + else: + self.query_class = mixin_user_query(query_class) + + +@relationships.RelationshipProperty.strategy_for(lazy="dynamic") +class DynaLoader(WriteOnlyLoader): + impl_class = DynamicAttributeImpl + + +class AppenderMixin(AbstractCollectionWriter[_T]): + """A mixin that expects to be mixing in a Query class with + AbstractAppender. + + + """ + + query_class: Optional[Type[Query[_T]]] = None + _order_by_clauses: Tuple[ColumnElement[Any], ...] + + def __init__( + self, attr: DynamicAttributeImpl, state: InstanceState[_T] + ) -> None: + Query.__init__( + self, # type: ignore[arg-type] + attr.target_mapper, + None, + ) + super().__init__(attr, state) + + @property + def session(self) -> Optional[Session]: + sess = object_session(self.instance) + if sess is not None and sess.autoflush and self.instance in sess: + sess.flush() + if not orm_util.has_identity(self.instance): + return None + else: + return sess + + @session.setter + def session(self, session: Session) -> None: + self.sess = session + + def _iter(self) -> Union[result.ScalarResult[_T], result.Result[_T]]: + sess = self.session + if sess is None: + state = attributes.instance_state(self.instance) + if state.detached: + util.warn( + "Instance %s is detached, dynamic relationship cannot " + "return a correct result. This warning will become " + "a DetachedInstanceError in a future release." + % (orm_util.state_str(state)) + ) + + return result.IteratorResult( + result.SimpleResultMetaData([self.attr.class_.__name__]), + self.attr._get_collection_history( # type: ignore[arg-type] + attributes.instance_state(self.instance), + PassiveFlag.PASSIVE_NO_INITIALIZE, + ).added_items, + _source_supports_scalars=True, + ).scalars() + else: + return self._generate(sess)._iter() + + if TYPE_CHECKING: + + def __iter__(self) -> Iterator[_T]: ... + + def __getitem__(self, index: Any) -> Union[_T, List[_T]]: + sess = self.session + if sess is None: + return self.attr._get_collection_history( + attributes.instance_state(self.instance), + PassiveFlag.PASSIVE_NO_INITIALIZE, + ).indexed(index) + else: + return self._generate(sess).__getitem__(index) # type: ignore[no-any-return] # noqa: E501 + + def count(self) -> int: + sess = self.session + if sess is None: + return len( + self.attr._get_collection_history( + attributes.instance_state(self.instance), + PassiveFlag.PASSIVE_NO_INITIALIZE, + ).added_items + ) + else: + return self._generate(sess).count() + + def _generate( + self, + sess: Optional[Session] = None, + ) -> Query[_T]: + # note we're returning an entirely new Query class instance + # here without any assignment capabilities; the class of this + # query is determined by the session. + instance = self.instance + if sess is None: + sess = object_session(instance) + if sess is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session, and no " + "contextual session is established; lazy load operation " + "of attribute '%s' cannot proceed" + % (orm_util.instance_str(instance), self.attr.key) + ) + + if self.query_class: + query = self.query_class(self.attr.target_mapper, session=sess) + else: + query = sess.query(self.attr.target_mapper) + + query._where_criteria = self._where_criteria + query._from_obj = self._from_obj + query._order_by_clauses = self._order_by_clauses + + return query + + def add_all(self, iterator: Iterable[_T]) -> None: + """Add an iterable of items to this :class:`_orm.AppenderQuery`. + + The given items will be persisted to the database in terms of + the parent instance's collection on the next flush. + + This method is provided to assist in delivering forwards-compatibility + with the :class:`_orm.WriteOnlyCollection` collection class. + + .. versionadded:: 2.0 + + """ + self._add_all_impl(iterator) + + def add(self, item: _T) -> None: + """Add an item to this :class:`_orm.AppenderQuery`. + + The given item will be persisted to the database in terms of + the parent instance's collection on the next flush. + + This method is provided to assist in delivering forwards-compatibility + with the :class:`_orm.WriteOnlyCollection` collection class. + + .. versionadded:: 2.0 + + """ + self._add_all_impl([item]) + + def extend(self, iterator: Iterable[_T]) -> None: + """Add an iterable of items to this :class:`_orm.AppenderQuery`. + + The given items will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl(iterator) + + def append(self, item: _T) -> None: + """Append an item to this :class:`_orm.AppenderQuery`. + + The given item will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl([item]) + + def remove(self, item: _T) -> None: + """Remove an item from this :class:`_orm.AppenderQuery`. + + The given item will be removed from the parent instance's collection on + the next flush. + + """ + self._remove_impl(item) + + +class AppenderQuery(AppenderMixin[_T], Query[_T]): # type: ignore[misc] + """A dynamic query that supports basic collection storage operations. + + Methods on :class:`.AppenderQuery` include all methods of + :class:`_orm.Query`, plus additional methods used for collection + persistence. + + + """ + + +def mixin_user_query(cls: Any) -> type[AppenderMixin[Any]]: + """Return a new class with AppenderQuery functionality layered over.""" + name = "Appender" + cls.__name__ + return type(name, (AppenderMixin, cls), {"query_class": cls}) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py new file mode 100644 index 0000000..f264454 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/evaluator.py @@ -0,0 +1,368 @@ +# orm/evaluator.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +"""Evaluation functions used **INTERNALLY** by ORM DML use cases. + + +This module is **private, for internal use by SQLAlchemy**. + +.. versionchanged:: 2.0.4 renamed ``EvaluatorCompiler`` to + ``_EvaluatorCompiler``. + +""" + + +from __future__ import annotations + +from typing import Type + +from . import exc as orm_exc +from .base import LoaderCallableStatus +from .base import PassiveFlag +from .. import exc +from .. import inspect +from ..sql import and_ +from ..sql import operators +from ..sql.sqltypes import Integer +from ..sql.sqltypes import Numeric +from ..util import warn_deprecated + + +class UnevaluatableError(exc.InvalidRequestError): + pass + + +class _NoObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return None + + def reverse_operate(self, *arg, **kw): + return None + + +class _ExpiredObject(operators.ColumnOperators): + def operate(self, *arg, **kw): + return self + + def reverse_operate(self, *arg, **kw): + return self + + +_NO_OBJECT = _NoObject() +_EXPIRED_OBJECT = _ExpiredObject() + + +class _EvaluatorCompiler: + def __init__(self, target_cls=None): + self.target_cls = target_cls + + def process(self, clause, *clauses): + if clauses: + clause = and_(clause, *clauses) + + meth = getattr(self, f"visit_{clause.__visit_name__}", None) + if not meth: + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__}" + ) + return meth(clause) + + def visit_grouping(self, clause): + return self.process(clause.element) + + def visit_null(self, clause): + return lambda obj: None + + def visit_false(self, clause): + return lambda obj: False + + def visit_true(self, clause): + return lambda obj: True + + def visit_column(self, clause): + try: + parentmapper = clause._annotations["parentmapper"] + except KeyError as ke: + raise UnevaluatableError( + f"Cannot evaluate column: {clause}" + ) from ke + + if self.target_cls and not issubclass( + self.target_cls, parentmapper.class_ + ): + raise UnevaluatableError( + "Can't evaluate criteria against " + f"alternate class {parentmapper.class_}" + ) + + parentmapper._check_configure() + + # we'd like to use "proxy_key" annotation to get the "key", however + # in relationship primaryjoin cases proxy_key is sometimes deannotated + # and sometimes apparently not present in the first place (?). + # While I can stop it from being deannotated (though need to see if + # this breaks other things), not sure right now about cases where it's + # not there in the first place. can fix at some later point. + # key = clause._annotations["proxy_key"] + + # for now, use the old way + try: + key = parentmapper._columntoproperty[clause].key + except orm_exc.UnmappedColumnError as err: + raise UnevaluatableError( + f"Cannot evaluate expression: {err}" + ) from err + + # note this used to fall back to a simple `getattr(obj, key)` evaluator + # if impl was None; as of #8656, we ensure mappers are configured + # so that impl is available + impl = parentmapper.class_manager[key].impl + + def get_corresponding_attr(obj): + if obj is None: + return _NO_OBJECT + state = inspect(obj) + dict_ = state.dict + + value = impl.get( + state, dict_, passive=PassiveFlag.PASSIVE_NO_FETCH + ) + if value is LoaderCallableStatus.PASSIVE_NO_RESULT: + return _EXPIRED_OBJECT + return value + + return get_corresponding_attr + + def visit_tuple(self, clause): + return self.visit_clauselist(clause) + + def visit_expression_clauselist(self, clause): + return self.visit_clauselist(clause) + + def visit_clauselist(self, clause): + evaluators = [self.process(clause) for clause in clause.clauses] + + dispatch = ( + f"visit_{clause.operator.__name__.rstrip('_')}_clauselist_op" + ) + meth = getattr(self, dispatch, None) + if meth: + return meth(clause.operator, evaluators, clause) + else: + raise UnevaluatableError( + f"Cannot evaluate clauselist with operator {clause.operator}" + ) + + def visit_binary(self, clause): + eval_left = self.process(clause.left) + eval_right = self.process(clause.right) + + dispatch = f"visit_{clause.operator.__name__.rstrip('_')}_binary_op" + meth = getattr(self, dispatch, None) + if meth: + return meth(clause.operator, eval_left, eval_right, clause) + else: + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__} with " + f"operator {clause.operator}" + ) + + def visit_or_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + has_null = False + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value: + return True + has_null = has_null or value is None + if has_null: + return None + return False + + return evaluate + + def visit_and_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + + if not value: + if value is None or value is _NO_OBJECT: + return None + return False + return True + + return evaluate + + def visit_comma_op_clauselist_op(self, operator, evaluators, clause): + def evaluate(obj): + values = [] + for sub_evaluate in evaluators: + value = sub_evaluate(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None or value is _NO_OBJECT: + return None + values.append(value) + return tuple(values) + + return evaluate + + def visit_custom_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + if operator.python_impl: + return self._straight_evaluate( + operator, eval_left, eval_right, clause + ) + else: + raise UnevaluatableError( + f"Custom operator {operator.opstring!r} can't be evaluated " + "in Python unless it specifies a callable using " + "`.python_impl`." + ) + + def visit_is_binary_op(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val == right_val + + return evaluate + + def visit_is_not_binary_op(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + return left_val != right_val + + return evaluate + + def _straight_evaluate(self, operator, eval_left, eval_right, clause): + def evaluate(obj): + left_val = eval_left(obj) + right_val = eval_right(obj) + if left_val is _EXPIRED_OBJECT or right_val is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif left_val is None or right_val is None: + return None + + return operator(eval_left(obj), eval_right(obj)) + + return evaluate + + def _straight_evaluate_numeric_only( + self, operator, eval_left, eval_right, clause + ): + if clause.left.type._type_affinity not in ( + Numeric, + Integer, + ) or clause.right.type._type_affinity not in (Numeric, Integer): + raise UnevaluatableError( + f'Cannot evaluate math operator "{operator.__name__}" for ' + f"datatypes {clause.left.type}, {clause.right.type}" + ) + + return self._straight_evaluate(operator, eval_left, eval_right, clause) + + visit_add_binary_op = _straight_evaluate_numeric_only + visit_mul_binary_op = _straight_evaluate_numeric_only + visit_sub_binary_op = _straight_evaluate_numeric_only + visit_mod_binary_op = _straight_evaluate_numeric_only + visit_truediv_binary_op = _straight_evaluate_numeric_only + visit_lt_binary_op = _straight_evaluate + visit_le_binary_op = _straight_evaluate + visit_ne_binary_op = _straight_evaluate + visit_gt_binary_op = _straight_evaluate + visit_ge_binary_op = _straight_evaluate + visit_eq_binary_op = _straight_evaluate + + def visit_in_op_binary_op(self, operator, eval_left, eval_right, clause): + return self._straight_evaluate( + lambda a, b: a in b if a is not _NO_OBJECT else None, + eval_left, + eval_right, + clause, + ) + + def visit_not_in_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a not in b if a is not _NO_OBJECT else None, + eval_left, + eval_right, + clause, + ) + + def visit_concat_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a + b, eval_left, eval_right, clause + ) + + def visit_startswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a.startswith(b), eval_left, eval_right, clause + ) + + def visit_endswith_op_binary_op( + self, operator, eval_left, eval_right, clause + ): + return self._straight_evaluate( + lambda a, b: a.endswith(b), eval_left, eval_right, clause + ) + + def visit_unary(self, clause): + eval_inner = self.process(clause.element) + if clause.operator is operators.inv: + + def evaluate(obj): + value = eval_inner(obj) + if value is _EXPIRED_OBJECT: + return _EXPIRED_OBJECT + elif value is None: + return None + return not value + + return evaluate + raise UnevaluatableError( + f"Cannot evaluate {type(clause).__name__} " + f"with operator {clause.operator}" + ) + + def visit_bindparam(self, clause): + if clause.callable: + val = clause.callable() + else: + val = clause.value + return lambda obj: val + + +def __getattr__(name: str) -> Type[_EvaluatorCompiler]: + if name == "EvaluatorCompiler": + warn_deprecated( + "Direct use of 'EvaluatorCompiler' is not supported, and this " + "name will be removed in a future release. " + "'_EvaluatorCompiler' is for internal use only", + "2.0", + ) + return _EvaluatorCompiler + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/events.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/events.py new file mode 100644 index 0000000..1cd51bf --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/events.py @@ -0,0 +1,3259 @@ +# orm/events.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""ORM event interfaces. + +""" +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import instrumentation +from . import interfaces +from . import mapperlib +from .attributes import QueryableAttribute +from .base import _mapper_or_none +from .base import NO_KEY +from .instrumentation import ClassManager +from .instrumentation import InstrumentationFactory +from .query import BulkDelete +from .query import BulkUpdate +from .query import Query +from .scoping import scoped_session +from .session import Session +from .session import sessionmaker +from .. import event +from .. import exc +from .. import util +from ..event import EventTarget +from ..event.registry import _ET +from ..util.compat import inspect_getfullargspec + +if TYPE_CHECKING: + from weakref import ReferenceType + + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _O + from ._typing import _T + from .attributes import Event + from .base import EventConstants + from .session import ORMExecuteState + from .session import SessionTransaction + from .unitofwork import UOWTransaction + from ..engine import Connection + from ..event.base import _Dispatch + from ..event.base import _HasEventsDispatch + from ..event.registry import _EventKey + from ..orm.collections import CollectionAdapter + from ..orm.context import QueryContext + from ..orm.decl_api import DeclarativeAttributeIntercept + from ..orm.decl_api import DeclarativeMeta + from ..orm.mapper import Mapper + from ..orm.state import InstanceState + +_KT = TypeVar("_KT", bound=Any) +_ET2 = TypeVar("_ET2", bound=EventTarget) + + +class InstrumentationEvents(event.Events[InstrumentationFactory]): + """Events related to class instrumentation events. + + The listeners here support being established against + any new style class, that is any object that is a subclass + of 'type'. Events will then be fired off for events + against that class. If the "propagate=True" flag is passed + to event.listen(), the event will fire off for subclasses + of that class as well. + + The Python ``type`` builtin is also accepted as a target, + which when used has the effect of events being emitted + for all classes. + + Note the "propagate" flag here is defaulted to ``True``, + unlike the other class level events where it defaults + to ``False``. This means that new subclasses will also + be the subject of these events, when a listener + is established on a superclass. + + """ + + _target_class_doc = "SomeBaseClass" + _dispatch_target = InstrumentationFactory + + @classmethod + def _accept_with( + cls, + target: Union[ + InstrumentationFactory, + Type[InstrumentationFactory], + ], + identifier: str, + ) -> Optional[ + Union[ + InstrumentationFactory, + Type[InstrumentationFactory], + ] + ]: + if isinstance(target, type): + return _InstrumentationEventsHold(target) # type: ignore [return-value] # noqa: E501 + else: + return None + + @classmethod + def _listen( + cls, event_key: _EventKey[_T], propagate: bool = True, **kw: Any + ) -> None: + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) + + def listen(target_cls: type, *arg: Any) -> Optional[Any]: + listen_cls = target() + + # if weakref were collected, however this is not something + # that normally happens. it was occurring during test teardown + # between mapper/registry/instrumentation_manager, however this + # interaction was changed to not rely upon the event system. + if listen_cls is None: + return None + + if propagate and issubclass(target_cls, listen_cls): + return fn(target_cls, *arg) + elif not propagate and target_cls is listen_cls: + return fn(target_cls, *arg) + else: + return None + + def remove(ref: ReferenceType[_T]) -> None: + key = event.registry._EventKey( # type: ignore [type-var] + None, + identifier, + listen, + instrumentation._instrumentation_factory, + ) + getattr( + instrumentation._instrumentation_factory.dispatch, identifier + ).remove(key) + + target = weakref.ref(target.class_, remove) + + event_key.with_dispatch_target( + instrumentation._instrumentation_factory + ).with_wrapper(listen).base_listen(**kw) + + @classmethod + def _clear(cls) -> None: + super()._clear() + instrumentation._instrumentation_factory.dispatch._clear() + + def class_instrument(self, cls: ClassManager[_O]) -> None: + """Called after the given class is instrumented. + + To get at the :class:`.ClassManager`, use + :func:`.manager_of_class`. + + """ + + def class_uninstrument(self, cls: ClassManager[_O]) -> None: + """Called before the given class is uninstrumented. + + To get at the :class:`.ClassManager`, use + :func:`.manager_of_class`. + + """ + + def attribute_instrument( + self, cls: ClassManager[_O], key: _KT, inst: _O + ) -> None: + """Called when an attribute is instrumented.""" + + +class _InstrumentationEventsHold: + """temporary marker object used to transfer from _accept_with() to + _listen() on the InstrumentationEvents class. + + """ + + def __init__(self, class_: type) -> None: + self.class_ = class_ + + dispatch = event.dispatcher(InstrumentationEvents) + + +class InstanceEvents(event.Events[ClassManager[Any]]): + """Define events specific to object lifecycle. + + e.g.:: + + from sqlalchemy import event + + def my_load_listener(target, context): + print("on load!") + + event.listen(SomeClass, 'load', my_load_listener) + + Available targets include: + + * mapped classes + * unmapped superclasses of mapped or to-be-mapped classes + (using the ``propagate=True`` flag) + * :class:`_orm.Mapper` objects + * the :class:`_orm.Mapper` class itself indicates listening for all + mappers. + + Instance events are closely related to mapper events, but + are more specific to the instance and its instrumentation, + rather than its system of persistence. + + When using :class:`.InstanceEvents`, several modifiers are + available to the :func:`.event.listen` function. + + :param propagate=False: When True, the event listener should + be applied to all inheriting classes as well as the + class which is the target of this listener. + :param raw=False: When True, the "target" argument passed + to applicable event listener functions will be the + instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param restore_load_context=False: Applies to the + :meth:`.InstanceEvents.load` and :meth:`.InstanceEvents.refresh` + events. Restores the loader context of the object when the event + hook is complete, so that ongoing eager load operations continue + to target the object appropriately. A warning is emitted if the + object is moved to a new loader context from within one of these + events if this flag is not set. + + .. versionadded:: 1.3.14 + + + """ + + _target_class_doc = "SomeClass" + + _dispatch_target = ClassManager + + @classmethod + def _new_classmanager_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + classmanager: ClassManager[_O], + ) -> None: + _InstanceEventsHold.populate(class_, classmanager) + + @classmethod + @util.preload_module("sqlalchemy.orm") + def _accept_with( + cls, + target: Union[ + ClassManager[Any], + Type[ClassManager[Any]], + ], + identifier: str, + ) -> Optional[Union[ClassManager[Any], Type[ClassManager[Any]]]]: + orm = util.preloaded.orm + + if isinstance(target, ClassManager): + return target + elif isinstance(target, mapperlib.Mapper): + return target.class_manager + elif target is orm.mapper: # type: ignore [attr-defined] + util.warn_deprecated( + "The `sqlalchemy.orm.mapper()` symbol is deprecated and " + "will be removed in a future release. For the mapper-wide " + "event target, use the 'sqlalchemy.orm.Mapper' class.", + "2.0", + ) + return ClassManager + elif isinstance(target, type): + if issubclass(target, mapperlib.Mapper): + return ClassManager + else: + manager = instrumentation.opt_manager_of_class(target) + if manager: + return manager + else: + return _InstanceEventsHold(target) # type: ignore [return-value] # noqa: E501 + return None + + @classmethod + def _listen( + cls, + event_key: _EventKey[ClassManager[Any]], + raw: bool = False, + propagate: bool = False, + restore_load_context: bool = False, + **kw: Any, + ) -> None: + target, fn = (event_key.dispatch_target, event_key._listen_fn) + + if not raw or restore_load_context: + + def wrap( + state: InstanceState[_O], *arg: Any, **kw: Any + ) -> Optional[Any]: + if not raw: + target: Any = state.obj() + else: + target = state + if restore_load_context: + runid = state.runid + try: + return fn(target, *arg, **kw) + finally: + if restore_load_context: + state.runid = runid + + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(propagate=propagate, **kw) + + if propagate: + for mgr in target.subclass_managers(True): + event_key.with_dispatch_target(mgr).base_listen(propagate=True) + + @classmethod + def _clear(cls) -> None: + super()._clear() + _InstanceEventsHold._clear() + + def first_init(self, manager: ClassManager[_O], cls: Type[_O]) -> None: + """Called when the first instance of a particular mapping is called. + + This event is called when the ``__init__`` method of a class + is called the first time for that particular class. The event + invokes before ``__init__`` actually proceeds as well as before + the :meth:`.InstanceEvents.init` event is invoked. + + """ + + def init(self, target: _O, args: Any, kwargs: Any) -> None: + """Receive an instance when its constructor is called. + + This method is only called during a userland construction of + an object, in conjunction with the object's constructor, e.g. + its ``__init__`` method. It is not called when an object is + loaded from the database; see the :meth:`.InstanceEvents.load` + event in order to intercept a database load. + + The event is called before the actual ``__init__`` constructor + of the object is called. The ``kwargs`` dictionary may be + modified in-place in order to affect what is passed to + ``__init__``. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param args: positional arguments passed to the ``__init__`` method. + This is passed as a tuple and is currently immutable. + :param kwargs: keyword arguments passed to the ``__init__`` method. + This structure *can* be altered in place. + + .. seealso:: + + :meth:`.InstanceEvents.init_failure` + + :meth:`.InstanceEvents.load` + + """ + + def init_failure(self, target: _O, args: Any, kwargs: Any) -> None: + """Receive an instance when its constructor has been called, + and raised an exception. + + This method is only called during a userland construction of + an object, in conjunction with the object's constructor, e.g. + its ``__init__`` method. It is not called when an object is loaded + from the database. + + The event is invoked after an exception raised by the ``__init__`` + method is caught. After the event + is invoked, the original exception is re-raised outwards, so that + the construction of the object still raises an exception. The + actual exception and stack trace raised should be present in + ``sys.exc_info()``. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param args: positional arguments that were passed to the ``__init__`` + method. + :param kwargs: keyword arguments that were passed to the ``__init__`` + method. + + .. seealso:: + + :meth:`.InstanceEvents.init` + + :meth:`.InstanceEvents.load` + + """ + + def _sa_event_merge_wo_load( + self, target: _O, context: QueryContext + ) -> None: + """receive an object instance after it was the subject of a merge() + call, when load=False was passed. + + The target would be the already-loaded object in the Session which + would have had its attributes overwritten by the incoming object. This + overwrite operation does not use attribute events, instead just + populating dict directly. Therefore the purpose of this event is so + that extensions like sqlalchemy.ext.mutable know that object state has + changed and incoming state needs to be set up for "parents" etc. + + This functionality is acceptable to be made public in a later release. + + .. versionadded:: 1.4.41 + + """ + + def load(self, target: _O, context: QueryContext) -> None: + """Receive an object instance after it has been created via + ``__new__``, and after initial attribute population has + occurred. + + This typically occurs when the instance is created based on + incoming result rows, and is only called once for that + instance's lifetime. + + .. warning:: + + During a result-row load, this event is invoked when the + first row received for this instance is processed. When using + eager loading with collection-oriented attributes, the additional + rows that are to be loaded / processed in order to load subsequent + collection items have not occurred yet. This has the effect + both that collections will not be fully loaded, as well as that + if an operation occurs within this event handler that emits + another database load operation for the object, the "loading + context" for the object can change and interfere with the + existing eager loaders still in progress. + + Examples of what can cause the "loading context" to change within + the event handler include, but are not necessarily limited to: + + * accessing deferred attributes that weren't part of the row, + will trigger an "undefer" operation and refresh the object + + * accessing attributes on a joined-inheritance subclass that + weren't part of the row, will trigger a refresh operation. + + As of SQLAlchemy 1.3.14, a warning is emitted when this occurs. The + :paramref:`.InstanceEvents.restore_load_context` option may be + used on the event to prevent this warning; this will ensure that + the existing loading context is maintained for the object after the + event is called:: + + @event.listens_for( + SomeClass, "load", restore_load_context=True) + def on_load(instance, context): + instance.some_unloaded_attribute + + .. versionchanged:: 1.3.14 Added + :paramref:`.InstanceEvents.restore_load_context` + and :paramref:`.SessionEvents.restore_load_context` flags which + apply to "on load" events, which will ensure that the loading + context for an object is restored when the event hook is + complete; a warning is emitted if the load context of the object + changes without this flag being set. + + + The :meth:`.InstanceEvents.load` event is also available in a + class-method decorator format called :func:`_orm.reconstructor`. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`_query.Query` in progress. This argument may be + ``None`` if the load does not correspond to a :class:`_query.Query`, + such as during :meth:`.Session.merge`. + + .. seealso:: + + :ref:`mapped_class_load_events` + + :meth:`.InstanceEvents.init` + + :meth:`.InstanceEvents.refresh` + + :meth:`.SessionEvents.loaded_as_persistent` + + """ + + def refresh( + self, target: _O, context: QueryContext, attrs: Optional[Iterable[str]] + ) -> None: + """Receive an object instance after one or more attributes have + been refreshed from a query. + + Contrast this to the :meth:`.InstanceEvents.load` method, which + is invoked when the object is first loaded from a query. + + .. note:: This event is invoked within the loader process before + eager loaders may have been completed, and the object's state may + not be complete. Additionally, invoking row-level refresh + operations on the object will place the object into a new loader + context, interfering with the existing load context. See the note + on :meth:`.InstanceEvents.load` for background on making use of the + :paramref:`.InstanceEvents.restore_load_context` parameter, in + order to resolve this scenario. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param context: the :class:`.QueryContext` corresponding to the + current :class:`_query.Query` in progress. + :param attrs: sequence of attribute names which + were populated, or None if all column-mapped, non-deferred + attributes were populated. + + .. seealso:: + + :ref:`mapped_class_load_events` + + :meth:`.InstanceEvents.load` + + """ + + def refresh_flush( + self, + target: _O, + flush_context: UOWTransaction, + attrs: Optional[Iterable[str]], + ) -> None: + """Receive an object instance after one or more attributes that + contain a column-level default or onupdate handler have been refreshed + during persistence of the object's state. + + This event is the same as :meth:`.InstanceEvents.refresh` except + it is invoked within the unit of work flush process, and includes + only non-primary-key columns that have column level default or + onupdate handlers, including Python callables as well as server side + defaults and triggers which may be fetched via the RETURNING clause. + + .. note:: + + While the :meth:`.InstanceEvents.refresh_flush` event is triggered + for an object that was INSERTed as well as for an object that was + UPDATEd, the event is geared primarily towards the UPDATE process; + it is mostly an internal artifact that INSERT actions can also + trigger this event, and note that **primary key columns for an + INSERTed row are explicitly omitted** from this event. In order to + intercept the newly INSERTed state of an object, the + :meth:`.SessionEvents.pending_to_persistent` and + :meth:`.MapperEvents.after_insert` are better choices. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + :param attrs: sequence of attribute names which + were populated. + + .. seealso:: + + :ref:`mapped_class_load_events` + + :ref:`orm_server_defaults` + + :ref:`metadata_defaults_toplevel` + + """ + + def expire(self, target: _O, attrs: Optional[Iterable[str]]) -> None: + """Receive an object instance after its attributes or some subset + have been expired. + + 'keys' is a list of attribute names. If None, the entire + state was expired. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param attrs: sequence of attribute + names which were expired, or None if all attributes were + expired. + + """ + + def pickle(self, target: _O, state_dict: _InstanceDict) -> None: + """Receive an object instance when its associated state is + being pickled. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param state_dict: the dictionary returned by + :class:`.InstanceState.__getstate__`, containing the state + to be pickled. + + """ + + def unpickle(self, target: _O, state_dict: _InstanceDict) -> None: + """Receive an object instance after its associated state has + been unpickled. + + :param target: the mapped instance. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :param state_dict: the dictionary sent to + :class:`.InstanceState.__setstate__`, containing the state + dictionary which was pickled. + + """ + + +class _EventsHold(event.RefCollection[_ET]): + """Hold onto listeners against unmapped, uninstrumented classes. + + Establish _listen() for that class' mapper/instrumentation when + those objects are created for that class. + + """ + + all_holds: weakref.WeakKeyDictionary[Any, Any] + + def __init__( + self, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + ) -> None: + self.class_ = class_ + + @classmethod + def _clear(cls) -> None: + cls.all_holds.clear() + + class HoldEvents(Generic[_ET2]): + _dispatch_target: Optional[Type[_ET2]] = None + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET2], + raw: bool = False, + propagate: bool = False, + retval: bool = False, + **kw: Any, + ) -> None: + target = event_key.dispatch_target + + if target.class_ in target.all_holds: + collection = target.all_holds[target.class_] + else: + collection = target.all_holds[target.class_] = {} + + event.registry._stored_in_collection(event_key, target) + collection[event_key._key] = ( + event_key, + raw, + propagate, + retval, + kw, + ) + + if propagate: + stack = list(target.class_.__subclasses__()) + while stack: + subclass = stack.pop(0) + stack.extend(subclass.__subclasses__()) + subject = target.resolve(subclass) + if subject is not None: + # we are already going through __subclasses__() + # so leave generic propagate flag False + event_key.with_dispatch_target(subject).listen( + raw=raw, propagate=False, retval=retval, **kw + ) + + def remove(self, event_key: _EventKey[_ET]) -> None: + target = event_key.dispatch_target + + if isinstance(target, _EventsHold): + collection = target.all_holds[target.class_] + del collection[event_key._key] + + @classmethod + def populate( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + subject: Union[ClassManager[_O], Mapper[_O]], + ) -> None: + for subclass in class_.__mro__: + if subclass in cls.all_holds: + collection = cls.all_holds[subclass] + for ( + event_key, + raw, + propagate, + retval, + kw, + ) in collection.values(): + if propagate or subclass is class_: + # since we can't be sure in what order different + # classes in a hierarchy are triggered with + # populate(), we rely upon _EventsHold for all event + # assignment, instead of using the generic propagate + # flag. + event_key.with_dispatch_target(subject).listen( + raw=raw, propagate=False, retval=retval, **kw + ) + + +class _InstanceEventsHold(_EventsHold[_ET]): + all_holds: weakref.WeakKeyDictionary[Any, Any] = ( + weakref.WeakKeyDictionary() + ) + + def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: + return instrumentation.opt_manager_of_class(class_) + + class HoldInstanceEvents(_EventsHold.HoldEvents[_ET], InstanceEvents): # type: ignore [misc] # noqa: E501 + pass + + dispatch = event.dispatcher(HoldInstanceEvents) + + +class MapperEvents(event.Events[mapperlib.Mapper[Any]]): + """Define events specific to mappings. + + e.g.:: + + from sqlalchemy import event + + def my_before_insert_listener(mapper, connection, target): + # execute a stored procedure upon INSERT, + # apply the value to the row to be inserted + target.calculated_value = connection.execute( + text("select my_special_function(%d)" % target.special_number) + ).scalar() + + # associate the listener function with SomeClass, + # to execute during the "before_insert" hook + event.listen( + SomeClass, 'before_insert', my_before_insert_listener) + + Available targets include: + + * mapped classes + * unmapped superclasses of mapped or to-be-mapped classes + (using the ``propagate=True`` flag) + * :class:`_orm.Mapper` objects + * the :class:`_orm.Mapper` class itself indicates listening for all + mappers. + + Mapper events provide hooks into critical sections of the + mapper, including those related to object instrumentation, + object loading, and object persistence. In particular, the + persistence methods :meth:`~.MapperEvents.before_insert`, + and :meth:`~.MapperEvents.before_update` are popular + places to augment the state being persisted - however, these + methods operate with several significant restrictions. The + user is encouraged to evaluate the + :meth:`.SessionEvents.before_flush` and + :meth:`.SessionEvents.after_flush` methods as more + flexible and user-friendly hooks in which to apply + additional database state during a flush. + + When using :class:`.MapperEvents`, several modifiers are + available to the :func:`.event.listen` function. + + :param propagate=False: When True, the event listener should + be applied to all inheriting mappers and/or the mappers of + inheriting classes, as well as any + mapper which is the target of this listener. + :param raw=False: When True, the "target" argument passed + to applicable event listener functions will be the + instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param retval=False: when True, the user-defined event function + must have a return value, the purpose of which is either to + control subsequent event propagation, or to otherwise alter + the operation in progress by the mapper. Possible return + values are: + + * ``sqlalchemy.orm.interfaces.EXT_CONTINUE`` - continue event + processing normally. + * ``sqlalchemy.orm.interfaces.EXT_STOP`` - cancel all subsequent + event handlers in the chain. + * other values - the return value specified by specific listeners. + + """ + + _target_class_doc = "SomeClass" + _dispatch_target = mapperlib.Mapper + + @classmethod + def _new_mapper_instance( + cls, + class_: Union[DeclarativeAttributeIntercept, DeclarativeMeta, type], + mapper: Mapper[_O], + ) -> None: + _MapperEventsHold.populate(class_, mapper) + + @classmethod + @util.preload_module("sqlalchemy.orm") + def _accept_with( + cls, + target: Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]], + identifier: str, + ) -> Optional[Union[mapperlib.Mapper[Any], Type[mapperlib.Mapper[Any]]]]: + orm = util.preloaded.orm + + if target is orm.mapper: # type: ignore [attr-defined] + util.warn_deprecated( + "The `sqlalchemy.orm.mapper()` symbol is deprecated and " + "will be removed in a future release. For the mapper-wide " + "event target, use the 'sqlalchemy.orm.Mapper' class.", + "2.0", + ) + return mapperlib.Mapper + elif isinstance(target, type): + if issubclass(target, mapperlib.Mapper): + return target + else: + mapper = _mapper_or_none(target) + if mapper is not None: + return mapper + else: + return _MapperEventsHold(target) + else: + return target + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + raw: bool = False, + retval: bool = False, + propagate: bool = False, + **kw: Any, + ) -> None: + target, identifier, fn = ( + event_key.dispatch_target, + event_key.identifier, + event_key._listen_fn, + ) + + if ( + identifier in ("before_configured", "after_configured") + and target is not mapperlib.Mapper + ): + util.warn( + "'before_configured' and 'after_configured' ORM events " + "only invoke with the Mapper class " + "as the target." + ) + + if not raw or not retval: + if not raw: + meth = getattr(cls, identifier) + try: + target_index = ( + inspect_getfullargspec(meth)[0].index("target") - 1 + ) + except ValueError: + target_index = None + + def wrap(*arg: Any, **kw: Any) -> Any: + if not raw and target_index is not None: + arg = list(arg) # type: ignore [assignment] + arg[target_index] = arg[target_index].obj() # type: ignore [index] # noqa: E501 + if not retval: + fn(*arg, **kw) + return interfaces.EXT_CONTINUE + else: + return fn(*arg, **kw) + + event_key = event_key.with_wrapper(wrap) + + if propagate: + for mapper in target.self_and_descendants: + event_key.with_dispatch_target(mapper).base_listen( + propagate=True, **kw + ) + else: + event_key.base_listen(**kw) + + @classmethod + def _clear(cls) -> None: + super()._clear() + _MapperEventsHold._clear() + + def instrument_class(self, mapper: Mapper[_O], class_: Type[_O]) -> None: + r"""Receive a class when the mapper is first constructed, + before instrumentation is applied to the mapped class. + + This event is the earliest phase of mapper construction. + Most attributes of the mapper are not yet initialized. To + receive an event within initial mapper construction where basic + state is available such as the :attr:`_orm.Mapper.attrs` collection, + the :meth:`_orm.MapperEvents.after_mapper_constructed` event may + be a better choice. + + This listener can either be applied to the :class:`_orm.Mapper` + class overall, or to any un-mapped class which serves as a base + for classes that will be mapped (using the ``propagate=True`` flag):: + + Base = declarative_base() + + @event.listens_for(Base, "instrument_class", propagate=True) + def on_new_class(mapper, cls_): + " ... " + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param class\_: the mapped class. + + .. seealso:: + + :meth:`_orm.MapperEvents.after_mapper_constructed` + + """ + + def after_mapper_constructed( + self, mapper: Mapper[_O], class_: Type[_O] + ) -> None: + """Receive a class and mapper when the :class:`_orm.Mapper` has been + fully constructed. + + This event is called after the initial constructor for + :class:`_orm.Mapper` completes. This occurs after the + :meth:`_orm.MapperEvents.instrument_class` event and after the + :class:`_orm.Mapper` has done an initial pass of its arguments + to generate its collection of :class:`_orm.MapperProperty` objects, + which are accessible via the :meth:`_orm.Mapper.get_property` + method and the :attr:`_orm.Mapper.iterate_properties` attribute. + + This event differs from the + :meth:`_orm.MapperEvents.before_mapper_configured` event in that it + is invoked within the constructor for :class:`_orm.Mapper`, rather + than within the :meth:`_orm.registry.configure` process. Currently, + this event is the only one which is appropriate for handlers that + wish to create additional mapped classes in response to the + construction of this :class:`_orm.Mapper`, which will be part of the + same configure step when :meth:`_orm.registry.configure` next runs. + + .. versionadded:: 2.0.2 + + .. seealso:: + + :ref:`examples_versioning` - an example which illustrates the use + of the :meth:`_orm.MapperEvents.before_mapper_configured` + event to create new mappers to record change-audit histories on + objects. + + """ + + def before_mapper_configured( + self, mapper: Mapper[_O], class_: Type[_O] + ) -> None: + """Called right before a specific mapper is to be configured. + + This event is intended to allow a specific mapper to be skipped during + the configure step, by returning the :attr:`.orm.interfaces.EXT_SKIP` + symbol which indicates to the :func:`.configure_mappers` call that this + particular mapper (or hierarchy of mappers, if ``propagate=True`` is + used) should be skipped in the current configuration run. When one or + more mappers are skipped, the he "new mappers" flag will remain set, + meaning the :func:`.configure_mappers` function will continue to be + called when mappers are used, to continue to try to configure all + available mappers. + + In comparison to the other configure-level events, + :meth:`.MapperEvents.before_configured`, + :meth:`.MapperEvents.after_configured`, and + :meth:`.MapperEvents.mapper_configured`, the + :meth;`.MapperEvents.before_mapper_configured` event provides for a + meaningful return value when it is registered with the ``retval=True`` + parameter. + + .. versionadded:: 1.3 + + e.g.:: + + from sqlalchemy.orm import EXT_SKIP + + Base = declarative_base() + + DontConfigureBase = declarative_base() + + @event.listens_for( + DontConfigureBase, + "before_mapper_configured", retval=True, propagate=True) + def dont_configure(mapper, cls): + return EXT_SKIP + + + .. seealso:: + + :meth:`.MapperEvents.before_configured` + + :meth:`.MapperEvents.after_configured` + + :meth:`.MapperEvents.mapper_configured` + + """ + + def mapper_configured(self, mapper: Mapper[_O], class_: Type[_O]) -> None: + r"""Called when a specific mapper has completed its own configuration + within the scope of the :func:`.configure_mappers` call. + + The :meth:`.MapperEvents.mapper_configured` event is invoked + for each mapper that is encountered when the + :func:`_orm.configure_mappers` function proceeds through the current + list of not-yet-configured mappers. + :func:`_orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + When the event is called, the mapper should be in its final + state, but **not including backrefs** that may be invoked from + other mappers; they might still be pending within the + configuration operation. Bidirectional relationships that + are instead configured via the + :paramref:`.orm.relationship.back_populates` argument + *will* be fully available, since this style of relationship does not + rely upon other possibly-not-configured mappers to know that they + exist. + + For an event that is guaranteed to have **all** mappers ready + to go including backrefs that are defined only on other + mappings, use the :meth:`.MapperEvents.after_configured` + event; this event invokes only after all known mappings have been + fully configured. + + The :meth:`.MapperEvents.mapper_configured` event, unlike + :meth:`.MapperEvents.before_configured` or + :meth:`.MapperEvents.after_configured`, + is called for each mapper/class individually, and the mapper is + passed to the event itself. It also is called exactly once for + a particular mapper. The event is therefore useful for + configurational steps that benefit from being invoked just once + on a specific mapper basis, which don't require that "backref" + configurations are necessarily ready yet. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param class\_: the mapped class. + + .. seealso:: + + :meth:`.MapperEvents.before_configured` + + :meth:`.MapperEvents.after_configured` + + :meth:`.MapperEvents.before_mapper_configured` + + """ + # TODO: need coverage for this event + + def before_configured(self) -> None: + """Called before a series of mappers have been configured. + + The :meth:`.MapperEvents.before_configured` event is invoked + each time the :func:`_orm.configure_mappers` function is + invoked, before the function has done any of its work. + :func:`_orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + This event can **only** be applied to the :class:`_orm.Mapper` class, + and not to individual mappings or mapped classes. It is only invoked + for all mappings as a whole:: + + from sqlalchemy.orm import Mapper + + @event.listens_for(Mapper, "before_configured") + def go(): + ... + + Contrast this event to :meth:`.MapperEvents.after_configured`, + which is invoked after the series of mappers has been configured, + as well as :meth:`.MapperEvents.before_mapper_configured` + and :meth:`.MapperEvents.mapper_configured`, which are both invoked + on a per-mapper basis. + + Theoretically this event is called once per + application, but is actually called any time new mappers + are to be affected by a :func:`_orm.configure_mappers` + call. If new mappings are constructed after existing ones have + already been used, this event will likely be called again. To ensure + that a particular event is only called once and no further, the + ``once=True`` argument (new in 0.9.4) can be applied:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "before_configured", once=True) + def go(): + ... + + + .. seealso:: + + :meth:`.MapperEvents.before_mapper_configured` + + :meth:`.MapperEvents.mapper_configured` + + :meth:`.MapperEvents.after_configured` + + """ + + def after_configured(self) -> None: + """Called after a series of mappers have been configured. + + The :meth:`.MapperEvents.after_configured` event is invoked + each time the :func:`_orm.configure_mappers` function is + invoked, after the function has completed its work. + :func:`_orm.configure_mappers` is typically invoked + automatically as mappings are first used, as well as each time + new mappers have been made available and new mapper use is + detected. + + Contrast this event to the :meth:`.MapperEvents.mapper_configured` + event, which is called on a per-mapper basis while the configuration + operation proceeds; unlike that event, when this event is invoked, + all cross-configurations (e.g. backrefs) will also have been made + available for any mappers that were pending. + Also contrast to :meth:`.MapperEvents.before_configured`, + which is invoked before the series of mappers has been configured. + + This event can **only** be applied to the :class:`_orm.Mapper` class, + and not to individual mappings or + mapped classes. It is only invoked for all mappings as a whole:: + + from sqlalchemy.orm import Mapper + + @event.listens_for(Mapper, "after_configured") + def go(): + # ... + + Theoretically this event is called once per + application, but is actually called any time new mappers + have been affected by a :func:`_orm.configure_mappers` + call. If new mappings are constructed after existing ones have + already been used, this event will likely be called again. To ensure + that a particular event is only called once and no further, the + ``once=True`` argument (new in 0.9.4) can be applied:: + + from sqlalchemy.orm import mapper + + @event.listens_for(mapper, "after_configured", once=True) + def go(): + # ... + + .. seealso:: + + :meth:`.MapperEvents.before_mapper_configured` + + :meth:`.MapperEvents.mapper_configured` + + :meth:`.MapperEvents.before_configured` + + """ + + def before_insert( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance before an INSERT statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify local, non-object related + attributes on the instance before an INSERT occurs, as well + as to emit additional SQL statements on the given + connection. + + The event is often called for a batch of objects of the + same class before their INSERT statements are emitted at + once in a later step. In the extremely rare case that + this is not desirable, the :class:`_orm.Mapper` object can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit INSERT statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def after_insert( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance after an INSERT statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify in-Python-only + state on the instance after an INSERT occurs, as well + as to emit additional SQL statements on the given + connection. + + The event is often called for a batch of objects of the + same class after their INSERT statements have been + emitted at once in a previous step. In the extremely + rare case that this is not desirable, the + :class:`_orm.Mapper` object can be configured with ``batch=False``, + which will cause batches of instances to be broken up + into individual (and more poorly performing) + event->persist->event steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit INSERT statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def before_update( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance before an UPDATE statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify local, non-object related + attributes on the instance before an UPDATE occurs, as well + as to emit additional SQL statements on the given + connection. + + This method is called for all instances that are + marked as "dirty", *even those which have no net changes + to their column-based attributes*. An object is marked + as dirty when any of its column-based attributes have a + "set attribute" operation called or when any of its + collections are modified. If, at update time, no + column-based attributes have any net changes, no UPDATE + statement will be issued. This means that an instance + being sent to :meth:`~.MapperEvents.before_update` is + *not* a guarantee that an UPDATE statement will be + issued, although you can affect the outcome here by + modifying attributes so that a net change in value does + exist. + + To detect if the column-based attributes on the object have net + changes, and will therefore generate an UPDATE statement, use + ``object_session(instance).is_modified(instance, + include_collections=False)``. + + The event is often called for a batch of objects of the + same class before their UPDATE statements are emitted at + once in a later step. In the extremely rare case that + this is not desirable, the :class:`_orm.Mapper` can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit UPDATE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def after_update( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance after an UPDATE statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to modify in-Python-only + state on the instance after an UPDATE occurs, as well + as to emit additional SQL statements on the given + connection. + + This method is called for all instances that are + marked as "dirty", *even those which have no net changes + to their column-based attributes*, and for which + no UPDATE statement has proceeded. An object is marked + as dirty when any of its column-based attributes have a + "set attribute" operation called or when any of its + collections are modified. If, at update time, no + column-based attributes have any net changes, no UPDATE + statement will be issued. This means that an instance + being sent to :meth:`~.MapperEvents.after_update` is + *not* a guarantee that an UPDATE statement has been + issued. + + To detect if the column-based attributes on the object have net + changes, and therefore resulted in an UPDATE statement, use + ``object_session(instance).is_modified(instance, + include_collections=False)``. + + The event is often called for a batch of objects of the + same class after their UPDATE statements have been emitted at + once in a previous step. In the extremely rare case that + this is not desirable, the :class:`_orm.Mapper` can be + configured with ``batch=False``, which will cause + batches of instances to be broken up into individual + (and more poorly performing) event->persist->event + steps. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit UPDATE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being persisted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def before_delete( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance before a DELETE statement + is emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to emit additional SQL statements on + the given connection as well as to perform application + specific bookkeeping related to a deletion event. + + The event is often called for a batch of objects of the + same class before their DELETE statements are emitted at + once in a later step. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit DELETE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being deleted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + def after_delete( + self, mapper: Mapper[_O], connection: Connection, target: _O + ) -> None: + """Receive an object instance after a DELETE statement + has been emitted corresponding to that instance. + + .. note:: this event **only** applies to the + :ref:`session flush operation ` + and does **not** apply to the ORM DML operations described at + :ref:`orm_expression_update_delete`. To intercept ORM + DML events, use :meth:`_orm.SessionEvents.do_orm_execute`. + + This event is used to emit additional SQL statements on + the given connection as well as to perform application + specific bookkeeping related to a deletion event. + + The event is often called for a batch of objects of the + same class after their DELETE statements have been emitted at + once in a previous step. + + .. warning:: + + Mapper-level flush events only allow **very limited operations**, + on attributes local to the row being operated upon only, + as well as allowing any SQL to be emitted on the given + :class:`_engine.Connection`. **Please read fully** the notes + at :ref:`session_persistence_mapper` for guidelines on using + these methods; generally, the :meth:`.SessionEvents.before_flush` + method should be preferred for general on-flush changes. + + :param mapper: the :class:`_orm.Mapper` which is the target + of this event. + :param connection: the :class:`_engine.Connection` being used to + emit DELETE statements for this instance. This + provides a handle into the current transaction on the + target database specific to this instance. + :param target: the mapped instance being deleted. If + the event is configured with ``raw=True``, this will + instead be the :class:`.InstanceState` state-management + object associated with the instance. + :return: No return value is supported by this event. + + .. seealso:: + + :ref:`session_persistence_events` + + """ + + +class _MapperEventsHold(_EventsHold[_ET]): + all_holds = weakref.WeakKeyDictionary() + + def resolve( + self, class_: Union[Type[_T], _InternalEntityType[_T]] + ) -> Optional[Mapper[_T]]: + return _mapper_or_none(class_) + + class HoldMapperEvents(_EventsHold.HoldEvents[_ET], MapperEvents): # type: ignore [misc] # noqa: E501 + pass + + dispatch = event.dispatcher(HoldMapperEvents) + + +_sessionevents_lifecycle_event_names: Set[str] = set() + + +class SessionEvents(event.Events[Session]): + """Define events specific to :class:`.Session` lifecycle. + + e.g.:: + + from sqlalchemy import event + from sqlalchemy.orm import sessionmaker + + def my_before_commit(session): + print("before commit!") + + Session = sessionmaker() + + event.listen(Session, "before_commit", my_before_commit) + + The :func:`~.event.listen` function will accept + :class:`.Session` objects as well as the return result + of :class:`~.sessionmaker()` and :class:`~.scoped_session()`. + + Additionally, it accepts the :class:`.Session` class which + will apply listeners to all :class:`.Session` instances + globally. + + :param raw=False: When True, the "target" argument passed + to applicable event listener functions that work on individual + objects will be the instance's :class:`.InstanceState` management + object, rather than the mapped instance itself. + + .. versionadded:: 1.3.14 + + :param restore_load_context=False: Applies to the + :meth:`.SessionEvents.loaded_as_persistent` event. Restores the loader + context of the object when the event hook is complete, so that ongoing + eager load operations continue to target the object appropriately. A + warning is emitted if the object is moved to a new loader context from + within this event if this flag is not set. + + .. versionadded:: 1.3.14 + + """ + + _target_class_doc = "SomeSessionClassOrObject" + + _dispatch_target = Session + + def _lifecycle_event( # type: ignore [misc] + fn: Callable[[SessionEvents, Session, Any], None] + ) -> Callable[[SessionEvents, Session, Any], None]: + _sessionevents_lifecycle_event_names.add(fn.__name__) + return fn + + @classmethod + def _accept_with( # type: ignore [return] + cls, target: Any, identifier: str + ) -> Union[Session, type]: + if isinstance(target, scoped_session): + target = target.session_factory + if not isinstance(target, sessionmaker) and ( + not isinstance(target, type) or not issubclass(target, Session) + ): + raise exc.ArgumentError( + "Session event listen on a scoped_session " + "requires that its creation callable " + "is associated with the Session class." + ) + + if isinstance(target, sessionmaker): + return target.class_ + elif isinstance(target, type): + if issubclass(target, scoped_session): + return Session + elif issubclass(target, Session): + return target + elif isinstance(target, Session): + return target + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() + else: + # allows alternate SessionEvents-like-classes to be consulted + return event.Events._accept_with(target, identifier) # type: ignore [return-value] # noqa: E501 + + @classmethod + def _listen( + cls, + event_key: Any, + *, + raw: bool = False, + restore_load_context: bool = False, + **kw: Any, + ) -> None: + is_instance_event = ( + event_key.identifier in _sessionevents_lifecycle_event_names + ) + + if is_instance_event: + if not raw or restore_load_context: + fn = event_key._listen_fn + + def wrap( + session: Session, + state: InstanceState[_O], + *arg: Any, + **kw: Any, + ) -> Optional[Any]: + if not raw: + target = state.obj() + if target is None: + # existing behavior is that if the object is + # garbage collected, no event is emitted + return None + else: + target = state # type: ignore [assignment] + if restore_load_context: + runid = state.runid + try: + return fn(session, target, *arg, **kw) + finally: + if restore_load_context: + state.runid = runid + + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(**kw) + + def do_orm_execute(self, orm_execute_state: ORMExecuteState) -> None: + """Intercept statement executions that occur on behalf of an + ORM :class:`.Session` object. + + This event is invoked for all top-level SQL statements invoked from the + :meth:`_orm.Session.execute` method, as well as related methods such as + :meth:`_orm.Session.scalars` and :meth:`_orm.Session.scalar`. As of + SQLAlchemy 1.4, all ORM queries that run through the + :meth:`_orm.Session.execute` method as well as related methods + :meth:`_orm.Session.scalars`, :meth:`_orm.Session.scalar` etc. + will participate in this event. + This event hook does **not** apply to the queries that are + emitted internally within the ORM flush process, i.e. the + process described at :ref:`session_flushing`. + + .. note:: The :meth:`_orm.SessionEvents.do_orm_execute` event hook + is triggered **for ORM statement executions only**, meaning those + invoked via the :meth:`_orm.Session.execute` and similar methods on + the :class:`_orm.Session` object. It does **not** trigger for + statements that are invoked by SQLAlchemy Core only, i.e. statements + invoked directly using :meth:`_engine.Connection.execute` or + otherwise originating from an :class:`_engine.Engine` object without + any :class:`_orm.Session` involved. To intercept **all** SQL + executions regardless of whether the Core or ORM APIs are in use, + see the event hooks at :class:`.ConnectionEvents`, such as + :meth:`.ConnectionEvents.before_execute` and + :meth:`.ConnectionEvents.before_cursor_execute`. + + Also, this event hook does **not** apply to queries that are + emitted internally within the ORM flush process, + i.e. the process described at :ref:`session_flushing`; to + intercept steps within the flush process, see the event + hooks described at :ref:`session_persistence_events` as + well as :ref:`session_persistence_mapper`. + + This event is a ``do_`` event, meaning it has the capability to replace + the operation that the :meth:`_orm.Session.execute` method normally + performs. The intended use for this includes sharding and + result-caching schemes which may seek to invoke the same statement + across multiple database connections, returning a result that is + merged from each of them, or which don't invoke the statement at all, + instead returning data from a cache. + + The hook intends to replace the use of the + ``Query._execute_and_instances`` method that could be subclassed prior + to SQLAlchemy 1.4. + + :param orm_execute_state: an instance of :class:`.ORMExecuteState` + which contains all information about the current execution, as well + as helper functions used to derive other commonly required + information. See that object for details. + + .. seealso:: + + :ref:`session_execute_events` - top level documentation on how + to use :meth:`_orm.SessionEvents.do_orm_execute` + + :class:`.ORMExecuteState` - the object passed to the + :meth:`_orm.SessionEvents.do_orm_execute` event which contains + all information about the statement to be invoked. It also + provides an interface to extend the current statement, options, + and parameters as well as an option that allows programmatic + invocation of the statement at any point. + + :ref:`examples_session_orm_events` - includes examples of using + :meth:`_orm.SessionEvents.do_orm_execute` + + :ref:`examples_caching` - an example of how to integrate + Dogpile caching with the ORM :class:`_orm.Session` making use + of the :meth:`_orm.SessionEvents.do_orm_execute` event hook. + + :ref:`examples_sharding` - the Horizontal Sharding example / + extension relies upon the + :meth:`_orm.SessionEvents.do_orm_execute` event hook to invoke a + SQL statement on multiple backends and return a merged result. + + + .. versionadded:: 1.4 + + """ + + def after_transaction_create( + self, session: Session, transaction: SessionTransaction + ) -> None: + """Execute when a new :class:`.SessionTransaction` is created. + + This event differs from :meth:`~.SessionEvents.after_begin` + in that it occurs for each :class:`.SessionTransaction` + overall, as opposed to when transactions are begun + on individual database connections. It is also invoked + for nested transactions and subtransactions, and is always + matched by a corresponding + :meth:`~.SessionEvents.after_transaction_end` event + (assuming normal operation of the :class:`.Session`). + + :param session: the target :class:`.Session`. + :param transaction: the target :class:`.SessionTransaction`. + + To detect if this is the outermost + :class:`.SessionTransaction`, as opposed to a "subtransaction" or a + SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute + is ``None``:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_create(session, transaction): + if transaction.parent is None: + # work with top-level transaction + + To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the + :attr:`.SessionTransaction.nested` attribute:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_create(session, transaction): + if transaction.nested: + # work with SAVEPOINT transaction + + + .. seealso:: + + :class:`.SessionTransaction` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_transaction_end( + self, session: Session, transaction: SessionTransaction + ) -> None: + """Execute when the span of a :class:`.SessionTransaction` ends. + + This event differs from :meth:`~.SessionEvents.after_commit` + in that it corresponds to all :class:`.SessionTransaction` + objects in use, including those for nested transactions + and subtransactions, and is always matched by a corresponding + :meth:`~.SessionEvents.after_transaction_create` event. + + :param session: the target :class:`.Session`. + :param transaction: the target :class:`.SessionTransaction`. + + To detect if this is the outermost + :class:`.SessionTransaction`, as opposed to a "subtransaction" or a + SAVEPOINT, test that the :attr:`.SessionTransaction.parent` attribute + is ``None``:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_end(session, transaction): + if transaction.parent is None: + # work with top-level transaction + + To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the + :attr:`.SessionTransaction.nested` attribute:: + + @event.listens_for(session, "after_transaction_create") + def after_transaction_end(session, transaction): + if transaction.nested: + # work with SAVEPOINT transaction + + + .. seealso:: + + :class:`.SessionTransaction` + + :meth:`~.SessionEvents.after_transaction_create` + + """ + + def before_commit(self, session: Session) -> None: + """Execute before commit is called. + + .. note:: + + The :meth:`~.SessionEvents.before_commit` hook is *not* per-flush, + that is, the :class:`.Session` can emit SQL to the database + many times within the scope of a transaction. + For interception of these events, use the + :meth:`~.SessionEvents.before_flush`, + :meth:`~.SessionEvents.after_flush`, or + :meth:`~.SessionEvents.after_flush_postexec` + events. + + :param session: The target :class:`.Session`. + + .. seealso:: + + :meth:`~.SessionEvents.after_commit` + + :meth:`~.SessionEvents.after_begin` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_commit(self, session: Session) -> None: + """Execute after a commit has occurred. + + .. note:: + + The :meth:`~.SessionEvents.after_commit` hook is *not* per-flush, + that is, the :class:`.Session` can emit SQL to the database + many times within the scope of a transaction. + For interception of these events, use the + :meth:`~.SessionEvents.before_flush`, + :meth:`~.SessionEvents.after_flush`, or + :meth:`~.SessionEvents.after_flush_postexec` + events. + + .. note:: + + The :class:`.Session` is not in an active transaction + when the :meth:`~.SessionEvents.after_commit` event is invoked, + and therefore can not emit SQL. To emit SQL corresponding to + every transaction, use the :meth:`~.SessionEvents.before_commit` + event. + + :param session: The target :class:`.Session`. + + .. seealso:: + + :meth:`~.SessionEvents.before_commit` + + :meth:`~.SessionEvents.after_begin` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + def after_rollback(self, session: Session) -> None: + """Execute after a real DBAPI rollback has occurred. + + Note that this event only fires when the *actual* rollback against + the database occurs - it does *not* fire each time the + :meth:`.Session.rollback` method is called, if the underlying + DBAPI transaction has already been rolled back. In many + cases, the :class:`.Session` will not be in + an "active" state during this event, as the current + transaction is not valid. To acquire a :class:`.Session` + which is active after the outermost rollback has proceeded, + use the :meth:`.SessionEvents.after_soft_rollback` event, checking the + :attr:`.Session.is_active` flag. + + :param session: The target :class:`.Session`. + + """ + + def after_soft_rollback( + self, session: Session, previous_transaction: SessionTransaction + ) -> None: + """Execute after any rollback has occurred, including "soft" + rollbacks that don't actually emit at the DBAPI level. + + This corresponds to both nested and outer rollbacks, i.e. + the innermost rollback that calls the DBAPI's + rollback() method, as well as the enclosing rollback + calls that only pop themselves from the transaction stack. + + The given :class:`.Session` can be used to invoke SQL and + :meth:`.Session.query` operations after an outermost rollback + by first checking the :attr:`.Session.is_active` flag:: + + @event.listens_for(Session, "after_soft_rollback") + def do_something(session, previous_transaction): + if session.is_active: + session.execute(text("select * from some_table")) + + :param session: The target :class:`.Session`. + :param previous_transaction: The :class:`.SessionTransaction` + transactional marker object which was just closed. The current + :class:`.SessionTransaction` for the given :class:`.Session` is + available via the :attr:`.Session.transaction` attribute. + + """ + + def before_flush( + self, + session: Session, + flush_context: UOWTransaction, + instances: Optional[Sequence[_O]], + ) -> None: + """Execute before flush process has started. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + :param instances: Usually ``None``, this is the collection of + objects which can be passed to the :meth:`.Session.flush` method + (note this usage is deprecated). + + .. seealso:: + + :meth:`~.SessionEvents.after_flush` + + :meth:`~.SessionEvents.after_flush_postexec` + + :ref:`session_persistence_events` + + """ + + def after_flush( + self, session: Session, flush_context: UOWTransaction + ) -> None: + """Execute after flush has completed, but before commit has been + called. + + Note that the session's state is still in pre-flush, i.e. 'new', + 'dirty', and 'deleted' lists still show pre-flush state as well + as the history settings on instance attributes. + + .. warning:: This event runs after the :class:`.Session` has emitted + SQL to modify the database, but **before** it has altered its + internal state to reflect those changes, including that newly + inserted objects are placed into the identity map. ORM operations + emitted within this event such as loads of related items + may produce new identity map entries that will immediately + be replaced, sometimes causing confusing results. SQLAlchemy will + emit a warning for this condition as of version 1.3.9. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + + .. seealso:: + + :meth:`~.SessionEvents.before_flush` + + :meth:`~.SessionEvents.after_flush_postexec` + + :ref:`session_persistence_events` + + """ + + def after_flush_postexec( + self, session: Session, flush_context: UOWTransaction + ) -> None: + """Execute after flush has completed, and after the post-exec + state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in + their final state. An actual commit() may or may not have + occurred, depending on whether or not the flush started its own + transaction or participated in a larger transaction. + + :param session: The target :class:`.Session`. + :param flush_context: Internal :class:`.UOWTransaction` object + which handles the details of the flush. + + + .. seealso:: + + :meth:`~.SessionEvents.before_flush` + + :meth:`~.SessionEvents.after_flush` + + :ref:`session_persistence_events` + + """ + + def after_begin( + self, + session: Session, + transaction: SessionTransaction, + connection: Connection, + ) -> None: + """Execute after a transaction is begun on a connection. + + .. note:: This event is called within the process of the + :class:`_orm.Session` modifying its own internal state. + To invoke SQL operations within this hook, use the + :class:`_engine.Connection` provided to the event; + do not run SQL operations using the :class:`_orm.Session` + directly. + + :param session: The target :class:`.Session`. + :param transaction: The :class:`.SessionTransaction`. + :param connection: The :class:`_engine.Connection` object + which will be used for SQL statements. + + .. seealso:: + + :meth:`~.SessionEvents.before_commit` + + :meth:`~.SessionEvents.after_commit` + + :meth:`~.SessionEvents.after_transaction_create` + + :meth:`~.SessionEvents.after_transaction_end` + + """ + + @_lifecycle_event + def before_attach(self, session: Session, instance: _O) -> None: + """Execute before an instance is attached to a session. + + This is called before an add, delete or merge causes + the object to be part of the session. + + .. seealso:: + + :meth:`~.SessionEvents.after_attach` + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def after_attach(self, session: Session, instance: _O) -> None: + """Execute after an instance is attached to a session. + + This is called after an add, delete or merge. + + .. note:: + + As of 0.8, this event fires off *after* the item + has been fully associated with the session, which is + different than previous releases. For event + handlers that require the object not yet + be part of session state (such as handlers which + may autoflush while the target object is not + yet complete) consider the + new :meth:`.before_attach` event. + + .. seealso:: + + :meth:`~.SessionEvents.before_attach` + + :ref:`session_lifecycle_events` + + """ + + @event._legacy_signature( + "0.9", + ["session", "query", "query_context", "result"], + lambda update_context: ( + update_context.session, + update_context.query, + None, + update_context.result, + ), + ) + def after_bulk_update(self, update_context: _O) -> None: + """Event for after the legacy :meth:`_orm.Query.update` method + has been called. + + .. legacy:: The :meth:`_orm.SessionEvents.after_bulk_update` method + is a legacy event hook as of SQLAlchemy 2.0. The event + **does not participate** in :term:`2.0 style` invocations + using :func:`_dml.update` documented at + :ref:`orm_queryguide_update_delete_where`. For 2.0 style use, + the :meth:`_orm.SessionEvents.do_orm_execute` hook will intercept + these calls. + + :param update_context: an "update context" object which contains + details about the update, including these attributes: + + * ``session`` - the :class:`.Session` involved + * ``query`` -the :class:`_query.Query` + object that this update operation + was called upon. + * ``values`` The "values" dictionary that was passed to + :meth:`_query.Query.update`. + * ``result`` the :class:`_engine.CursorResult` + returned as a result of the + bulk UPDATE operation. + + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + + .. seealso:: + + :meth:`.QueryEvents.before_compile_update` + + :meth:`.SessionEvents.after_bulk_delete` + + """ + + @event._legacy_signature( + "0.9", + ["session", "query", "query_context", "result"], + lambda delete_context: ( + delete_context.session, + delete_context.query, + None, + delete_context.result, + ), + ) + def after_bulk_delete(self, delete_context: _O) -> None: + """Event for after the legacy :meth:`_orm.Query.delete` method + has been called. + + .. legacy:: The :meth:`_orm.SessionEvents.after_bulk_delete` method + is a legacy event hook as of SQLAlchemy 2.0. The event + **does not participate** in :term:`2.0 style` invocations + using :func:`_dml.delete` documented at + :ref:`orm_queryguide_update_delete_where`. For 2.0 style use, + the :meth:`_orm.SessionEvents.do_orm_execute` hook will intercept + these calls. + + :param delete_context: a "delete context" object which contains + details about the update, including these attributes: + + * ``session`` - the :class:`.Session` involved + * ``query`` -the :class:`_query.Query` + object that this update operation + was called upon. + * ``result`` the :class:`_engine.CursorResult` + returned as a result of the + bulk DELETE operation. + + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + + .. seealso:: + + :meth:`.QueryEvents.before_compile_delete` + + :meth:`.SessionEvents.after_bulk_update` + + """ + + @_lifecycle_event + def transient_to_pending(self, session: Session, instance: _O) -> None: + """Intercept the "transient to pending" transition for a specific + object. + + This event is a specialization of the + :meth:`.SessionEvents.after_attach` event which is only invoked + for this specific transition. It is invoked typically during the + :meth:`.Session.add` call. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def pending_to_transient(self, session: Session, instance: _O) -> None: + """Intercept the "pending to transient" transition for a specific + object. + + This less common transition occurs when an pending object that has + not been flushed is evicted from the session; this can occur + when the :meth:`.Session.rollback` method rolls back the transaction, + or when the :meth:`.Session.expunge` method is used. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def persistent_to_transient(self, session: Session, instance: _O) -> None: + """Intercept the "persistent to transient" transition for a specific + object. + + This less common transition occurs when an pending object that has + has been flushed is evicted from the session; this can occur + when the :meth:`.Session.rollback` method rolls back the transaction. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def pending_to_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "pending to persistent"" transition for a specific + object. + + This event is invoked within the flush process, and is + similar to scanning the :attr:`.Session.new` collection within + the :meth:`.SessionEvents.after_flush` event. However, in this + case the object has already been moved to the persistent state + when the event is called. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def detached_to_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "detached to persistent" transition for a specific + object. + + This event is a specialization of the + :meth:`.SessionEvents.after_attach` event which is only invoked + for this specific transition. It is invoked typically during the + :meth:`.Session.add` call, as well as during the + :meth:`.Session.delete` call if the object was not previously + associated with the + :class:`.Session` (note that an object marked as "deleted" remains + in the "persistent" state until the flush proceeds). + + .. note:: + + If the object becomes persistent as part of a call to + :meth:`.Session.delete`, the object is **not** yet marked as + deleted when this event is called. To detect deleted objects, + check the ``deleted`` flag sent to the + :meth:`.SessionEvents.persistent_to_detached` to event after the + flush proceeds, or check the :attr:`.Session.deleted` collection + within the :meth:`.SessionEvents.before_flush` event if deleted + objects need to be intercepted before the flush. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def loaded_as_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "loaded as persistent" transition for a specific + object. + + This event is invoked within the ORM loading process, and is invoked + very similarly to the :meth:`.InstanceEvents.load` event. However, + the event here is linkable to a :class:`.Session` class or instance, + rather than to a mapper or class hierarchy, and integrates + with the other session lifecycle events smoothly. The object + is guaranteed to be present in the session's identity map when + this event is called. + + .. note:: This event is invoked within the loader process before + eager loaders may have been completed, and the object's state may + not be complete. Additionally, invoking row-level refresh + operations on the object will place the object into a new loader + context, interfering with the existing load context. See the note + on :meth:`.InstanceEvents.load` for background on making use of the + :paramref:`.SessionEvents.restore_load_context` parameter, which + works in the same manner as that of + :paramref:`.InstanceEvents.restore_load_context`, in order to + resolve this scenario. + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def persistent_to_deleted(self, session: Session, instance: _O) -> None: + """Intercept the "persistent to deleted" transition for a specific + object. + + This event is invoked when a persistent object's identity + is deleted from the database within a flush, however the object + still remains associated with the :class:`.Session` until the + transaction completes. + + If the transaction is rolled back, the object moves again + to the persistent state, and the + :meth:`.SessionEvents.deleted_to_persistent` event is called. + If the transaction is committed, the object becomes detached, + which will emit the :meth:`.SessionEvents.deleted_to_detached` + event. + + Note that while the :meth:`.Session.delete` method is the primary + public interface to mark an object as deleted, many objects + get deleted due to cascade rules, which are not always determined + until flush time. Therefore, there's no way to catch + every object that will be deleted until the flush has proceeded. + the :meth:`.SessionEvents.persistent_to_deleted` event is therefore + invoked at the end of a flush. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def deleted_to_persistent(self, session: Session, instance: _O) -> None: + """Intercept the "deleted to persistent" transition for a specific + object. + + This transition occurs only when an object that's been deleted + successfully in a flush is restored due to a call to + :meth:`.Session.rollback`. The event is not called under + any other circumstances. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def deleted_to_detached(self, session: Session, instance: _O) -> None: + """Intercept the "deleted to detached" transition for a specific + object. + + This event is invoked when a deleted object is evicted + from the session. The typical case when this occurs is when + the transaction for a :class:`.Session` in which the object + was deleted is committed; the object moves from the deleted + state to the detached state. + + It is also invoked for objects that were deleted in a flush + when the :meth:`.Session.expunge_all` or :meth:`.Session.close` + events are called, as well as if the object is individually + expunged from its deleted state via :meth:`.Session.expunge`. + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + @_lifecycle_event + def persistent_to_detached(self, session: Session, instance: _O) -> None: + """Intercept the "persistent to detached" transition for a specific + object. + + This event is invoked when a persistent object is evicted + from the session. There are many conditions that cause this + to happen, including: + + * using a method such as :meth:`.Session.expunge` + or :meth:`.Session.close` + + * Calling the :meth:`.Session.rollback` method, when the object + was part of an INSERT statement for that session's transaction + + + :param session: target :class:`.Session` + + :param instance: the ORM-mapped instance being operated upon. + + :param deleted: boolean. If True, indicates this object moved + to the detached state because it was marked as deleted and flushed. + + + .. seealso:: + + :ref:`session_lifecycle_events` + + """ + + +class AttributeEvents(event.Events[QueryableAttribute[Any]]): + r"""Define events for object attributes. + + These are typically defined on the class-bound descriptor for the + target class. + + For example, to register a listener that will receive the + :meth:`_orm.AttributeEvents.append` event:: + + from sqlalchemy import event + + @event.listens_for(MyClass.collection, 'append', propagate=True) + def my_append_listener(target, value, initiator): + print("received append event for target: %s" % target) + + + Listeners have the option to return a possibly modified version of the + value, when the :paramref:`.AttributeEvents.retval` flag is passed to + :func:`.event.listen` or :func:`.event.listens_for`, such as below, + illustrated using the :meth:`_orm.AttributeEvents.set` event:: + + def validate_phone(target, value, oldvalue, initiator): + "Strip non-numeric characters from a phone number" + + return re.sub(r'\D', '', value) + + # setup listener on UserContact.phone attribute, instructing + # it to use the return value + listen(UserContact.phone, 'set', validate_phone, retval=True) + + A validation function like the above can also raise an exception + such as :exc:`ValueError` to halt the operation. + + The :paramref:`.AttributeEvents.propagate` flag is also important when + applying listeners to mapped classes that also have mapped subclasses, + as when using mapper inheritance patterns:: + + + @event.listens_for(MySuperClass.attr, 'set', propagate=True) + def receive_set(target, value, initiator): + print("value set: %s" % target) + + The full list of modifiers available to the :func:`.event.listen` + and :func:`.event.listens_for` functions are below. + + :param active_history=False: When True, indicates that the + "set" event would like to receive the "old" value being + replaced unconditionally, even if this requires firing off + database loads. Note that ``active_history`` can also be + set directly via :func:`.column_property` and + :func:`_orm.relationship`. + + :param propagate=False: When True, the listener function will + be established not just for the class attribute given, but + for attributes of the same name on all current subclasses + of that class, as well as all future subclasses of that + class, using an additional listener that listens for + instrumentation events. + :param raw=False: When True, the "target" argument to the + event will be the :class:`.InstanceState` management + object, rather than the mapped instance itself. + :param retval=False: when True, the user-defined event + listening must return the "value" argument from the + function. This gives the listening function the opportunity + to change the value that is ultimately used for a "set" + or "append" event. + + """ + + _target_class_doc = "SomeClass.some_attribute" + _dispatch_target = QueryableAttribute + + @staticmethod + def _set_dispatch( + cls: Type[_HasEventsDispatch[Any]], dispatch_cls: Type[_Dispatch[Any]] + ) -> _Dispatch[Any]: + dispatch = event.Events._set_dispatch(cls, dispatch_cls) + dispatch_cls._active_history = False + return dispatch + + @classmethod + def _accept_with( + cls, + target: Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]], + identifier: str, + ) -> Union[QueryableAttribute[Any], Type[QueryableAttribute[Any]]]: + # TODO: coverage + if isinstance(target, interfaces.MapperProperty): + return getattr(target.parent.class_, target.key) + else: + return target + + @classmethod + def _listen( # type: ignore [override] + cls, + event_key: _EventKey[QueryableAttribute[Any]], + active_history: bool = False, + raw: bool = False, + retval: bool = False, + propagate: bool = False, + include_key: bool = False, + ) -> None: + target, fn = event_key.dispatch_target, event_key._listen_fn + + if active_history: + target.dispatch._active_history = True + + if not raw or not retval or not include_key: + + def wrap(target: InstanceState[_O], *arg: Any, **kw: Any) -> Any: + if not raw: + target = target.obj() # type: ignore [assignment] + if not retval: + if arg: + value = arg[0] + else: + value = None + if include_key: + fn(target, *arg, **kw) + else: + fn(target, *arg) + return value + else: + if include_key: + return fn(target, *arg, **kw) + else: + return fn(target, *arg) + + event_key = event_key.with_wrapper(wrap) + + event_key.base_listen(propagate=propagate) + + if propagate: + manager = instrumentation.manager_of_class(target.class_) + + for mgr in manager.subclass_managers(True): # type: ignore [no-untyped-call] # noqa: E501 + event_key.with_dispatch_target(mgr[target.key]).base_listen( + propagate=True + ) + if active_history: + mgr[target.key].dispatch._active_history = True + + def append( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> Optional[_T]: + """Receive a collection append event. + + The append event is invoked for each element as it is appended + to the collection. This occurs for single-item appends as well + as for a "bulk replace" operation. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being appended. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation, as well as be inspected for information + about the source of the event. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``collection[some_key_or_index] = value``. + The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + :return: if the event was registered with ``retval=True``, + the given value, or a new effective value, should be returned. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + :meth:`.AttributeEvents.bulk_replace` + + """ + + def append_wo_mutation( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> None: + """Receive a collection append event where the collection was not + actually mutated. + + This event differs from :meth:`_orm.AttributeEvents.append` in that + it is fired off for de-duplicating collections such as sets and + dictionaries, when the object already exists in the target collection. + The event does not have a return value and the identity of the + given object cannot be changed. + + The event is used for cascading objects into a :class:`_orm.Session` + when the collection has already been mutated via a backref event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value that would be appended if the object did not + already exist in the collection. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation, as well as be inspected for information + about the source of the event. + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``collection[some_key_or_index] = value``. + The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + :return: No return value is defined for this event. + + .. versionadded:: 1.4.15 + + """ + + def bulk_replace( + self, + target: _O, + values: Iterable[_T], + initiator: Event, + *, + keys: Optional[Iterable[EventConstants]] = None, + ) -> None: + """Receive a collection 'bulk replace' event. + + This event is invoked for a sequence of values as they are incoming + to a bulk collection set operation, which can be + modified in place before the values are treated as ORM objects. + This is an "early hook" that runs before the bulk replace routine + attempts to reconcile which objects are already present in the + collection and which are being removed by the net replace operation. + + It is typical that this method be combined with use of the + :meth:`.AttributeEvents.append` event. When using both of these + events, note that a bulk replace operation will invoke + the :meth:`.AttributeEvents.append` event for all new items, + even after :meth:`.AttributeEvents.bulk_replace` has been invoked + for the collection as a whole. In order to determine if an + :meth:`.AttributeEvents.append` event is part of a bulk replace, + use the symbol :attr:`~.attributes.OP_BULK_REPLACE` to test the + incoming initiator:: + + from sqlalchemy.orm.attributes import OP_BULK_REPLACE + + @event.listens_for(SomeObject.collection, "bulk_replace") + def process_collection(target, values, initiator): + values[:] = [_make_value(value) for value in values] + + @event.listens_for(SomeObject.collection, "append", retval=True) + def process_collection(target, value, initiator): + # make sure bulk_replace didn't already do it + if initiator is None or initiator.op is not OP_BULK_REPLACE: + return _make_value(value) + else: + return value + + .. versionadded:: 1.2 + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: a sequence (e.g. a list) of the values being set. The + handler can modify this list in place. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. + :param keys: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the sequence of keys used in the operation, + typically only for a dictionary update. The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + + """ + + def remove( + self, + target: _O, + value: _T, + initiator: Event, + *, + key: EventConstants = NO_KEY, + ) -> None: + """Receive a collection remove event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being removed. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation. + + :param key: When the event is established using the + :paramref:`.AttributeEvents.include_key` parameter set to + True, this will be the key used in the operation, such as + ``del collection[some_key_or_index]``. The parameter is not passed + to the event at all if the the + :paramref:`.AttributeEvents.include_key` + was not used to set up the event; this is to allow backwards + compatibility with existing event handlers that don't include the + ``key`` parameter. + + .. versionadded:: 2.0 + + :return: No return value is defined for this event. + + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + def set( + self, target: _O, value: _T, oldvalue: _T, initiator: Event + ) -> None: + """Receive a scalar set event. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value being set. If this listener + is registered with ``retval=True``, the listener + function must return this value, or a new value which + replaces it. + :param oldvalue: the previous value being replaced. This + may also be the symbol ``NEVER_SET`` or ``NO_VALUE``. + If the listener is registered with ``active_history=True``, + the previous value of the attribute will be loaded from + the database if the existing value is currently unloaded + or expired. + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. May be modified + from its original value by backref handlers in order to control + chained event propagation. + + :return: if the event was registered with ``retval=True``, + the given value, or a new effective value, should be returned. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + def init_scalar( + self, target: _O, value: _T, dict_: Dict[Any, Any] + ) -> None: + r"""Receive a scalar "init" event. + + This event is invoked when an uninitialized, unpersisted scalar + attribute is accessed, e.g. read:: + + + x = my_object.some_attribute + + The ORM's default behavior when this occurs for an un-initialized + attribute is to return the value ``None``; note this differs from + Python's usual behavior of raising ``AttributeError``. The + event here can be used to customize what value is actually returned, + with the assumption that the event listener would be mirroring + a default generator that is configured on the Core + :class:`_schema.Column` + object as well. + + Since a default generator on a :class:`_schema.Column` + might also produce + a changing value such as a timestamp, the + :meth:`.AttributeEvents.init_scalar` + event handler can also be used to **set** the newly returned value, so + that a Core-level default generation function effectively fires off + only once, but at the moment the attribute is accessed on the + non-persisted object. Normally, no change to the object's state + is made when an uninitialized attribute is accessed (much older + SQLAlchemy versions did in fact change the object's state). + + If a default generator on a column returned a particular constant, + a handler might be used as follows:: + + SOME_CONSTANT = 3.1415926 + + class MyClass(Base): + # ... + + some_attribute = Column(Numeric, default=SOME_CONSTANT) + + @event.listens_for( + MyClass.some_attribute, "init_scalar", + retval=True, propagate=True) + def _init_some_attribute(target, dict_, value): + dict_['some_attribute'] = SOME_CONSTANT + return SOME_CONSTANT + + Above, we initialize the attribute ``MyClass.some_attribute`` to the + value of ``SOME_CONSTANT``. The above code includes the following + features: + + * By setting the value ``SOME_CONSTANT`` in the given ``dict_``, + we indicate that this value is to be persisted to the database. + This supersedes the use of ``SOME_CONSTANT`` in the default generator + for the :class:`_schema.Column`. The ``active_column_defaults.py`` + example given at :ref:`examples_instrumentation` illustrates using + the same approach for a changing default, e.g. a timestamp + generator. In this particular example, it is not strictly + necessary to do this since ``SOME_CONSTANT`` would be part of the + INSERT statement in either case. + + * By establishing the ``retval=True`` flag, the value we return + from the function will be returned by the attribute getter. + Without this flag, the event is assumed to be a passive observer + and the return value of our function is ignored. + + * The ``propagate=True`` flag is significant if the mapped class + includes inheriting subclasses, which would also make use of this + event listener. Without this flag, an inheriting subclass will + not use our event handler. + + In the above example, the attribute set event + :meth:`.AttributeEvents.set` as well as the related validation feature + provided by :obj:`_orm.validates` is **not** invoked when we apply our + value to the given ``dict_``. To have these events to invoke in + response to our newly generated value, apply the value to the given + object as a normal attribute set operation:: + + SOME_CONSTANT = 3.1415926 + + @event.listens_for( + MyClass.some_attribute, "init_scalar", + retval=True, propagate=True) + def _init_some_attribute(target, dict_, value): + # will also fire off attribute set events + target.some_attribute = SOME_CONSTANT + return SOME_CONSTANT + + When multiple listeners are set up, the generation of the value + is "chained" from one listener to the next by passing the value + returned by the previous listener that specifies ``retval=True`` + as the ``value`` argument of the next listener. + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param value: the value that is to be returned before this event + listener were invoked. This value begins as the value ``None``, + however will be the return value of the previous event handler + function if multiple listeners are present. + :param dict\_: the attribute dictionary of this mapped object. + This is normally the ``__dict__`` of the object, but in all cases + represents the destination that the attribute system uses to get + at the actual value of this attribute. Placing the value in this + dictionary has the effect that the value will be used in the + INSERT statement generated by the unit of work. + + + .. seealso:: + + :meth:`.AttributeEvents.init_collection` - collection version + of this event + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + :ref:`examples_instrumentation` - see the + ``active_column_defaults.py`` example. + + """ + + def init_collection( + self, + target: _O, + collection: Type[Collection[Any]], + collection_adapter: CollectionAdapter, + ) -> None: + """Receive a 'collection init' event. + + This event is triggered for a collection-based attribute, when + the initial "empty collection" is first generated for a blank + attribute, as well as for when the collection is replaced with + a new one, such as via a set event. + + E.g., given that ``User.addresses`` is a relationship-based + collection, the event is triggered here:: + + u1 = User() + u1.addresses.append(a1) # <- new collection + + and also during replace operations:: + + u1.addresses = [a2, a3] # <- new collection + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + :param collection: the new collection. This will always be generated + from what was specified as + :paramref:`_orm.relationship.collection_class`, and will always + be empty. + :param collection_adapter: the :class:`.CollectionAdapter` that will + mediate internal access to the collection. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + :meth:`.AttributeEvents.init_scalar` - "scalar" version of this + event. + + """ + + def dispose_collection( + self, + target: _O, + collection: Collection[Any], + collection_adapter: CollectionAdapter, + ) -> None: + """Receive a 'collection dispose' event. + + This event is triggered for a collection-based attribute when + a collection is replaced, that is:: + + u1.addresses.append(a1) + + u1.addresses = [a2, a3] # <- old collection is disposed + + The old collection received will contain its previous contents. + + .. versionchanged:: 1.2 The collection passed to + :meth:`.AttributeEvents.dispose_collection` will now have its + contents before the dispose intact; previously, the collection + would be empty. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + def modified(self, target: _O, initiator: Event) -> None: + """Receive a 'modified' event. + + This event is triggered when the :func:`.attributes.flag_modified` + function is used to trigger a modify event on an attribute without + any specific value being set. + + .. versionadded:: 1.2 + + :param target: the object instance receiving the event. + If the listener is registered with ``raw=True``, this will + be the :class:`.InstanceState` object. + + :param initiator: An instance of :class:`.attributes.Event` + representing the initiation of the event. + + .. seealso:: + + :class:`.AttributeEvents` - background on listener options such + as propagation to subclasses. + + """ + + +class QueryEvents(event.Events[Query[Any]]): + """Represent events within the construction of a :class:`_query.Query` + object. + + .. legacy:: The :class:`_orm.QueryEvents` event methods are legacy + as of SQLAlchemy 2.0, and only apply to direct use of the + :class:`_orm.Query` object. They are not used for :term:`2.0 style` + statements. For events to intercept and modify 2.0 style ORM use, + use the :meth:`_orm.SessionEvents.do_orm_execute` hook. + + + The :class:`_orm.QueryEvents` hooks are now superseded by the + :meth:`_orm.SessionEvents.do_orm_execute` event hook. + + """ + + _target_class_doc = "SomeQuery" + _dispatch_target = Query + + def before_compile(self, query: Query[Any]) -> None: + """Receive the :class:`_query.Query` + object before it is composed into a + core :class:`_expression.Select` object. + + .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile` event + is superseded by the much more capable + :meth:`_orm.SessionEvents.do_orm_execute` hook. In version 1.4, + the :meth:`_orm.QueryEvents.before_compile` event is **no longer + used** for ORM-level attribute loads, such as loads of deferred + or expired attributes as well as relationship loaders. See the + new examples in :ref:`examples_session_orm_events` which + illustrate new ways of intercepting and modifying ORM queries + for the most common purpose of adding arbitrary filter criteria. + + + This event is intended to allow changes to the query given:: + + @event.listens_for(Query, "before_compile", retval=True) + def no_deleted(query): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + return query + + The event should normally be listened with the ``retval=True`` + parameter set, so that the modified query may be returned. + + The :meth:`.QueryEvents.before_compile` event by default + will disallow "baked" queries from caching a query, if the event + hook returns a new :class:`_query.Query` object. + This affects both direct + use of the baked query extension as well as its operation within + lazy loaders and eager loaders for relationships. In order to + re-establish the query being cached, apply the event adding the + ``bake_ok`` flag:: + + @event.listens_for( + Query, "before_compile", retval=True, bake_ok=True) + def my_event(query): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + return query + + When ``bake_ok`` is set to True, the event hook will only be invoked + once, and not called for subsequent invocations of a particular query + that is being cached. + + .. versionadded:: 1.3.11 - added the "bake_ok" flag to the + :meth:`.QueryEvents.before_compile` event and disallowed caching via + the "baked" extension from occurring for event handlers that + return a new :class:`_query.Query` object if this flag is not set. + + .. seealso:: + + :meth:`.QueryEvents.before_compile_update` + + :meth:`.QueryEvents.before_compile_delete` + + :ref:`baked_with_before_compile` + + """ + + def before_compile_update( + self, query: Query[Any], update_context: BulkUpdate + ) -> None: + """Allow modifications to the :class:`_query.Query` object within + :meth:`_query.Query.update`. + + .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_update` + event is superseded by the much more capable + :meth:`_orm.SessionEvents.do_orm_execute` hook. + + Like the :meth:`.QueryEvents.before_compile` event, if the event + is to be used to alter the :class:`_query.Query` object, it should + be configured with ``retval=True``, and the modified + :class:`_query.Query` object returned, as in :: + + @event.listens_for(Query, "before_compile_update", retval=True) + def no_deleted(query, update_context): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + + update_context.values['timestamp'] = datetime.utcnow() + return query + + The ``.values`` dictionary of the "update context" object can also + be modified in place as illustrated above. + + :param query: a :class:`_query.Query` instance; this is also + the ``.query`` attribute of the given "update context" + object. + + :param update_context: an "update context" object which is + the same kind of object as described in + :paramref:`.QueryEvents.after_bulk_update.update_context`. + The object has a ``.values`` attribute in an UPDATE context which is + the dictionary of parameters passed to :meth:`_query.Query.update`. + This + dictionary can be modified to alter the VALUES clause of the + resulting UPDATE statement. + + .. versionadded:: 1.2.17 + + .. seealso:: + + :meth:`.QueryEvents.before_compile` + + :meth:`.QueryEvents.before_compile_delete` + + + """ + + def before_compile_delete( + self, query: Query[Any], delete_context: BulkDelete + ) -> None: + """Allow modifications to the :class:`_query.Query` object within + :meth:`_query.Query.delete`. + + .. deprecated:: 1.4 The :meth:`_orm.QueryEvents.before_compile_delete` + event is superseded by the much more capable + :meth:`_orm.SessionEvents.do_orm_execute` hook. + + Like the :meth:`.QueryEvents.before_compile` event, this event + should be configured with ``retval=True``, and the modified + :class:`_query.Query` object returned, as in :: + + @event.listens_for(Query, "before_compile_delete", retval=True) + def no_deleted(query, delete_context): + for desc in query.column_descriptions: + if desc['type'] is User: + entity = desc['entity'] + query = query.filter(entity.deleted == False) + return query + + :param query: a :class:`_query.Query` instance; this is also + the ``.query`` attribute of the given "delete context" + object. + + :param delete_context: a "delete context" object which is + the same kind of object as described in + :paramref:`.QueryEvents.after_bulk_delete.delete_context`. + + .. versionadded:: 1.2.17 + + .. seealso:: + + :meth:`.QueryEvents.before_compile` + + :meth:`.QueryEvents.before_compile_update` + + + """ + + @classmethod + def _listen( + cls, + event_key: _EventKey[_ET], + retval: bool = False, + bake_ok: bool = False, + **kw: Any, + ) -> None: + fn = event_key._listen_fn + + if not retval: + + def wrap(*arg: Any, **kw: Any) -> Any: + if not retval: + query = arg[0] + fn(*arg, **kw) + return query + else: + return fn(*arg, **kw) + + event_key = event_key.with_wrapper(wrap) + else: + # don't assume we can apply an attribute to the callable + def wrap(*arg: Any, **kw: Any) -> Any: + return fn(*arg, **kw) + + event_key = event_key.with_wrapper(wrap) + + wrap._bake_ok = bake_ok # type: ignore [attr-defined] + + event_key.base_listen(**kw) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/exc.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/exc.py new file mode 100644 index 0000000..39dd540 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/exc.py @@ -0,0 +1,228 @@ +# orm/exc.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""SQLAlchemy ORM exceptions.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar + +from .util import _mapper_property_as_plain_name +from .. import exc as sa_exc +from .. import util +from ..exc import MultipleResultsFound # noqa +from ..exc import NoResultFound # noqa + +if TYPE_CHECKING: + from .interfaces import LoaderStrategy + from .interfaces import MapperProperty + from .state import InstanceState + +_T = TypeVar("_T", bound=Any) + +NO_STATE = (AttributeError, KeyError) +"""Exception types that may be raised by instrumentation implementations.""" + + +class StaleDataError(sa_exc.SQLAlchemyError): + """An operation encountered database state that is unaccounted for. + + Conditions which cause this to happen include: + + * A flush may have attempted to update or delete rows + and an unexpected number of rows were matched during + the UPDATE or DELETE statement. Note that when + version_id_col is used, rows in UPDATE or DELETE statements + are also matched against the current known version + identifier. + + * A mapped object with version_id_col was refreshed, + and the version number coming back from the database does + not match that of the object itself. + + * A object is detached from its parent object, however + the object was previously attached to a different parent + identity which was garbage collected, and a decision + cannot be made if the new parent was really the most + recent "parent". + + """ + + +ConcurrentModificationError = StaleDataError + + +class FlushError(sa_exc.SQLAlchemyError): + """A invalid condition was detected during flush().""" + + +class UnmappedError(sa_exc.InvalidRequestError): + """Base for exceptions that involve expected mappings not present.""" + + +class ObjectDereferencedError(sa_exc.SQLAlchemyError): + """An operation cannot complete due to an object being garbage + collected. + + """ + + +class DetachedInstanceError(sa_exc.SQLAlchemyError): + """An attempt to access unloaded attributes on a + mapped instance that is detached.""" + + code = "bhk3" + + +class UnmappedInstanceError(UnmappedError): + """An mapping operation was requested for an unknown instance.""" + + @util.preload_module("sqlalchemy.orm.base") + def __init__(self, obj: object, msg: Optional[str] = None): + base = util.preloaded.orm_base + + if not msg: + try: + base.class_mapper(type(obj)) + name = _safe_cls_name(type(obj)) + msg = ( + "Class %r is mapped, but this instance lacks " + "instrumentation. This occurs when the instance " + "is created before sqlalchemy.orm.mapper(%s) " + "was called." % (name, name) + ) + except UnmappedClassError: + msg = f"Class '{_safe_cls_name(type(obj))}' is not mapped" + if isinstance(obj, type): + msg += ( + "; was a class (%s) supplied where an instance was " + "required?" % _safe_cls_name(obj) + ) + UnmappedError.__init__(self, msg) + + def __reduce__(self) -> Any: + return self.__class__, (None, self.args[0]) + + +class UnmappedClassError(UnmappedError): + """An mapping operation was requested for an unknown class.""" + + def __init__(self, cls: Type[_T], msg: Optional[str] = None): + if not msg: + msg = _default_unmapped(cls) + UnmappedError.__init__(self, msg) + + def __reduce__(self) -> Any: + return self.__class__, (None, self.args[0]) + + +class ObjectDeletedError(sa_exc.InvalidRequestError): + """A refresh operation failed to retrieve the database + row corresponding to an object's known primary key identity. + + A refresh operation proceeds when an expired attribute is + accessed on an object, or when :meth:`_query.Query.get` is + used to retrieve an object which is, upon retrieval, detected + as expired. A SELECT is emitted for the target row + based on primary key; if no row is returned, this + exception is raised. + + The true meaning of this exception is simply that + no row exists for the primary key identifier associated + with a persistent object. The row may have been + deleted, or in some cases the primary key updated + to a new value, outside of the ORM's management of the target + object. + + """ + + @util.preload_module("sqlalchemy.orm.base") + def __init__(self, state: InstanceState[Any], msg: Optional[str] = None): + base = util.preloaded.orm_base + + if not msg: + msg = ( + "Instance '%s' has been deleted, or its " + "row is otherwise not present." % base.state_str(state) + ) + + sa_exc.InvalidRequestError.__init__(self, msg) + + def __reduce__(self) -> Any: + return self.__class__, (None, self.args[0]) + + +class UnmappedColumnError(sa_exc.InvalidRequestError): + """Mapping operation was requested on an unknown column.""" + + +class LoaderStrategyException(sa_exc.InvalidRequestError): + """A loader strategy for an attribute does not exist.""" + + def __init__( + self, + applied_to_property_type: Type[Any], + requesting_property: MapperProperty[Any], + applies_to: Optional[Type[MapperProperty[Any]]], + actual_strategy_type: Optional[Type[LoaderStrategy]], + strategy_key: Tuple[Any, ...], + ): + if actual_strategy_type is None: + sa_exc.InvalidRequestError.__init__( + self, + "Can't find strategy %s for %s" + % (strategy_key, requesting_property), + ) + else: + assert applies_to is not None + sa_exc.InvalidRequestError.__init__( + self, + 'Can\'t apply "%s" strategy to property "%s", ' + 'which is a "%s"; this loader strategy is intended ' + 'to be used with a "%s".' + % ( + util.clsname_as_plain_name(actual_strategy_type), + requesting_property, + _mapper_property_as_plain_name(applied_to_property_type), + _mapper_property_as_plain_name(applies_to), + ), + ) + + +def _safe_cls_name(cls: Type[Any]) -> str: + cls_name: Optional[str] + try: + cls_name = ".".join((cls.__module__, cls.__name__)) + except AttributeError: + cls_name = getattr(cls, "__name__", None) + if cls_name is None: + cls_name = repr(cls) + return cls_name + + +@util.preload_module("sqlalchemy.orm.base") +def _default_unmapped(cls: Type[Any]) -> Optional[str]: + base = util.preloaded.orm_base + + try: + mappers = base.manager_of_class(cls).mappers # type: ignore + except ( + UnmappedClassError, + TypeError, + ) + NO_STATE: + mappers = {} + name = _safe_cls_name(cls) + + if not mappers: + return f"Class '{name}' is not mapped" + else: + return None diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/identity.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/identity.py new file mode 100644 index 0000000..23682f7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/identity.py @@ -0,0 +1,302 @@ +# orm/identity.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +import weakref + +from . import util as orm_util +from .. import exc as sa_exc + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from .state import InstanceState + + +_T = TypeVar("_T", bound=Any) + +_O = TypeVar("_O", bound=object) + + +class IdentityMap: + _wr: weakref.ref[IdentityMap] + + _dict: Dict[_IdentityKeyType[Any], Any] + _modified: Set[InstanceState[Any]] + + def __init__(self) -> None: + self._dict = {} + self._modified = set() + self._wr = weakref.ref(self) + + def _kill(self) -> None: + self._add_unpresent = _killed # type: ignore + + def all_states(self) -> List[InstanceState[Any]]: + raise NotImplementedError() + + def contains_state(self, state: InstanceState[Any]) -> bool: + raise NotImplementedError() + + def __contains__(self, key: _IdentityKeyType[Any]) -> bool: + raise NotImplementedError() + + def safe_discard(self, state: InstanceState[Any]) -> None: + raise NotImplementedError() + + def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: + raise NotImplementedError() + + def get( + self, key: _IdentityKeyType[_O], default: Optional[_O] = None + ) -> Optional[_O]: + raise NotImplementedError() + + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + + def keys(self) -> Iterable[_IdentityKeyType[Any]]: + return self._dict.keys() + + def values(self) -> Iterable[object]: + raise NotImplementedError() + + def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]: + raise NotImplementedError() + + def add(self, state: InstanceState[Any]) -> bool: + raise NotImplementedError() + + def _fast_discard(self, state: InstanceState[Any]) -> None: + raise NotImplementedError() + + def _add_unpresent( + self, state: InstanceState[Any], key: _IdentityKeyType[Any] + ) -> None: + """optional inlined form of add() which can assume item isn't present + in the map""" + self.add(state) + + def _manage_incoming_state(self, state: InstanceState[Any]) -> None: + state._instance_dict = self._wr + + if state.modified: + self._modified.add(state) + + def _manage_removed_state(self, state: InstanceState[Any]) -> None: + del state._instance_dict + if state.modified: + self._modified.discard(state) + + def _dirty_states(self) -> Set[InstanceState[Any]]: + return self._modified + + def check_modified(self) -> bool: + """return True if any InstanceStates present have been marked + as 'modified'. + + """ + return bool(self._modified) + + def has_key(self, key: _IdentityKeyType[Any]) -> bool: + return key in self + + def __len__(self) -> int: + return len(self._dict) + + +class WeakInstanceDict(IdentityMap): + _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]] + + def __getitem__(self, key: _IdentityKeyType[_O]) -> _O: + state = cast("InstanceState[_O]", self._dict[key]) + o = state.obj() + if o is None: + raise KeyError(key) + return o + + def __contains__(self, key: _IdentityKeyType[Any]) -> bool: + try: + if key in self._dict: + state = self._dict[key] + o = state.obj() + else: + return False + except KeyError: + return False + else: + return o is not None + + def contains_state(self, state: InstanceState[Any]) -> bool: + if state.key in self._dict: + if TYPE_CHECKING: + assert state.key is not None + try: + return self._dict[state.key] is state + except KeyError: + return False + else: + return False + + def replace( + self, state: InstanceState[Any] + ) -> Optional[InstanceState[Any]]: + assert state.key is not None + if state.key in self._dict: + try: + existing = existing_non_none = self._dict[state.key] + except KeyError: + # catch gc removed the key after we just checked for it + existing = None + else: + if existing_non_none is not state: + self._manage_removed_state(existing_non_none) + else: + return None + else: + existing = None + + self._dict[state.key] = state + self._manage_incoming_state(state) + return existing + + def add(self, state: InstanceState[Any]) -> bool: + key = state.key + assert key is not None + # inline of self.__contains__ + if key in self._dict: + try: + existing_state = self._dict[key] + except KeyError: + # catch gc removed the key after we just checked for it + pass + else: + if existing_state is not state: + o = existing_state.obj() + if o is not None: + raise sa_exc.InvalidRequestError( + "Can't attach instance " + "%s; another instance with key %s is already " + "present in this session." + % (orm_util.state_str(state), state.key) + ) + else: + return False + self._dict[key] = state + self._manage_incoming_state(state) + return True + + def _add_unpresent( + self, state: InstanceState[Any], key: _IdentityKeyType[Any] + ) -> None: + # inlined form of add() called by loading.py + self._dict[key] = state + state._instance_dict = self._wr + + def fast_get_state( + self, key: _IdentityKeyType[_O] + ) -> Optional[InstanceState[_O]]: + return self._dict.get(key) + + def get( + self, key: _IdentityKeyType[_O], default: Optional[_O] = None + ) -> Optional[_O]: + if key not in self._dict: + return default + try: + state = cast("InstanceState[_O]", self._dict[key]) + except KeyError: + # catch gc removed the key after we just checked for it + return default + else: + o = state.obj() + if o is None: + return default + return o + + def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]: + values = self.all_states() + result = [] + for state in values: + value = state.obj() + key = state.key + assert key is not None + if value is not None: + result.append((key, value)) + return result + + def values(self) -> List[object]: + values = self.all_states() + result = [] + for state in values: + value = state.obj() + if value is not None: + result.append(value) + + return result + + def __iter__(self) -> Iterator[_IdentityKeyType[Any]]: + return iter(self.keys()) + + def all_states(self) -> List[InstanceState[Any]]: + return list(self._dict.values()) + + def _fast_discard(self, state: InstanceState[Any]) -> None: + # used by InstanceState for state being + # GC'ed, inlines _managed_removed_state + key = state.key + assert key is not None + try: + st = self._dict[key] + except KeyError: + # catch gc removed the key after we just checked for it + pass + else: + if st is state: + self._dict.pop(key, None) + + def discard(self, state: InstanceState[Any]) -> None: + self.safe_discard(state) + + def safe_discard(self, state: InstanceState[Any]) -> None: + key = state.key + if key in self._dict: + assert key is not None + try: + st = self._dict[key] + except KeyError: + # catch gc removed the key after we just checked for it + pass + else: + if st is state: + self._dict.pop(key, None) + self._manage_removed_state(state) + + +def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn: + # external function to avoid creating cycles when assigned to + # the IdentityMap + raise sa_exc.InvalidRequestError( + "Object %s cannot be converted to 'persistent' state, as this " + "identity map is no longer valid. Has the owning Session " + "been closed?" % orm_util.state_str(state), + code="lkrp", + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/instrumentation.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/instrumentation.py new file mode 100644 index 0000000..e9fe843 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/instrumentation.py @@ -0,0 +1,754 @@ +# orm/instrumentation.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Defines SQLAlchemy's system of class instrumentation. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +instrumentation.py deals with registration of end-user classes +for state tracking. It interacts closely with state.py +and attributes.py which establish per-instance and per-class-attribute +instrumentation, respectively. + +The class instrumentation system can be customized on a per-class +or global basis using the :mod:`sqlalchemy.ext.instrumentation` +module, which provides the means to build and specify +alternate instrumentation forms. + +.. versionchanged: 0.8 + The instrumentation extension system was moved out of the + ORM and into the external :mod:`sqlalchemy.ext.instrumentation` + package. When that package is imported, it installs + itself within sqlalchemy.orm so that its more comprehensive + resolution mechanics take effect. + +""" + + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import base +from . import collections +from . import exc +from . import interfaces +from . import state +from ._typing import _O +from .attributes import _is_collection_attribute_impl +from .. import util +from ..event import EventTarget +from ..util import HasMemoized +from ..util.typing import Literal +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ._typing import _RegistryType + from .attributes import AttributeImpl + from .attributes import QueryableAttribute + from .collections import _AdaptedCollectionProtocol + from .collections import _CollectionFactoryType + from .decl_base import _MapperConfig + from .events import InstanceEvents + from .mapper import Mapper + from .state import InstanceState + from ..event import dispatcher + +_T = TypeVar("_T", bound=Any) +DEL_ATTR = util.symbol("DEL_ATTR") + + +class _ExpiredAttributeLoaderProto(Protocol): + def __call__( + self, + state: state.InstanceState[Any], + toload: Set[str], + passive: base.PassiveFlag, + ) -> None: ... + + +class _ManagerFactory(Protocol): + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: ... + + +class ClassManager( + HasMemoized, + Dict[str, "QueryableAttribute[Any]"], + Generic[_O], + EventTarget, +): + """Tracks state information at the class level.""" + + dispatch: dispatcher[ClassManager[_O]] + + MANAGER_ATTR = base.DEFAULT_MANAGER_ATTR + STATE_ATTR = base.DEFAULT_STATE_ATTR + + _state_setter = staticmethod(util.attrsetter(STATE_ATTR)) + + expired_attribute_loader: _ExpiredAttributeLoaderProto + "previously known as deferred_scalar_loader" + + init_method: Optional[Callable[..., None]] + original_init: Optional[Callable[..., None]] = None + + factory: Optional[_ManagerFactory] + + declarative_scan: Optional[weakref.ref[_MapperConfig]] = None + + registry: _RegistryType + + if not TYPE_CHECKING: + # starts as None during setup + registry = None + + class_: Type[_O] + + _bases: List[ClassManager[Any]] + + @property + @util.deprecated( + "1.4", + message="The ClassManager.deferred_scalar_loader attribute is now " + "named expired_attribute_loader", + ) + def deferred_scalar_loader(self): + return self.expired_attribute_loader + + @deferred_scalar_loader.setter + @util.deprecated( + "1.4", + message="The ClassManager.deferred_scalar_loader attribute is now " + "named expired_attribute_loader", + ) + def deferred_scalar_loader(self, obj): + self.expired_attribute_loader = obj + + def __init__(self, class_): + self.class_ = class_ + self.info = {} + self.new_init = None + self.local_attrs = {} + self.originals = {} + self._finalized = False + self.factory = None + self.init_method = None + + self._bases = [ + mgr + for mgr in cast( + "List[Optional[ClassManager[Any]]]", + [ + opt_manager_of_class(base) + for base in self.class_.__bases__ + if isinstance(base, type) + ], + ) + if mgr is not None + ] + + for base_ in self._bases: + self.update(base_) + + cast( + "InstanceEvents", self.dispatch._events + )._new_classmanager_instance(class_, self) + + for basecls in class_.__mro__: + mgr = opt_manager_of_class(basecls) + if mgr is not None: + self.dispatch._update(mgr.dispatch) + + self.manage() + + if "__del__" in class_.__dict__: + util.warn( + "__del__() method on class %s will " + "cause unreachable cycles and memory leaks, " + "as SQLAlchemy instrumentation often creates " + "reference cycles. Please remove this method." % class_ + ) + + def _update_state( + self, + finalize: bool = False, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[ + _ExpiredAttributeLoaderProto + ] = None, + init_method: Optional[Callable[..., None]] = None, + ) -> None: + if mapper: + self.mapper = mapper # + if registry: + registry._add_manager(self) + if declarative_scan: + self.declarative_scan = weakref.ref(declarative_scan) + if expired_attribute_loader: + self.expired_attribute_loader = expired_attribute_loader + + if init_method: + assert not self._finalized, ( + "class is already instrumented, " + "init_method %s can't be applied" % init_method + ) + self.init_method = init_method + + if not self._finalized: + self.original_init = ( + self.init_method + if self.init_method is not None + and self.class_.__init__ is object.__init__ + else self.class_.__init__ + ) + + if finalize and not self._finalized: + self._finalize() + + def _finalize(self) -> None: + if self._finalized: + return + self._finalized = True + + self._instrument_init() + + _instrumentation_factory.dispatch.class_instrument(self.class_) + + def __hash__(self) -> int: # type: ignore[override] + return id(self) + + def __eq__(self, other: Any) -> bool: + return other is self + + @property + def is_mapped(self) -> bool: + return "mapper" in self.__dict__ + + @HasMemoized.memoized_attribute + def _all_key_set(self): + return frozenset(self) + + @HasMemoized.memoized_attribute + def _collection_impl_keys(self): + return frozenset( + [attr.key for attr in self.values() if attr.impl.collection] + ) + + @HasMemoized.memoized_attribute + def _scalar_loader_impls(self): + return frozenset( + [ + attr.impl + for attr in self.values() + if attr.impl.accepts_scalar_loader + ] + ) + + @HasMemoized.memoized_attribute + def _loader_impls(self): + return frozenset([attr.impl for attr in self.values()]) + + @util.memoized_property + def mapper(self) -> Mapper[_O]: + # raises unless self.mapper has been assigned + raise exc.UnmappedClassError(self.class_) + + def _all_sqla_attributes(self, exclude=None): + """return an iterator of all classbound attributes that are + implement :class:`.InspectionAttr`. + + This includes :class:`.QueryableAttribute` as well as extension + types such as :class:`.hybrid_property` and + :class:`.AssociationProxy`. + + """ + + found: Dict[str, Any] = {} + + # constraints: + # 1. yield keys in cls.__dict__ order + # 2. if a subclass has the same key as a superclass, include that + # key as part of the ordering of the superclass, because an + # overridden key is usually installed by the mapper which is going + # on a different ordering + # 3. don't use getattr() as this fires off descriptors + + for supercls in self.class_.__mro__[0:-1]: + inherits = supercls.__mro__[1] + for key in supercls.__dict__: + found.setdefault(key, supercls) + if key in inherits.__dict__: + continue + val = found[key].__dict__[key] + if ( + isinstance(val, interfaces.InspectionAttr) + and val.is_attribute + ): + yield key, val + + def _get_class_attr_mro(self, key, default=None): + """return an attribute on the class without tripping it.""" + + for supercls in self.class_.__mro__: + if key in supercls.__dict__: + return supercls.__dict__[key] + else: + return default + + def _attr_has_impl(self, key: str) -> bool: + """Return True if the given attribute is fully initialized. + + i.e. has an impl. + """ + + return key in self and self[key].impl is not None + + def _subclass_manager(self, cls: Type[_T]) -> ClassManager[_T]: + """Create a new ClassManager for a subclass of this ClassManager's + class. + + This is called automatically when attributes are instrumented so that + the attributes can be propagated to subclasses against their own + class-local manager, without the need for mappers etc. to have already + pre-configured managers for the full class hierarchy. Mappers + can post-configure the auto-generated ClassManager when needed. + + """ + return register_class(cls, finalize=False) + + def _instrument_init(self): + self.new_init = _generate_init(self.class_, self, self.original_init) + self.install_member("__init__", self.new_init) + + @util.memoized_property + def _state_constructor(self) -> Type[state.InstanceState[_O]]: + self.dispatch.first_init(self, self.class_) + return state.InstanceState + + def manage(self): + """Mark this instance as the manager for its class.""" + + setattr(self.class_, self.MANAGER_ATTR, self) + + @util.hybridmethod + def manager_getter(self): + return _default_manager_getter + + @util.hybridmethod + def state_getter(self): + """Return a (instance) -> InstanceState callable. + + "state getter" callables should raise either KeyError or + AttributeError if no InstanceState could be found for the + instance. + """ + + return _default_state_getter + + @util.hybridmethod + def dict_getter(self): + return _default_dict_getter + + def instrument_attribute( + self, + key: str, + inst: QueryableAttribute[Any], + propagated: bool = False, + ) -> None: + if propagated: + if key in self.local_attrs: + return # don't override local attr with inherited attr + else: + self.local_attrs[key] = inst + self.install_descriptor(key, inst) + self._reset_memoizations() + self[key] = inst + + for cls in self.class_.__subclasses__(): + manager = self._subclass_manager(cls) + manager.instrument_attribute(key, inst, True) + + def subclass_managers(self, recursive): + for cls in self.class_.__subclasses__(): + mgr = opt_manager_of_class(cls) + if mgr is not None and mgr is not self: + yield mgr + if recursive: + yield from mgr.subclass_managers(True) + + def post_configure_attribute(self, key): + _instrumentation_factory.dispatch.attribute_instrument( + self.class_, key, self[key] + ) + + def uninstrument_attribute(self, key, propagated=False): + if key not in self: + return + if propagated: + if key in self.local_attrs: + return # don't get rid of local attr + else: + del self.local_attrs[key] + self.uninstall_descriptor(key) + self._reset_memoizations() + del self[key] + for cls in self.class_.__subclasses__(): + manager = opt_manager_of_class(cls) + if manager: + manager.uninstrument_attribute(key, True) + + def unregister(self) -> None: + """remove all instrumentation established by this ClassManager.""" + + for key in list(self.originals): + self.uninstall_member(key) + + self.mapper = None + self.dispatch = None # type: ignore + self.new_init = None + self.info.clear() + + for key in list(self): + if key in self.local_attrs: + self.uninstrument_attribute(key) + + if self.MANAGER_ATTR in self.class_.__dict__: + delattr(self.class_, self.MANAGER_ATTR) + + def install_descriptor( + self, key: str, inst: QueryableAttribute[Any] + ) -> None: + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError( + "%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key + ) + setattr(self.class_, key, inst) + + def uninstall_descriptor(self, key: str) -> None: + delattr(self.class_, key) + + def install_member(self, key: str, implementation: Any) -> None: + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError( + "%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key + ) + self.originals.setdefault(key, self.class_.__dict__.get(key, DEL_ATTR)) + setattr(self.class_, key, implementation) + + def uninstall_member(self, key: str) -> None: + original = self.originals.pop(key, None) + if original is not DEL_ATTR: + setattr(self.class_, key, original) + else: + delattr(self.class_, key) + + def instrument_collection_class( + self, key: str, collection_class: Type[Collection[Any]] + ) -> _CollectionFactoryType: + return collections.prepare_instrumentation(collection_class) + + def initialize_collection( + self, + key: str, + state: InstanceState[_O], + factory: _CollectionFactoryType, + ) -> Tuple[collections.CollectionAdapter, _AdaptedCollectionProtocol]: + user_data = factory() + impl = self.get_impl(key) + assert _is_collection_attribute_impl(impl) + adapter = collections.CollectionAdapter(impl, state, user_data) + return adapter, user_data + + def is_instrumented(self, key: str, search: bool = False) -> bool: + if search: + return key in self + else: + return key in self.local_attrs + + def get_impl(self, key: str) -> AttributeImpl: + return self[key].impl + + @property + def attributes(self) -> Iterable[Any]: + return iter(self.values()) + + # InstanceState management + + def new_instance(self, state: Optional[InstanceState[_O]] = None) -> _O: + # here, we would prefer _O to be bound to "object" + # so that mypy sees that __new__ is present. currently + # it's bound to Any as there were other problems not having + # it that way but these can be revisited + instance = self.class_.__new__(self.class_) + if state is None: + state = self._state_constructor(instance, self) + self._state_setter(instance, state) + return instance + + def setup_instance( + self, instance: _O, state: Optional[InstanceState[_O]] = None + ) -> None: + if state is None: + state = self._state_constructor(instance, self) + self._state_setter(instance, state) + + def teardown_instance(self, instance: _O) -> None: + delattr(instance, self.STATE_ATTR) + + def _serialize( + self, state: InstanceState[_O], state_dict: Dict[str, Any] + ) -> _SerializeManager: + return _SerializeManager(state, state_dict) + + def _new_state_if_none( + self, instance: _O + ) -> Union[Literal[False], InstanceState[_O]]: + """Install a default InstanceState if none is present. + + A private convenience method used by the __init__ decorator. + + """ + if hasattr(instance, self.STATE_ATTR): + return False + elif self.class_ is not instance.__class__ and self.is_mapped: + # this will create a new ClassManager for the + # subclass, without a mapper. This is likely a + # user error situation but allow the object + # to be constructed, so that it is usable + # in a non-ORM context at least. + return self._subclass_manager( + instance.__class__ + )._new_state_if_none(instance) + else: + state = self._state_constructor(instance, self) + self._state_setter(instance, state) + return state + + def has_state(self, instance: _O) -> bool: + return hasattr(instance, self.STATE_ATTR) + + def has_parent( + self, state: InstanceState[_O], key: str, optimistic: bool = False + ) -> bool: + """TODO""" + return self.get_impl(key).hasparent(state, optimistic=optimistic) + + def __bool__(self) -> bool: + """All ClassManagers are non-zero regardless of attribute state.""" + return True + + def __repr__(self) -> str: + return "<%s of %r at %x>" % ( + self.__class__.__name__, + self.class_, + id(self), + ) + + +class _SerializeManager: + """Provide serialization of a :class:`.ClassManager`. + + The :class:`.InstanceState` uses ``__init__()`` on serialize + and ``__call__()`` on deserialize. + + """ + + def __init__(self, state: state.InstanceState[Any], d: Dict[str, Any]): + self.class_ = state.class_ + manager = state.manager + manager.dispatch.pickle(state, d) + + def __call__(self, state, inst, state_dict): + state.manager = manager = opt_manager_of_class(self.class_) + if manager is None: + raise exc.UnmappedInstanceError( + inst, + "Cannot deserialize object of type %r - " + "no mapper() has " + "been configured for this class within the current " + "Python process!" % self.class_, + ) + elif manager.is_mapped and not manager.mapper.configured: + manager.mapper._check_configure() + + # setup _sa_instance_state ahead of time so that + # unpickle events can access the object normally. + # see [ticket:2362] + if inst is not None: + manager.setup_instance(inst, state) + manager.dispatch.unpickle(state, state_dict) + + +class InstrumentationFactory(EventTarget): + """Factory for new ClassManager instances.""" + + dispatch: dispatcher[InstrumentationFactory] + + def create_manager_for_cls(self, class_: Type[_O]) -> ClassManager[_O]: + assert class_ is not None + assert opt_manager_of_class(class_) is None + + # give a more complicated subclass + # a chance to do what it wants here + manager, factory = self._locate_extended_factory(class_) + + if factory is None: + factory = ClassManager + manager = ClassManager(class_) + else: + assert manager is not None + + self._check_conflicts(class_, factory) + + manager.factory = factory + + return manager + + def _locate_extended_factory( + self, class_: Type[_O] + ) -> Tuple[Optional[ClassManager[_O]], Optional[_ManagerFactory]]: + """Overridden by a subclass to do an extended lookup.""" + return None, None + + def _check_conflicts( + self, class_: Type[_O], factory: Callable[[Type[_O]], ClassManager[_O]] + ) -> None: + """Overridden by a subclass to test for conflicting factories.""" + + def unregister(self, class_: Type[_O]) -> None: + manager = manager_of_class(class_) + manager.unregister() + self.dispatch.class_uninstrument(class_) + + +# this attribute is replaced by sqlalchemy.ext.instrumentation +# when imported. +_instrumentation_factory = InstrumentationFactory() + +# these attributes are replaced by sqlalchemy.ext.instrumentation +# when a non-standard InstrumentationManager class is first +# used to instrument a class. +instance_state = _default_state_getter = base.instance_state + +instance_dict = _default_dict_getter = base.instance_dict + +manager_of_class = _default_manager_getter = base.manager_of_class +opt_manager_of_class = _default_opt_manager_getter = base.opt_manager_of_class + + +def register_class( + class_: Type[_O], + finalize: bool = True, + mapper: Optional[Mapper[_O]] = None, + registry: Optional[_RegistryType] = None, + declarative_scan: Optional[_MapperConfig] = None, + expired_attribute_loader: Optional[_ExpiredAttributeLoaderProto] = None, + init_method: Optional[Callable[..., None]] = None, +) -> ClassManager[_O]: + """Register class instrumentation. + + Returns the existing or newly created class manager. + + """ + + manager = opt_manager_of_class(class_) + if manager is None: + manager = _instrumentation_factory.create_manager_for_cls(class_) + manager._update_state( + mapper=mapper, + registry=registry, + declarative_scan=declarative_scan, + expired_attribute_loader=expired_attribute_loader, + init_method=init_method, + finalize=finalize, + ) + + return manager + + +def unregister_class(class_): + """Unregister class instrumentation.""" + + _instrumentation_factory.unregister(class_) + + +def is_instrumented(instance, key): + """Return True if the given attribute on the given instance is + instrumented by the attributes package. + + This function may be used regardless of instrumentation + applied directly to the class, i.e. no descriptors are required. + + """ + return manager_of_class(instance.__class__).is_instrumented( + key, search=True + ) + + +def _generate_init(class_, class_manager, original_init): + """Build an __init__ decorator that triggers ClassManager events.""" + + # TODO: we should use the ClassManager's notion of the + # original '__init__' method, once ClassManager is fixed + # to always reference that. + + if original_init is None: + original_init = class_.__init__ + + # Go through some effort here and don't change the user's __init__ + # calling signature, including the unlikely case that it has + # a return value. + # FIXME: need to juggle local names to avoid constructor argument + # clashes. + func_body = """\ +def __init__(%(apply_pos)s): + new_state = class_manager._new_state_if_none(%(self_arg)s) + if new_state: + return new_state._initialize_instance(%(apply_kw)s) + else: + return original_init(%(apply_kw)s) +""" + func_vars = util.format_argspec_init(original_init, grouped=False) + func_text = func_body % func_vars + + func_defaults = getattr(original_init, "__defaults__", None) + func_kw_defaults = getattr(original_init, "__kwdefaults__", None) + + env = locals().copy() + env["__name__"] = __name__ + exec(func_text, env) + __init__ = env["__init__"] + __init__.__doc__ = original_init.__doc__ + __init__._sa_original_init = original_init + + if func_defaults: + __init__.__defaults__ = func_defaults + if func_kw_defaults: + __init__.__kwdefaults__ = func_kw_defaults + + return __init__ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/interfaces.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/interfaces.py new file mode 100644 index 0000000..36336e7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/interfaces.py @@ -0,0 +1,1469 @@ +# orm/interfaces.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +""" + +Contains various base classes used throughout the ORM. + +Defines some key base classes prominent within the internals. + +This module and the classes within are mostly private, though some attributes +are exposed when inspecting mappings. + +""" + +from __future__ import annotations + +import collections +import dataclasses +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import Generic +from typing import Iterator +from typing import List +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import exc as orm_exc +from . import path_registry +from .base import _MappedAttribute as _MappedAttribute +from .base import EXT_CONTINUE as EXT_CONTINUE # noqa: F401 +from .base import EXT_SKIP as EXT_SKIP # noqa: F401 +from .base import EXT_STOP as EXT_STOP # noqa: F401 +from .base import InspectionAttr as InspectionAttr # noqa: F401 +from .base import InspectionAttrInfo as InspectionAttrInfo +from .base import MANYTOMANY as MANYTOMANY # noqa: F401 +from .base import MANYTOONE as MANYTOONE # noqa: F401 +from .base import NO_KEY as NO_KEY # noqa: F401 +from .base import NO_VALUE as NO_VALUE # noqa: F401 +from .base import NotExtension as NotExtension # noqa: F401 +from .base import ONETOMANY as ONETOMANY # noqa: F401 +from .base import RelationshipDirection as RelationshipDirection # noqa: F401 +from .base import SQLORMOperations +from .. import ColumnElement +from .. import exc as sa_exc +from .. import inspection +from .. import util +from ..sql import operators +from ..sql import roles +from ..sql import visitors +from ..sql.base import _NoArg +from ..sql.base import ExecutableOption +from ..sql.cache_key import HasCacheKey +from ..sql.operators import ColumnOperators +from ..sql.schema import Column +from ..sql.type_api import TypeEngine +from ..util import warn_deprecated +from ..util.typing import RODescriptorReference +from ..util.typing import TypedDict + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _ORMAdapterProto + from .attributes import InstrumentedAttribute + from .base import Mapped + from .context import _MapperEntity + from .context import ORMCompileState + from .context import QueryContext + from .decl_api import RegistryType + from .decl_base import _ClassScanMapperConfig + from .loading import _PopulatorDict + from .mapper import Mapper + from .path_registry import AbstractEntityRegistry + from .query import Query + from .session import Session + from .state import InstanceState + from .strategy_options import _LoadElement + from .util import AliasedInsp + from .util import ORMAdapter + from ..engine.result import Result + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType + from ..sql.operators import OperatorType + from ..sql.visitors import _TraverseInternalsType + from ..util.typing import _AnnotationScanType + +_StrategyKey = Tuple[Any, ...] + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + +_TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") + + +class ORMStatementRole(roles.StatementRole): + __slots__ = () + _role_name = ( + "Executable SQL or text() construct, including ORM aware objects" + ) + + +class ORMColumnsClauseRole( + roles.ColumnsClauseRole, roles.TypedColumnsClauseRole[_T] +): + __slots__ = () + _role_name = "ORM mapped entity, aliased entity, or Column expression" + + +class ORMEntityColumnsClauseRole(ORMColumnsClauseRole[_T]): + __slots__ = () + _role_name = "ORM mapped or aliased entity" + + +class ORMFromClauseRole(roles.StrictFromClauseRole): + __slots__ = () + _role_name = "ORM mapped entity, aliased entity, or FROM expression" + + +class ORMColumnDescription(TypedDict): + name: str + # TODO: add python_type and sql_type here; combining them + # into "type" is a bad idea + type: Union[Type[Any], TypeEngine[Any]] + aliased: bool + expr: _ColumnsClauseArgument[Any] + entity: Optional[_ColumnsClauseArgument[Any]] + + +class _IntrospectsAnnotations: + __slots__ = () + + @classmethod + def _mapper_property_name(cls) -> str: + return cls.__name__ + + def found_in_pep593_annotated(self) -> Any: + """return a copy of this object to use in declarative when the + object is found inside of an Annotated object.""" + + raise NotImplementedError( + f"Use of the {self._mapper_property_name()!r} " + "construct inside of an Annotated object is not yet supported." + ) + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + """Perform class-specific initializaton at early declarative scanning + time. + + .. versionadded:: 2.0 + + """ + + def _raise_for_required(self, key: str, cls: Type[Any]) -> NoReturn: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{self._mapper_property_name()}" ' + "construct are None or not present" + ) + + +class _AttributeOptions(NamedTuple): + """define Python-local attribute behavior options common to all + :class:`.MapperProperty` objects. + + Currently this includes dataclass-generation arguments. + + .. versionadded:: 2.0 + + """ + + dataclasses_init: Union[_NoArg, bool] + dataclasses_repr: Union[_NoArg, bool] + dataclasses_default: Union[_NoArg, Any] + dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] + dataclasses_compare: Union[_NoArg, bool] + dataclasses_kw_only: Union[_NoArg, bool] + + def _as_dataclass_field(self, key: str) -> Any: + """Return a ``dataclasses.Field`` object given these arguments.""" + + kw: Dict[str, Any] = {} + if self.dataclasses_default_factory is not _NoArg.NO_ARG: + kw["default_factory"] = self.dataclasses_default_factory + if self.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = self.dataclasses_default + if self.dataclasses_init is not _NoArg.NO_ARG: + kw["init"] = self.dataclasses_init + if self.dataclasses_repr is not _NoArg.NO_ARG: + kw["repr"] = self.dataclasses_repr + if self.dataclasses_compare is not _NoArg.NO_ARG: + kw["compare"] = self.dataclasses_compare + if self.dataclasses_kw_only is not _NoArg.NO_ARG: + kw["kw_only"] = self.dataclasses_kw_only + + if "default" in kw and callable(kw["default"]): + # callable defaults are ambiguous. deprecate them in favour of + # insert_default or default_factory. #9936 + warn_deprecated( + f"Callable object passed to the ``default`` parameter for " + f"attribute {key!r} in a ORM-mapped Dataclasses context is " + "ambiguous, " + "and this use will raise an error in a future release. " + "If this callable is intended to produce Core level INSERT " + "default values for an underlying ``Column``, use " + "the ``mapped_column.insert_default`` parameter instead. " + "To establish this callable as providing a default value " + "for instances of the dataclass itself, use the " + "``default_factory`` dataclasses parameter.", + "2.0", + ) + + if ( + "init" in kw + and not kw["init"] + and "default" in kw + and not callable(kw["default"]) # ignore callable defaults. #9936 + and "default_factory" not in kw # illegal but let dc.field raise + ): + # fix for #9879 + default = kw.pop("default") + kw["default_factory"] = lambda: default + + return dataclasses.field(**kw) + + @classmethod + def _get_arguments_for_make_dataclass( + cls, + key: str, + annotation: _AnnotationScanType, + mapped_container: Optional[Any], + elem: _T, + ) -> Union[ + Tuple[str, _AnnotationScanType], + Tuple[str, _AnnotationScanType, dataclasses.Field[Any]], + ]: + """given attribute key, annotation, and value from a class, return + the argument tuple we would pass to dataclasses.make_dataclass() + for this attribute. + + """ + if isinstance(elem, _DCAttributeOptions): + dc_field = elem._attribute_options._as_dataclass_field(key) + + return (key, annotation, dc_field) + elif elem is not _NoArg.NO_ARG: + # why is typing not erroring on this? + return (key, annotation, elem) + elif mapped_container is not None: + # it's Mapped[], but there's no "element", which means declarative + # did not actually do anything for this field. this shouldn't + # happen. + # previously, this would occur because _scan_attributes would + # skip a field that's on an already mapped superclass, but it + # would still include it in the annotations, leading + # to issue #8718 + + assert False, "Mapped[] received without a mapping declaration" + + else: + # plain dataclass field, not mapped. Is only possible + # if __allow_unmapped__ is set up. I can see this mode causing + # problems... + return (key, annotation) + + +_DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions( + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, +) + +_DEFAULT_READONLY_ATTRIBUTE_OPTIONS = _AttributeOptions( + False, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, + _NoArg.NO_ARG, +) + + +class _DCAttributeOptions: + """mixin for descriptors or configurational objects that include dataclass + field options. + + This includes :class:`.MapperProperty`, :class:`._MapsColumn` within + the ORM, but also includes :class:`.AssociationProxy` within ext. + Can in theory be used for other descriptors that serve a similar role + as association proxy. (*maybe* hybrids, not sure yet.) + + """ + + __slots__ = () + + _attribute_options: _AttributeOptions + """behavioral options for ORM-enabled Python attributes + + .. versionadded:: 2.0 + + """ + + _has_dataclass_arguments: bool + + +class _MapsColumns(_DCAttributeOptions, _MappedAttribute[_T]): + """interface for declarative-capable construct that delivers one or more + Column objects to the declarative process to be part of a Table. + """ + + __slots__ = () + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: + """return a MapperProperty to be assigned to the declarative mapping""" + raise NotImplementedError() + + @property + def columns_to_assign(self) -> List[Tuple[Column[_T], int]]: + """A list of Column objects that should be declaratively added to the + new Table object. + + """ + raise NotImplementedError() + + +# NOTE: MapperProperty needs to extend _MappedAttribute so that declarative +# typing works, i.e. "Mapped[A] = relationship()". This introduces an +# inconvenience which is that all the MapperProperty objects are treated +# as descriptors by typing tools, which are misled by this as assignment / +# access to a descriptor attribute wants to move through __get__. +# Therefore, references to MapperProperty as an instance variable, such +# as in PropComparator, may have some special typing workarounds such as the +# use of sqlalchemy.util.typing.DescriptorReference to avoid mis-interpretation +# by typing tools +@inspection._self_inspects +class MapperProperty( + HasCacheKey, + _DCAttributeOptions, + _MappedAttribute[_T], + InspectionAttrInfo, + util.MemoizedSlots, +): + """Represent a particular class attribute mapped by :class:`_orm.Mapper`. + + The most common occurrences of :class:`.MapperProperty` are the + mapped :class:`_schema.Column`, which is represented in a mapping as + an instance of :class:`.ColumnProperty`, + and a reference to another class produced by :func:`_orm.relationship`, + represented in the mapping as an instance of + :class:`.Relationship`. + + """ + + __slots__ = ( + "_configure_started", + "_configure_finished", + "_attribute_options", + "_has_dataclass_arguments", + "parent", + "key", + "info", + "doc", + ) + + _cache_key_traversal: _TraverseInternalsType = [ + ("parent", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("key", visitors.ExtendedInternalTraversal.dp_string), + ] + + if not TYPE_CHECKING: + cascade = None + + is_property = True + """Part of the InspectionAttr interface; states this object is a + mapper property. + + """ + + comparator: PropComparator[_T] + """The :class:`_orm.PropComparator` instance that implements SQL + expression construction on behalf of this mapped attribute.""" + + key: str + """name of class attribute""" + + parent: Mapper[Any] + """the :class:`.Mapper` managing this property.""" + + _is_relationship = False + + _links_to_entity: bool + """True if this MapperProperty refers to a mapped entity. + + Should only be True for Relationship, False for all others. + + """ + + doc: Optional[str] + """optional documentation string""" + + info: _InfoType + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`_orm.relationship`, or :func:`.composite` + functions. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + + def _memoized_attr_info(self) -> _InfoType: + """Info dictionary associated with the object, allowing user-defined + data to be associated with this :class:`.InspectionAttr`. + + The dictionary is generated when first accessed. Alternatively, + it can be specified as a constructor argument to the + :func:`.column_property`, :func:`_orm.relationship`, or + :func:`.composite` + functions. + + .. seealso:: + + :attr:`.QueryableAttribute.info` + + :attr:`.SchemaItem.info` + + """ + return {} + + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: + """Called by Query for the purposes of constructing a SQL statement. + + Each MapperProperty associated with the target mapper processes the + statement referenced by the query context, adding columns and/or + criterion as appropriate. + + """ + + def create_row_processor( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + mapper: Mapper[Any], + result: Result[Any], + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: + """Produce row processing functions and append to the given + set of populators lists. + + """ + + def cascade_iterator( + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: + """Iterate through instances related to the given instance for + a particular 'cascade', starting with this MapperProperty. + + Return an iterator3-tuples (instance, mapper, state). + + Note that the 'cascade' collection on this MapperProperty is + checked first for the given type before cascade_iterator is called. + + This method typically only applies to Relationship. + + """ + + return iter(()) + + def set_parent(self, parent: Mapper[Any], init: bool) -> None: + """Set the parent mapper that references this MapperProperty. + + This method is overridden by some subclasses to perform extra + setup when the mapper is first known. + + """ + self.parent = parent + + def instrument_class(self, mapper: Mapper[Any]) -> None: + """Hook called by the Mapper to the property to initiate + instrumentation of the class attribute managed by this + MapperProperty. + + The MapperProperty here will typically call out to the + attributes module to set up an InstrumentedAttribute. + + This step is the first of two steps to set up an InstrumentedAttribute, + and is called early in the mapper setup process. + + The second step is typically the init_class_attribute step, + called from StrategizedProperty via the post_instrument_class() + hook. This step assigns additional state to the InstrumentedAttribute + (specifically the "impl") which has been determined after the + MapperProperty has determined what kind of persistence + management it needs to do (e.g. scalar, object, collection, etc). + + """ + + def __init__( + self, + attribute_options: Optional[_AttributeOptions] = None, + _assume_readonly_dc_attributes: bool = False, + ) -> None: + self._configure_started = False + self._configure_finished = False + + if _assume_readonly_dc_attributes: + default_attrs = _DEFAULT_READONLY_ATTRIBUTE_OPTIONS + else: + default_attrs = _DEFAULT_ATTRIBUTE_OPTIONS + + if attribute_options and attribute_options != default_attrs: + self._has_dataclass_arguments = True + self._attribute_options = attribute_options + else: + self._has_dataclass_arguments = False + self._attribute_options = default_attrs + + def init(self) -> None: + """Called after all mappers are created to assemble + relationships between mappers and perform other post-mapper-creation + initialization steps. + + + """ + self._configure_started = True + self.do_init() + self._configure_finished = True + + @property + def class_attribute(self) -> InstrumentedAttribute[_T]: + """Return the class-bound descriptor corresponding to this + :class:`.MapperProperty`. + + This is basically a ``getattr()`` call:: + + return getattr(self.parent.class_, self.key) + + I.e. if this :class:`.MapperProperty` were named ``addresses``, + and the class to which it is mapped is ``User``, this sequence + is possible:: + + >>> from sqlalchemy import inspect + >>> mapper = inspect(User) + >>> addresses_property = mapper.attrs.addresses + >>> addresses_property.class_attribute is User.addresses + True + >>> User.addresses.property is addresses_property + True + + + """ + + return getattr(self.parent.class_, self.key) # type: ignore + + def do_init(self) -> None: + """Perform subclass-specific initialization post-mapper-creation + steps. + + This is a template method called by the ``MapperProperty`` + object's init() method. + + """ + + def post_instrument_class(self, mapper: Mapper[Any]) -> None: + """Perform instrumentation adjustments that need to occur + after init() has completed. + + The given Mapper is the Mapper invoking the operation, which + may not be the same Mapper as self.parent in an inheritance + scenario; however, Mapper will always at least be a sub-mapper of + self.parent. + + This method is typically used by StrategizedProperty, which delegates + it to LoaderStrategy.init_class_attribute() to perform final setup + on the class-bound InstrumentedAttribute. + + """ + + def merge( + self, + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: + """Merge the attribute represented by this ``MapperProperty`` + from source to destination object. + + """ + + def __repr__(self) -> str: + return "<%s at 0x%x; %s>" % ( + self.__class__.__name__, + id(self), + getattr(self, "key", "no key"), + ) + + +@inspection._self_inspects +class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): + r"""Defines SQL operations for ORM mapped attributes. + + SQLAlchemy allows for operators to + be redefined at both the Core and ORM level. :class:`.PropComparator` + is the base class of operator redefinition for ORM-level operations, + including those of :class:`.ColumnProperty`, + :class:`.Relationship`, and :class:`.Composite`. + + User-defined subclasses of :class:`.PropComparator` may be created. The + built-in Python comparison and math operator methods, such as + :meth:`.operators.ColumnOperators.__eq__`, + :meth:`.operators.ColumnOperators.__lt__`, and + :meth:`.operators.ColumnOperators.__add__`, can be overridden to provide + new operator behavior. The custom :class:`.PropComparator` is passed to + the :class:`.MapperProperty` instance via the ``comparator_factory`` + argument. In each case, + the appropriate subclass of :class:`.PropComparator` should be used:: + + # definition of custom PropComparator subclasses + + from sqlalchemy.orm.properties import \ + ColumnProperty,\ + Composite,\ + Relationship + + class MyColumnComparator(ColumnProperty.Comparator): + def __eq__(self, other): + return self.__clause_element__() == other + + class MyRelationshipComparator(Relationship.Comparator): + def any(self, expression): + "define the 'any' operation" + # ... + + class MyCompositeComparator(Composite.Comparator): + def __gt__(self, other): + "redefine the 'greater than' operation" + + return sql.and_(*[a>b for a, b in + zip(self.__clause_element__().clauses, + other.__composite_values__())]) + + + # application of custom PropComparator subclasses + + from sqlalchemy.orm import column_property, relationship, composite + from sqlalchemy import Column, String + + class SomeMappedClass(Base): + some_column = column_property(Column("some_column", String), + comparator_factory=MyColumnComparator) + + some_relationship = relationship(SomeOtherClass, + comparator_factory=MyRelationshipComparator) + + some_composite = composite( + Column("a", String), Column("b", String), + comparator_factory=MyCompositeComparator + ) + + Note that for column-level operator redefinition, it's usually + simpler to define the operators at the Core level, using the + :attr:`.TypeEngine.comparator_factory` attribute. See + :ref:`types_operators` for more detail. + + .. seealso:: + + :class:`.ColumnProperty.Comparator` + + :class:`.Relationship.Comparator` + + :class:`.Composite.Comparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + __slots__ = "prop", "_parententity", "_adapt_to_entity" + + __visit_name__ = "orm_prop_comparator" + + _parententity: _InternalEntityType[Any] + _adapt_to_entity: Optional[AliasedInsp[Any]] + prop: RODescriptorReference[MapperProperty[_T_co]] + + def __init__( + self, + prop: MapperProperty[_T], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, + ): + self.prop = prop + self._parententity = adapt_to_entity or parentmapper + self._adapt_to_entity = adapt_to_entity + + @util.non_memoized_property + def property(self) -> MapperProperty[_T_co]: + """Return the :class:`.MapperProperty` associated with this + :class:`.PropComparator`. + + + Return values here will commonly be instances of + :class:`.ColumnProperty` or :class:`.Relationship`. + + + """ + return self.prop + + def __clause_element__(self) -> roles.ColumnsClauseRole: + raise NotImplementedError("%r" % self) + + def _bulk_update_tuples( + self, value: Any + ) -> Sequence[Tuple[_DMLColumnArgument, Any]]: + """Receive a SQL expression that represents a value in the SET + clause of an UPDATE statement. + + Return a tuple that can be passed to a :class:`_expression.Update` + construct. + + """ + + return [(cast("_DMLColumnArgument", self.__clause_element__()), value)] + + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> PropComparator[_T_co]: + """Return a copy of this PropComparator which will use the given + :class:`.AliasedInsp` to produce corresponding expressions. + """ + return self.__class__(self.prop, self._parententity, adapt_to_entity) + + @util.ro_non_memoized_property + def _parentmapper(self) -> Mapper[Any]: + """legacy; this is renamed to _parententity to be + compatible with QueryableAttribute.""" + return self._parententity.mapper + + def _criterion_exists( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[Any]: + return self.prop.comparator._criterion_exists(criterion, **kwargs) + + @util.ro_non_memoized_property + def adapter(self) -> Optional[_ORMAdapterProto]: + """Produce a callable that adapts column expressions + to suit an aliased version of this comparator. + + """ + if self._adapt_to_entity is None: + return None + else: + return self._adapt_to_entity._orm_adapt_element + + @util.ro_non_memoized_property + def info(self) -> _InfoType: + return self.prop.info + + @staticmethod + def _any_op(a: Any, b: Any, **kwargs: Any) -> Any: + return a.any(b, **kwargs) + + @staticmethod + def _has_op(left: Any, other: Any, **kwargs: Any) -> Any: + return left.has(other, **kwargs) + + @staticmethod + def _of_type_op(a: Any, class_: Any) -> Any: + return a.of_type(class_) + + any_op = cast(operators.OperatorType, _any_op) + has_op = cast(operators.OperatorType, _has_op) + of_type_op = cast(operators.OperatorType, _of_type_op) + + if typing.TYPE_CHECKING: + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... + + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: + r"""Redefine this object in terms of a polymorphic subclass, + :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` + construct. + + Returns a new PropComparator from which further criterion can be + evaluated. + + e.g.:: + + query.join(Company.employees.of_type(Engineer)).\ + filter(Engineer.name=='foo') + + :param \class_: a class or mapper indicating that criterion will be + against this specific subclass. + + .. seealso:: + + :ref:`orm_queryguide_joining_relationships_aliased` - in the + :ref:`queryguide_toplevel` + + :ref:`inheritance_of_type` + + """ + + return self.operate(PropComparator.of_type_op, class_) # type: ignore + + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[bool]: + """Add additional criteria to the ON clause that's represented by this + relationship attribute. + + E.g.:: + + + stmt = select(User).join( + User.addresses.and_(Address.email_address != 'foo') + ) + + stmt = select(User).options( + joinedload(User.addresses.and_(Address.email_address != 'foo')) + ) + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_queryguide_join_on_augmented` + + :ref:`loader_option_criteria` + + :func:`.with_loader_criteria` + + """ + return self.operate(operators.and_, *criteria) # type: ignore + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + r"""Return a SQL expression representing true if this element + references a member which meets the given criterion. + + The usual implementation of ``any()`` is + :meth:`.Relationship.Comparator.any`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. + + """ + + return self.operate(PropComparator.any_op, criterion, **kwargs) + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + r"""Return a SQL expression representing true if this element + references a member which meets the given criterion. + + The usual implementation of ``has()`` is + :meth:`.Relationship.Comparator.has`. + + :param criterion: an optional ClauseElement formulated against the + member class' table or attributes. + + :param \**kwargs: key/value pairs corresponding to member class + attribute names which will be compared via equality to the + corresponding values. + + """ + + return self.operate(PropComparator.has_op, criterion, **kwargs) + + +class StrategizedProperty(MapperProperty[_T]): + """A MapperProperty which uses selectable strategies to affect + loading behavior. + + There is a single strategy selected by default. Alternate + strategies can be selected at Query time through the usage of + ``StrategizedOption`` objects via the Query.options() method. + + The mechanics of StrategizedProperty are used for every Query + invocation for every mapped attribute participating in that Query, + to determine first how the attribute will be rendered in SQL + and secondly how the attribute will retrieve a value from a result + row and apply it to a mapped object. The routines here are very + performance-critical. + + """ + + __slots__ = ( + "_strategies", + "strategy", + "_wildcard_token", + "_default_path_loader_key", + "strategy_key", + ) + inherit_cache = True + strategy_wildcard_key: ClassVar[str] + + strategy_key: _StrategyKey + + _strategies: Dict[_StrategyKey, LoaderStrategy] + + def _memoized_attr__wildcard_token(self) -> Tuple[str]: + return ( + f"{self.strategy_wildcard_key}:{path_registry._WILDCARD_TOKEN}", + ) + + def _memoized_attr__default_path_loader_key( + self, + ) -> Tuple[str, Tuple[str]]: + return ( + "loader", + (f"{self.strategy_wildcard_key}:{path_registry._DEFAULT_TOKEN}",), + ) + + def _get_context_loader( + self, context: ORMCompileState, path: AbstractEntityRegistry + ) -> Optional[_LoadElement]: + load: Optional[_LoadElement] = None + + search_path = path[self] + + # search among: exact match, "attr.*", "default" strategy + # if any. + for path_key in ( + search_path._loader_key, + search_path._wildcard_path_loader_key, + search_path._default_path_loader_key, + ): + if path_key in context.attributes: + load = context.attributes[path_key] + break + + # note that if strategy_options.Load is placing non-actionable + # objects in the context like defaultload(), we would + # need to continue the loop here if we got such an + # option as below. + # if load.strategy or load.local_opts: + # break + + return load + + def _get_strategy(self, key: _StrategyKey) -> LoaderStrategy: + try: + return self._strategies[key] + except KeyError: + pass + + # run outside to prevent transfer of exception context + cls = self._strategy_lookup(self, *key) + # this previously was setting self._strategies[cls], that's + # a bad idea; should use strategy key at all times because every + # strategy has multiple keys at this point + self._strategies[key] = strategy = cls(self, key) + return strategy + + def setup( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.setup_query( + context, query_entity, path, loader, adapter, **kwargs + ) + + def create_row_processor( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + mapper: Mapper[Any], + result: Result[Any], + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: + loader = self._get_context_loader(context, path) + if loader and loader.strategy: + strat = self._get_strategy(loader.strategy) + else: + strat = self.strategy + strat.create_row_processor( + context, + query_entity, + path, + loader, + mapper, + result, + adapter, + populators, + ) + + def do_init(self) -> None: + self._strategies = {} + self.strategy = self._get_strategy(self.strategy_key) + + def post_instrument_class(self, mapper: Mapper[Any]) -> None: + if ( + not self.parent.non_primary + and not mapper.class_manager._attr_has_impl(self.key) + ): + self.strategy.init_class_attribute(mapper) + + _all_strategies: collections.defaultdict[ + Type[MapperProperty[Any]], Dict[_StrategyKey, Type[LoaderStrategy]] + ] = collections.defaultdict(dict) + + @classmethod + def strategy_for(cls, **kw: Any) -> Callable[[_TLS], _TLS]: + def decorate(dec_cls: _TLS) -> _TLS: + # ensure each subclass of the strategy has its + # own _strategy_keys collection + if "_strategy_keys" not in dec_cls.__dict__: + dec_cls._strategy_keys = [] + key = tuple(sorted(kw.items())) + cls._all_strategies[cls][key] = dec_cls + dec_cls._strategy_keys.append(key) + return dec_cls + + return decorate + + @classmethod + def _strategy_lookup( + cls, requesting_property: MapperProperty[Any], *key: Any + ) -> Type[LoaderStrategy]: + requesting_property.parent._with_polymorphic_mappers + + for prop_cls in cls.__mro__: + if prop_cls in cls._all_strategies: + if TYPE_CHECKING: + assert issubclass(prop_cls, MapperProperty) + strategies = cls._all_strategies[prop_cls] + try: + return strategies[key] + except KeyError: + pass + + for property_type, strats in cls._all_strategies.items(): + if key in strats: + intended_property_type = property_type + actual_strategy = strats[key] + break + else: + intended_property_type = None + actual_strategy = None + + raise orm_exc.LoaderStrategyException( + cls, + requesting_property, + intended_property_type, + actual_strategy, + key, + ) + + +class ORMOption(ExecutableOption): + """Base class for option objects that are passed to ORM queries. + + These options may be consumed by :meth:`.Query.options`, + :meth:`.Select.options`, or in a more general sense by any + :meth:`.Executable.options` method. They are interpreted at + statement compile time or execution time in modern use. The + deprecated :class:`.MapperOption` is consumed at ORM query construction + time. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _is_legacy_option = False + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" SELECT statements that occur for relationship + lazy loaders as well as attribute load / refresh operations. + + """ + + _is_core = False + + _is_user_defined = False + + _is_compile_state = False + + _is_criteria_option = False + + _is_strategy_option = False + + def _adapt_cached_option_to_uncached_option( + self, context: QueryContext, uncached_opt: ORMOption + ) -> ORMOption: + """adapt this option to the "uncached" version of itself in a + loader strategy context. + + given "self" which is an option from a cached query, as well as the + corresponding option from the uncached version of the same query, + return the option we should use in a new query, in the context of a + loader strategy being asked to load related rows on behalf of that + cached query, which is assumed to be building a new query based on + entities passed to us from the cached query. + + Currently this routine chooses between "self" and "uncached" without + manufacturing anything new. If the option is itself a loader strategy + option which has a path, that path needs to match to the entities being + passed to us by the cached query, so the :class:`_orm.Load` subclass + overrides this to return "self". For all other options, we return the + uncached form which may have changing state, such as a + with_loader_criteria() option which will very often have new state. + + This routine could in the future involve + generating a new option based on both inputs if use cases arise, + such as if with_loader_criteria() needed to match up to + ``AliasedClass`` instances given in the parent query. + + However, longer term it might be better to restructure things such that + ``AliasedClass`` entities are always matched up on their cache key, + instead of identity, in things like paths and such, so that this whole + issue of "the uncached option does not match the entities" goes away. + However this would make ``PathRegistry`` more complicated and difficult + to debug as well as potentially less performant in that it would be + hashing enormous cache keys rather than a simple AliasedInsp. UNLESS, + we could get cache keys overall to be reliably hashed into something + like an md5 key. + + .. versionadded:: 1.4.41 + + """ + if uncached_opt is not None: + return uncached_opt + else: + return self + + +class CompileStateOption(HasCacheKey, ORMOption): + """base for :class:`.ORMOption` classes that affect the compilation of + a SQL query and therefore need to be part of the cache key. + + .. note:: :class:`.CompileStateOption` is generally non-public and + should not be used as a base class for user-defined options; instead, + use :class:`.UserDefinedOption`, which is easier to use as it does not + interact with ORM compilation internals or caching. + + :class:`.CompileStateOption` defines an internal attribute + ``_is_compile_state=True`` which has the effect of the ORM compilation + routines for SELECT and other statements will call upon these options when + a SQL string is being compiled. As such, these classes implement + :class:`.HasCacheKey` and need to provide robust ``_cache_key_traversal`` + structures. + + The :class:`.CompileStateOption` class is used to implement the ORM + :class:`.LoaderOption` and :class:`.CriteriaOption` classes. + + .. versionadded:: 1.4.28 + + + """ + + __slots__ = () + + _is_compile_state = True + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + """Apply a modification to a given :class:`.ORMCompileState`. + + This method is part of the implementation of a particular + :class:`.CompileStateOption` and is only invoked internally + when an ORM query is compiled. + + """ + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: + """Apply a modification to a given :class:`.ORMCompileState`, + given entities that were replaced by with_only_columns() or + with_entities(). + + This method is part of the implementation of a particular + :class:`.CompileStateOption` and is only invoked internally + when an ORM query is compiled. + + .. versionadded:: 1.4.19 + + """ + + +class LoaderOption(CompileStateOption): + """Describe a loader modification to an ORM statement at compilation time. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: + self.process_compile_state(compile_state) + + +class CriteriaOption(CompileStateOption): + """Describe a WHERE criteria modification to an ORM statement at + compilation time. + + .. versionadded:: 1.4 + + """ + + __slots__ = () + + _is_criteria_option = True + + def get_global_criteria(self, attributes: Dict[str, Any]) -> None: + """update additional entity criteria options in the given + attributes dictionary. + + """ + + +class UserDefinedOption(ORMOption): + """Base class for a user-defined option that can be consumed from the + :meth:`.SessionEvents.do_orm_execute` event hook. + + """ + + __slots__ = ("payload",) + + _is_legacy_option = False + + _is_user_defined = True + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + + """ + + def __init__(self, payload: Optional[Any] = None): + self.payload = payload + + +@util.deprecated_cls( + "1.4", + "The :class:`.MapperOption class is deprecated and will be removed " + "in a future release. For " + "modifications to queries on a per-execution basis, use the " + ":class:`.UserDefinedOption` class to establish state within a " + ":class:`.Query` or other Core statement, then use the " + ":meth:`.SessionEvents.before_orm_execute` hook to consume them.", + constructor=None, +) +class MapperOption(ORMOption): + """Describe a modification to a Query""" + + __slots__ = () + + _is_legacy_option = True + + propagate_to_loaders = False + """if True, indicate this option should be carried along + to "secondary" Query objects produced during lazy loads + or refresh operations. + + """ + + def process_query(self, query: Query[Any]) -> None: + """Apply a modification to the given :class:`_query.Query`.""" + + def process_query_conditionally(self, query: Query[Any]) -> None: + """same as process_query(), except that this option may not + apply to the given query. + + This is typically applied during a lazy load or scalar refresh + operation to propagate options stated in the original Query to the + new Query being used for the load. It occurs for those options that + specify propagate_to_loaders=True. + + """ + + self.process_query(query) + + +class LoaderStrategy: + """Describe the loading behavior of a StrategizedProperty object. + + The ``LoaderStrategy`` interacts with the querying process in three + ways: + + * it controls the configuration of the ``InstrumentedAttribute`` + placed on a class to handle the behavior of the attribute. this + may involve setting up class-level callable functions to fire + off a select operation when the attribute is first accessed + (i.e. a lazy load) + + * it processes the ``QueryContext`` at statement construction time, + where it can modify the SQL statement that is being produced. + For example, simple column attributes will add their represented + column to the list of selected columns, a joined eager loader + may establish join clauses to add to the statement. + + * It produces "row processor" functions at result fetching time. + These "row processor" functions populate a particular attribute + on a particular mapped instance. + + """ + + __slots__ = ( + "parent_property", + "is_class_level", + "parent", + "key", + "strategy_key", + "strategy_opts", + ) + + _strategy_keys: ClassVar[List[_StrategyKey]] + + def __init__( + self, parent: MapperProperty[Any], strategy_key: _StrategyKey + ): + self.parent_property = parent + self.is_class_level = False + self.parent = self.parent_property.parent + self.key = self.parent_property.key + self.strategy_key = strategy_key + self.strategy_opts = dict(strategy_key) + + def init_class_attribute(self, mapper: Mapper[Any]) -> None: + pass + + def setup_query( + self, + compile_state: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + adapter: Optional[ORMAdapter], + **kwargs: Any, + ) -> None: + """Establish column and other state for a given QueryContext. + + This method fulfills the contract specified by MapperProperty.setup(). + + StrategizedProperty delegates its setup() method + directly to this method. + + """ + + def create_row_processor( + self, + context: ORMCompileState, + query_entity: _MapperEntity, + path: AbstractEntityRegistry, + loadopt: Optional[_LoadElement], + mapper: Mapper[Any], + result: Result[Any], + adapter: Optional[ORMAdapter], + populators: _PopulatorDict, + ) -> None: + """Establish row processing functions for a given QueryContext. + + This method fulfills the contract specified by + MapperProperty.create_row_processor(). + + StrategizedProperty delegates its create_row_processor() method + directly to this method. + + """ + + def __str__(self) -> str: + return str(self.parent_property) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/loading.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/loading.py new file mode 100644 index 0000000..4e2cb82 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/loading.py @@ -0,0 +1,1665 @@ +# orm/loading.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""private module containing functions used to convert database +rows into object instances and associated state. + +the functions here are called primarily by Query, Mapper, +as well as some of the attribute loading strategies. + +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import exc as orm_exc +from . import path_registry +from .base import _DEFER_FOR_STATE +from .base import _RAISE_FOR_STATE +from .base import _SET_DEFERRED_EXPIRED +from .base import PassiveFlag +from .context import FromStatement +from .context import ORMCompileState +from .context import QueryContext +from .util import _none_set +from .util import state_str +from .. import exc as sa_exc +from .. import util +from ..engine import result_tuple +from ..engine.result import ChunkedIteratorResult +from ..engine.result import FrozenResult +from ..engine.result import SimpleResultMetaData +from ..sql import select +from ..sql import util as sql_util +from ..sql.selectable import ForUpdateArg +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectState +from ..util import EMPTY_DICT + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from .base import LoaderCallableStatus + from .interfaces import ORMOption + from .mapper import Mapper + from .query import Query + from .session import Session + from .state import InstanceState + from ..engine.cursor import CursorResult + from ..engine.interfaces import _ExecuteOptions + from ..engine.result import Result + from ..sql import Select + +_T = TypeVar("_T", bound=Any) +_O = TypeVar("_O", bound=object) +_new_runid = util.counter() + + +_PopulatorDict = Dict[str, List[Tuple[str, Any]]] + + +def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: + """Return a :class:`.Result` given an ORM query context. + + :param cursor: a :class:`.CursorResult`, generated by a statement + which came from :class:`.ORMCompileState` + + :param context: a :class:`.QueryContext` object + + :return: a :class:`.Result` object representing ORM results + + .. versionchanged:: 1.4 The instances() function now uses + :class:`.Result` objects and has an all new interface. + + """ + + context.runid = _new_runid() + + if context.top_level_context: + is_top_level = False + context.post_load_paths = context.top_level_context.post_load_paths + else: + is_top_level = True + context.post_load_paths = {} + + compile_state = context.compile_state + filtered = compile_state._has_mapper_entities + single_entity = ( + not context.load_options._only_return_tuples + and len(compile_state._entities) == 1 + and compile_state._entities[0].supports_single_entity + ) + + try: + (process, labels, extra) = list( + zip( + *[ + query_entity.row_processor(context, cursor) + for query_entity in context.compile_state._entities + ] + ) + ) + + if context.yield_per and ( + context.loaders_require_buffering + or context.loaders_require_uniquing + ): + raise sa_exc.InvalidRequestError( + "Can't use yield_per with eager loaders that require uniquing " + "or row buffering, e.g. joinedload() against collections " + "or subqueryload(). Consider the selectinload() strategy " + "for better flexibility in loading objects." + ) + + except Exception: + with util.safe_reraise(): + cursor.close() + + def _no_unique(entry): + raise sa_exc.InvalidRequestError( + "Can't use the ORM yield_per feature in conjunction with unique()" + ) + + def _not_hashable(datatype, *, legacy=False, uncertain=False): + if not legacy: + + def go(obj): + if uncertain: + try: + return hash(obj) + except: + pass + + raise sa_exc.InvalidRequestError( + "Can't apply uniqueness to row tuple containing value of " + f"""type {datatype!r}; { + 'the values returned appear to be' + if uncertain + else 'this datatype produces' + } non-hashable values""" + ) + + return go + elif not uncertain: + return id + else: + _use_id = False + + def go(obj): + nonlocal _use_id + + if not _use_id: + try: + return hash(obj) + except: + pass + + # in #10459, we considered using a warning here, however + # as legacy query uses result.unique() in all cases, this + # would lead to too many warning cases. + _use_id = True + + return id(obj) + + return go + + unique_filters = [ + ( + _no_unique + if context.yield_per + else ( + _not_hashable( + ent.column.type, # type: ignore + legacy=context.load_options._legacy_uniquing, + uncertain=ent._null_column_type, + ) + if ( + not ent.use_id_for_hash + and (ent._non_hashable_value or ent._null_column_type) + ) + else id if ent.use_id_for_hash else None + ) + ) + for ent in context.compile_state._entities + ] + + row_metadata = SimpleResultMetaData( + labels, extra, _unique_filters=unique_filters + ) + + def chunks(size): # type: ignore + while True: + yield_per = size + + context.partials = {} + + if yield_per: + fetch = cursor.fetchmany(yield_per) + + if not fetch: + break + else: + fetch = cursor._raw_all_rows() + + if single_entity: + proc = process[0] + rows = [proc(row) for row in fetch] + else: + rows = [ + tuple([proc(row) for proc in process]) for row in fetch + ] + + # if we are the originating load from a query, meaning we + # aren't being called as a result of a nested "post load", + # iterate through all the collected post loaders and fire them + # off. Previously this used to work recursively, however that + # prevented deeply nested structures from being loadable + if is_top_level: + if yield_per: + # if using yield per, memoize the state of the + # collection so that it can be restored + top_level_post_loads = list( + context.post_load_paths.items() + ) + + while context.post_load_paths: + post_loads = list(context.post_load_paths.items()) + context.post_load_paths.clear() + for path, post_load in post_loads: + post_load.invoke(context, path) + + if yield_per: + context.post_load_paths.clear() + context.post_load_paths.update(top_level_post_loads) + + yield rows + + if not yield_per: + break + + if context.execution_options.get("prebuffer_rows", False): + # this is a bit of a hack at the moment. + # I would rather have some option in the result to pre-buffer + # internally. + _prebuffered = list(chunks(None)) + + def chunks(size): + return iter(_prebuffered) + + result = ChunkedIteratorResult( + row_metadata, + chunks, + source_supports_scalars=single_entity, + raw=cursor, + dynamic_yield_per=cursor.context._is_server_side, + ) + + # filtered and single_entity are used to indicate to legacy Query that the + # query has ORM entities, so legacy deduping and scalars should be called + # on the result. + result._attributes = result._attributes.union( + dict(filtered=filtered, is_single_entity=single_entity) + ) + + # multi_row_eager_loaders OTOH is specific to joinedload. + if context.compile_state.multi_row_eager_loaders: + + def require_unique(obj): + raise sa_exc.InvalidRequestError( + "The unique() method must be invoked on this Result, " + "as it contains results that include joined eager loads " + "against collections" + ) + + result._unique_filter_state = (None, require_unique) + + if context.yield_per: + result.yield_per(context.yield_per) + + return result + + +@util.preload_module("sqlalchemy.orm.context") +def merge_frozen_result(session, statement, frozen_result, load=True): + """Merge a :class:`_engine.FrozenResult` back into a :class:`_orm.Session`, + returning a new :class:`_engine.Result` object with :term:`persistent` + objects. + + See the section :ref:`do_orm_execute_re_executing` for an example. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` + + :meth:`_engine.Result.freeze` + + :class:`_engine.FrozenResult` + + """ + querycontext = util.preloaded.orm_context + + if load: + # flush current contents if we expect to load data + session._autoflush() + + ctx = querycontext.ORMSelectCompileState._create_entities_collection( + statement, legacy=False + ) + + autoflush = session.autoflush + try: + session.autoflush = False + mapped_entities = [ + i + for i, e in enumerate(ctx._entities) + if isinstance(e, querycontext._MapperEntity) + ] + keys = [ent._label_name for ent in ctx._entities] + + keyed_tuple = result_tuple( + keys, [ent._extra_entities for ent in ctx._entities] + ) + + result = [] + for newrow in frozen_result.rewrite_rows(): + for i in mapped_entities: + if newrow[i] is not None: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + + result.append(keyed_tuple(newrow)) + + return frozen_result.with_new_rows(result) + finally: + session.autoflush = autoflush + + +@util.became_legacy_20( + ":func:`_orm.merge_result`", + alternative="The function as well as the method on :class:`_orm.Query` " + "is superseded by the :func:`_orm.merge_frozen_result` function.", +) +@util.preload_module("sqlalchemy.orm.context") +def merge_result( + query: Query[Any], + iterator: Union[FrozenResult, Iterable[Sequence[Any]], Iterable[object]], + load: bool = True, +) -> Union[FrozenResult, Iterable[Any]]: + """Merge a result into the given :class:`.Query` object's Session. + + See :meth:`_orm.Query.merge_result` for top-level documentation on this + function. + + """ + + querycontext = util.preloaded.orm_context + + session = query.session + if load: + # flush current contents if we expect to load data + session._autoflush() + + # TODO: need test coverage and documentation for the FrozenResult + # use case. + if isinstance(iterator, FrozenResult): + frozen_result = iterator + iterator = iter(frozen_result.data) + else: + frozen_result = None + + ctx = querycontext.ORMSelectCompileState._create_entities_collection( + query, legacy=True + ) + + autoflush = session.autoflush + try: + session.autoflush = False + single_entity = not frozen_result and len(ctx._entities) == 1 + + if single_entity: + if isinstance(ctx._entities[0], querycontext._MapperEntity): + result = [ + session._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + for instance in iterator + ] + else: + result = list(iterator) + else: + mapped_entities = [ + i + for i, e in enumerate(ctx._entities) + if isinstance(e, querycontext._MapperEntity) + ] + result = [] + keys = [ent._label_name for ent in ctx._entities] + + keyed_tuple = result_tuple( + keys, [ent._extra_entities for ent in ctx._entities] + ) + + for row in iterator: + newrow = list(row) + for i in mapped_entities: + if newrow[i] is not None: + newrow[i] = session._merge( + attributes.instance_state(newrow[i]), + attributes.instance_dict(newrow[i]), + load=load, + _recursive={}, + _resolve_conflict_map={}, + ) + result.append(keyed_tuple(newrow)) + + if frozen_result: + return frozen_result.with_new_rows(result) + else: + return iter(result) + finally: + session.autoflush = autoflush + + +def get_from_identity( + session: Session, + mapper: Mapper[_O], + key: _IdentityKeyType[_O], + passive: PassiveFlag, +) -> Union[LoaderCallableStatus, Optional[_O]]: + """Look up the given key in the given session's identity map, + check the object for expired state if found. + + """ + instance = session.identity_map.get(key) + if instance is not None: + state = attributes.instance_state(instance) + + if mapper.inherits and not state.mapper.isa(mapper): + return attributes.PASSIVE_CLASS_MISMATCH + + # expired - ensure it still exists + if state.expired: + if not passive & attributes.SQL_OK: + # TODO: no coverage here + return attributes.PASSIVE_NO_RESULT + elif not passive & attributes.RELATED_OBJECT_OK: + # this mode is used within a flush and the instance's + # expired state will be checked soon enough, if necessary. + # also used by immediateloader for a mutually-dependent + # o2m->m2m load, :ticket:`6301` + return instance + try: + state._load_expired(state, passive) + except orm_exc.ObjectDeletedError: + session._remove_newly_deleted([state]) + return None + return instance + else: + return None + + +def load_on_ident( + session: Session, + statement: Union[Select, FromStatement], + key: Optional[_IdentityKeyType], + *, + load_options: Optional[Sequence[ORMOption]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + with_for_update: Optional[ForUpdateArg] = None, + only_load_props: Optional[Iterable[str]] = None, + no_autoflush: bool = False, + bind_arguments: Mapping[str, Any] = util.EMPTY_DICT, + execution_options: _ExecuteOptions = util.EMPTY_DICT, + require_pk_cols: bool = False, + is_user_refresh: bool = False, +): + """Load the given identity key from the database.""" + if key is not None: + ident = key[1] + identity_token = key[2] + else: + ident = identity_token = None + + return load_on_pk_identity( + session, + statement, + ident, + load_options=load_options, + refresh_state=refresh_state, + with_for_update=with_for_update, + only_load_props=only_load_props, + identity_token=identity_token, + no_autoflush=no_autoflush, + bind_arguments=bind_arguments, + execution_options=execution_options, + require_pk_cols=require_pk_cols, + is_user_refresh=is_user_refresh, + ) + + +def load_on_pk_identity( + session: Session, + statement: Union[Select, FromStatement], + primary_key_identity: Optional[Tuple[Any, ...]], + *, + load_options: Optional[Sequence[ORMOption]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + with_for_update: Optional[ForUpdateArg] = None, + only_load_props: Optional[Iterable[str]] = None, + identity_token: Optional[Any] = None, + no_autoflush: bool = False, + bind_arguments: Mapping[str, Any] = util.EMPTY_DICT, + execution_options: _ExecuteOptions = util.EMPTY_DICT, + require_pk_cols: bool = False, + is_user_refresh: bool = False, +): + """Load the given primary key identity from the database.""" + + query = statement + q = query._clone() + + assert not q._is_lambda_element + + if load_options is None: + load_options = QueryContext.default_load_options + + if ( + statement._compile_options + is SelectState.default_select_compile_options + ): + compile_options = ORMCompileState.default_compile_options + else: + compile_options = statement._compile_options + + if primary_key_identity is not None: + mapper = query._propagate_attrs["plugin_subject"] + + (_get_clause, _get_params) = mapper._get_clause + + # None present in ident - turn those comparisons + # into "IS NULL" + if None in primary_key_identity: + nones = { + _get_params[col].key + for col, value in zip(mapper.primary_key, primary_key_identity) + if value is None + } + + _get_clause = sql_util.adapt_criterion_to_null(_get_clause, nones) + + if len(nones) == len(primary_key_identity): + util.warn( + "fully NULL primary key identity cannot load any " + "object. This condition may raise an error in a future " + "release." + ) + + q._where_criteria = ( + sql_util._deep_annotate(_get_clause, {"_orm_adapt": True}), + ) + + params = { + _get_params[primary_key].key: id_val + for id_val, primary_key in zip( + primary_key_identity, mapper.primary_key + ) + } + else: + params = None + + if with_for_update is not None: + version_check = True + q._for_update_arg = with_for_update + elif query._for_update_arg is not None: + version_check = True + q._for_update_arg = query._for_update_arg + else: + version_check = False + + if require_pk_cols and only_load_props: + if not refresh_state: + raise sa_exc.ArgumentError( + "refresh_state is required when require_pk_cols is present" + ) + + refresh_state_prokeys = refresh_state.mapper._primary_key_propkeys + has_changes = { + key + for key in refresh_state_prokeys.difference(only_load_props) + if refresh_state.attrs[key].history.has_changes() + } + if has_changes: + # raise if pending pk changes are present. + # technically, this could be limited to the case where we have + # relationships in the only_load_props collection to be refreshed + # also (and only ones that have a secondary eager loader, at that). + # however, the error is in place across the board so that behavior + # here is easier to predict. The use case it prevents is one + # of mutating PK attrs, leaving them unflushed, + # calling session.refresh(), and expecting those attrs to remain + # still unflushed. It seems likely someone doing all those + # things would be better off having the PK attributes flushed + # to the database before tinkering like that (session.refresh() is + # tinkering). + raise sa_exc.InvalidRequestError( + f"Please flush pending primary key changes on " + "attributes " + f"{has_changes} for mapper {refresh_state.mapper} before " + "proceeding with a refresh" + ) + + # overall, the ORM has no internal flow right now for "dont load the + # primary row of an object at all, but fire off + # selectinload/subqueryload/immediateload for some relationships". + # It would probably be a pretty big effort to add such a flow. So + # here, the case for #8703 is introduced; user asks to refresh some + # relationship attributes only which are + # selectinload/subqueryload/immediateload/ etc. (not joinedload). + # ORM complains there's no columns in the primary row to load. + # So here, we just add the PK cols if that + # case is detected, so that there is a SELECT emitted for the primary + # row. + # + # Let's just state right up front, for this one little case, + # the ORM here is adding a whole extra SELECT just to satisfy + # limitations in the internal flow. This is really not a thing + # SQLAlchemy finds itself doing like, ever, obviously, we are + # constantly working to *remove* SELECTs we don't need. We + # rationalize this for now based on 1. session.refresh() is not + # commonly used 2. session.refresh() with only relationship attrs is + # even less commonly used 3. the SELECT in question is very low + # latency. + # + # to add the flow to not include the SELECT, the quickest way + # might be to just manufacture a single-row result set to send off to + # instances(), but we'd have to weave that into context.py and all + # that. For 2.0.0, we have enough big changes to navigate for now. + # + mp = refresh_state.mapper._props + for p in only_load_props: + if mp[p]._is_relationship: + only_load_props = refresh_state_prokeys.union(only_load_props) + break + + if refresh_state and refresh_state.load_options: + compile_options += {"_current_path": refresh_state.load_path.parent} + q = q.options(*refresh_state.load_options) + + new_compile_options, load_options = _set_get_options( + compile_options, + load_options, + version_check=version_check, + only_load_props=only_load_props, + refresh_state=refresh_state, + identity_token=identity_token, + is_user_refresh=is_user_refresh, + ) + + q._compile_options = new_compile_options + q._order_by = None + + if no_autoflush: + load_options += {"_autoflush": False} + + execution_options = util.EMPTY_DICT.merge_with( + execution_options, {"_sa_orm_load_options": load_options} + ) + result = ( + session.execute( + q, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + .unique() + .scalars() + ) + + try: + return result.one() + except orm_exc.NoResultFound: + return None + + +def _set_get_options( + compile_opt, + load_opt, + populate_existing=None, + version_check=None, + only_load_props=None, + refresh_state=None, + identity_token=None, + is_user_refresh=None, +): + compile_options = {} + load_options = {} + if version_check: + load_options["_version_check"] = version_check + if populate_existing: + load_options["_populate_existing"] = populate_existing + if refresh_state: + load_options["_refresh_state"] = refresh_state + compile_options["_for_refresh_state"] = True + if only_load_props: + compile_options["_only_load_props"] = frozenset(only_load_props) + if identity_token: + load_options["_identity_token"] = identity_token + + if is_user_refresh: + load_options["_is_user_refresh"] = is_user_refresh + if load_options: + load_opt += load_options + if compile_options: + compile_opt += compile_options + + return compile_opt, load_opt + + +def _setup_entity_query( + compile_state, + mapper, + query_entity, + path, + adapter, + column_collection, + with_polymorphic=None, + only_load_props=None, + polymorphic_discriminator=None, + **kw, +): + if with_polymorphic: + poly_properties = mapper._iterate_polymorphic_properties( + with_polymorphic + ) + else: + poly_properties = mapper._polymorphic_properties + + quick_populators = {} + + path.set(compile_state.attributes, "memoized_setups", quick_populators) + + # for the lead entities in the path, e.g. not eager loads, and + # assuming a user-passed aliased class, e.g. not a from_self() or any + # implicit aliasing, don't add columns to the SELECT that aren't + # in the thing that's aliased. + check_for_adapt = adapter and len(path) == 1 and path[-1].is_aliased_class + + for value in poly_properties: + if only_load_props and value.key not in only_load_props: + continue + value.setup( + compile_state, + query_entity, + path, + adapter, + only_load_props=only_load_props, + column_collection=column_collection, + memoized_populators=quick_populators, + check_for_adapt=check_for_adapt, + **kw, + ) + + if ( + polymorphic_discriminator is not None + and polymorphic_discriminator is not mapper.polymorphic_on + ): + if adapter: + pd = adapter.columns[polymorphic_discriminator] + else: + pd = polymorphic_discriminator + column_collection.append(pd) + + +def _warn_for_runid_changed(state): + util.warn( + "Loading context for %s has changed within a load/refresh " + "handler, suggesting a row refresh operation took place. If this " + "event handler is expected to be " + "emitting row refresh operations within an existing load or refresh " + "operation, set restore_load_context=True when establishing the " + "listener to ensure the context remains unchanged when the event " + "handler completes." % (state_str(state),) + ) + + +def _instance_processor( + query_entity, + mapper, + context, + result, + path, + adapter, + only_load_props=None, + refresh_state=None, + polymorphic_discriminator=None, + _polymorphic_from=None, +): + """Produce a mapper level row processor callable + which processes rows into mapped instances.""" + + # note that this method, most of which exists in a closure + # called _instance(), resists being broken out, as + # attempts to do so tend to add significant function + # call overhead. _instance() is the most + # performance-critical section in the whole ORM. + + identity_class = mapper._identity_class + compile_state = context.compile_state + + # look for "row getter" functions that have been assigned along + # with the compile state that were cached from a previous load. + # these are operator.itemgetter() objects that each will extract a + # particular column from each row. + + getter_key = ("getters", mapper) + getters = path.get(compile_state.attributes, getter_key, None) + + if getters is None: + # no getters, so go through a list of attributes we are loading for, + # and the ones that are column based will have already put information + # for us in another collection "memoized_setups", which represents the + # output of the LoaderStrategy.setup_query() method. We can just as + # easily call LoaderStrategy.create_row_processor for each, but by + # getting it all at once from setup_query we save another method call + # per attribute. + props = mapper._prop_set + if only_load_props is not None: + props = props.intersection( + mapper._props[k] for k in only_load_props + ) + + quick_populators = path.get( + context.attributes, "memoized_setups", EMPTY_DICT + ) + + todo = [] + cached_populators = { + "new": [], + "quick": [], + "deferred": [], + "expire": [], + "existing": [], + "eager": [], + } + + if refresh_state is None: + # we can also get the "primary key" tuple getter function + pk_cols = mapper.primary_key + + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + primary_key_getter = result._tuple_getter(pk_cols) + else: + primary_key_getter = None + + getters = { + "cached_populators": cached_populators, + "todo": todo, + "primary_key_getter": primary_key_getter, + } + for prop in props: + if prop in quick_populators: + # this is an inlined path just for column-based attributes. + col = quick_populators[prop] + if col is _DEFER_FOR_STATE: + cached_populators["new"].append( + (prop.key, prop._deferred_column_loader) + ) + elif col is _SET_DEFERRED_EXPIRED: + # note that in this path, we are no longer + # searching in the result to see if the column might + # be present in some unexpected way. + cached_populators["expire"].append((prop.key, False)) + elif col is _RAISE_FOR_STATE: + cached_populators["new"].append( + (prop.key, prop._raise_column_loader) + ) + else: + getter = None + if adapter: + # this logic had been removed for all 1.4 releases + # up until 1.4.18; the adapter here is particularly + # the compound eager adapter which isn't accommodated + # in the quick_populators right now. The "fallback" + # logic below instead took over in many more cases + # until issue #6596 was identified. + + # note there is still an issue where this codepath + # produces no "getter" for cases where a joined-inh + # mapping includes a labeled column property, meaning + # KeyError is caught internally and we fall back to + # _getter(col), which works anyway. The adapter + # here for joined inh without any aliasing might not + # be useful. Tests which see this include + # test.orm.inheritance.test_basic -> + # EagerTargetingTest.test_adapt_stringency + # OptimizedLoadTest.test_column_expression_joined + # PolymorphicOnNotLocalTest.test_polymorphic_on_column_prop # noqa: E501 + # + + adapted_col = adapter.columns[col] + if adapted_col is not None: + getter = result._getter(adapted_col, False) + if not getter: + getter = result._getter(col, False) + if getter: + cached_populators["quick"].append((prop.key, getter)) + else: + # fall back to the ColumnProperty itself, which + # will iterate through all of its columns + # to see if one fits + prop.create_row_processor( + context, + query_entity, + path, + mapper, + result, + adapter, + cached_populators, + ) + else: + # loader strategies like subqueryload, selectinload, + # joinedload, basically relationships, these need to interact + # with the context each time to work correctly. + todo.append(prop) + + path.set(compile_state.attributes, getter_key, getters) + + cached_populators = getters["cached_populators"] + + populators = {key: list(value) for key, value in cached_populators.items()} + for prop in getters["todo"]: + prop.create_row_processor( + context, query_entity, path, mapper, result, adapter, populators + ) + + propagated_loader_options = context.propagated_loader_options + load_path = ( + context.compile_state.current_path + path + if context.compile_state.current_path.path + else path + ) + + session_identity_map = context.session.identity_map + + populate_existing = context.populate_existing or mapper.always_refresh + load_evt = bool(mapper.class_manager.dispatch.load) + refresh_evt = bool(mapper.class_manager.dispatch.refresh) + persistent_evt = bool(context.session.dispatch.loaded_as_persistent) + if persistent_evt: + loaded_as_persistent = context.session.dispatch.loaded_as_persistent + instance_state = attributes.instance_state + instance_dict = attributes.instance_dict + session_id = context.session.hash_key + runid = context.runid + identity_token = context.identity_token + + version_check = context.version_check + if version_check: + version_id_col = mapper.version_id_col + if version_id_col is not None: + if adapter: + version_id_col = adapter.columns[version_id_col] + version_id_getter = result._getter(version_id_col) + else: + version_id_getter = None + + if not refresh_state and _polymorphic_from is not None: + key = ("loader", path.path) + + if key in context.attributes and context.attributes[key].strategy == ( + ("selectinload_polymorphic", True), + ): + option_entities = context.attributes[key].local_opts["entities"] + else: + option_entities = None + selectin_load_via = mapper._should_selectin_load( + option_entities, + _polymorphic_from, + ) + + if selectin_load_via and selectin_load_via is not _polymorphic_from: + # only_load_props goes w/ refresh_state only, and in a refresh + # we are a single row query for the exact entity; polymorphic + # loading does not apply + assert only_load_props is None + + callable_ = _load_subclass_via_in( + context, + path, + selectin_load_via, + _polymorphic_from, + option_entities, + ) + PostLoad.callable_for_path( + context, + load_path, + selectin_load_via.mapper, + selectin_load_via, + callable_, + selectin_load_via, + ) + + post_load = PostLoad.for_context(context, load_path, only_load_props) + + if refresh_state: + refresh_identity_key = refresh_state.key + if refresh_identity_key is None: + # super-rare condition; a refresh is being called + # on a non-instance-key instance; this is meant to only + # occur within a flush() + refresh_identity_key = mapper._identity_key_from_state( + refresh_state + ) + else: + refresh_identity_key = None + + primary_key_getter = getters["primary_key_getter"] + + if mapper.allow_partial_pks: + is_not_primary_key = _none_set.issuperset + else: + is_not_primary_key = _none_set.intersection + + def _instance(row): + # determine the state that we'll be populating + if refresh_identity_key: + # fixed state that we're refreshing + state = refresh_state + instance = state.obj() + dict_ = instance_dict(instance) + isnew = state.runid != runid + currentload = True + loaded_instance = False + else: + # look at the row, see if that identity is in the + # session, or we have to create a new one + identitykey = ( + identity_class, + primary_key_getter(row), + identity_token, + ) + + instance = session_identity_map.get(identitykey) + + if instance is not None: + # existing instance + state = instance_state(instance) + dict_ = instance_dict(instance) + + isnew = state.runid != runid + currentload = not isnew + loaded_instance = False + + if version_check and version_id_getter and not currentload: + _validate_version_id( + mapper, state, dict_, row, version_id_getter + ) + + else: + # create a new instance + + # check for non-NULL values in the primary key columns, + # else no entity is returned for the row + if is_not_primary_key(identitykey[1]): + return None + + isnew = True + currentload = True + loaded_instance = True + + instance = mapper.class_manager.new_instance() + + dict_ = instance_dict(instance) + state = instance_state(instance) + state.key = identitykey + state.identity_token = identity_token + + # attach instance to session. + state.session_id = session_id + session_identity_map._add_unpresent(state, identitykey) + + effective_populate_existing = populate_existing + if refresh_state is state: + effective_populate_existing = True + + # populate. this looks at whether this state is new + # for this load or was existing, and whether or not this + # row is the first row with this identity. + if currentload or effective_populate_existing: + # full population routines. Objects here are either + # just created, or we are doing a populate_existing + + # be conservative about setting load_path when populate_existing + # is in effect; want to maintain options from the original + # load. see test_expire->test_refresh_maintains_deferred_options + if isnew and ( + propagated_loader_options or not effective_populate_existing + ): + state.load_options = propagated_loader_options + state.load_path = load_path + + _populate_full( + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + effective_populate_existing, + populators, + ) + + if isnew: + # state.runid should be equal to context.runid / runid + # here, however for event checks we are being more conservative + # and checking against existing run id + # assert state.runid == runid + + existing_runid = state.runid + + if loaded_instance: + if load_evt: + state.manager.dispatch.load(state, context) + if state.runid != existing_runid: + _warn_for_runid_changed(state) + if persistent_evt: + loaded_as_persistent(context.session, state) + if state.runid != existing_runid: + _warn_for_runid_changed(state) + elif refresh_evt: + state.manager.dispatch.refresh( + state, context, only_load_props + ) + if state.runid != runid: + _warn_for_runid_changed(state) + + if effective_populate_existing or state.modified: + if refresh_state and only_load_props: + state._commit(dict_, only_load_props) + else: + state._commit_all(dict_, session_identity_map) + + if post_load: + post_load.add_state(state, True) + + else: + # partial population routines, for objects that were already + # in the Session, but a row matches them; apply eager loaders + # on existing objects, etc. + unloaded = state.unloaded + isnew = state not in context.partials + + if not isnew or unloaded or populators["eager"]: + # state is having a partial set of its attributes + # refreshed. Populate those attributes, + # and add to the "context.partials" collection. + + to_load = _populate_partial( + context, + row, + state, + dict_, + isnew, + load_path, + unloaded, + populators, + ) + + if isnew: + if refresh_evt: + existing_runid = state.runid + state.manager.dispatch.refresh(state, context, to_load) + if state.runid != existing_runid: + _warn_for_runid_changed(state) + + state._commit(dict_, to_load) + + if post_load and context.invoke_all_eagers: + post_load.add_state(state, False) + + return instance + + if mapper.polymorphic_map and not _polymorphic_from and not refresh_state: + # if we are doing polymorphic, dispatch to a different _instance() + # method specific to the subclass mapper + def ensure_no_pk(row): + identitykey = ( + identity_class, + primary_key_getter(row), + identity_token, + ) + if not is_not_primary_key(identitykey[1]): + return identitykey + else: + return None + + _instance = _decorate_polymorphic_switch( + _instance, + context, + query_entity, + mapper, + result, + path, + polymorphic_discriminator, + adapter, + ensure_no_pk, + ) + + return _instance + + +def _load_subclass_via_in( + context, path, entity, polymorphic_from, option_entities +): + mapper = entity.mapper + + # TODO: polymorphic_from seems to be a Mapper in all cases. + # this is likely not needed, but as we dont have typing in loading.py + # yet, err on the safe side + polymorphic_from_mapper = polymorphic_from.mapper + not_against_basemost = polymorphic_from_mapper.inherits is not None + + zero_idx = len(mapper.base_mapper.primary_key) == 1 + + if entity.is_aliased_class or not_against_basemost: + q, enable_opt, disable_opt = mapper._subclass_load_via_in( + entity, polymorphic_from + ) + else: + q, enable_opt, disable_opt = mapper._subclass_load_via_in_mapper + + def do_load(context, path, states, load_only, effective_entity): + if not option_entities: + # filter out states for those that would have selectinloaded + # from another loader + # TODO: we are currently ignoring the case where the + # "selectin_polymorphic" option is used, as this is much more + # complex / specific / very uncommon API use + states = [ + (s, v) + for s, v in states + if s.mapper._would_selectin_load_only_from_given_mapper(mapper) + ] + + if not states: + return + + orig_query = context.query + + if path.parent: + enable_opt_lcl = enable_opt._prepend_path(path) + disable_opt_lcl = disable_opt._prepend_path(path) + else: + enable_opt_lcl = enable_opt + disable_opt_lcl = disable_opt + options = ( + (enable_opt_lcl,) + orig_query._with_options + (disable_opt_lcl,) + ) + + q2 = q.options(*options) + + q2._compile_options = context.compile_state.default_compile_options + q2._compile_options += {"_current_path": path.parent} + + if context.populate_existing: + q2 = q2.execution_options(populate_existing=True) + + context.session.execute( + q2, + dict( + primary_keys=[ + state.key[1][0] if zero_idx else state.key[1] + for state, load_attrs in states + ] + ), + ).unique().scalars().all() + + return do_load + + +def _populate_full( + context, + row, + state, + dict_, + isnew, + load_path, + loaded_instance, + populate_existing, + populators, +): + if isnew: + # first time we are seeing a row with this identity. + state.runid = context.runid + + for key, getter in populators["quick"]: + dict_[key] = getter(row) + if populate_existing: + for key, set_callable in populators["expire"]: + dict_.pop(key, None) + if set_callable: + state.expired_attributes.add(key) + else: + for key, set_callable in populators["expire"]: + if set_callable: + state.expired_attributes.add(key) + + for key, populator in populators["new"]: + populator(state, dict_, row) + + elif load_path != state.load_path: + # new load path, e.g. object is present in more than one + # column position in a series of rows + state.load_path = load_path + + # if we have data, and the data isn't in the dict, OK, let's put + # it in. + for key, getter in populators["quick"]: + if key not in dict_: + dict_[key] = getter(row) + + # otherwise treat like an "already seen" row + for key, populator in populators["existing"]: + populator(state, dict_, row) + # TODO: allow "existing" populator to know this is + # a new path for the state: + # populator(state, dict_, row, new_path=True) + + else: + # have already seen rows with this identity in this same path. + for key, populator in populators["existing"]: + populator(state, dict_, row) + + # TODO: same path + # populator(state, dict_, row, new_path=False) + + +def _populate_partial( + context, row, state, dict_, isnew, load_path, unloaded, populators +): + if not isnew: + if unloaded: + # extra pass, see #8166 + for key, getter in populators["quick"]: + if key in unloaded: + dict_[key] = getter(row) + + to_load = context.partials[state] + for key, populator in populators["existing"]: + if key in to_load: + populator(state, dict_, row) + else: + to_load = unloaded + context.partials[state] = to_load + + for key, getter in populators["quick"]: + if key in to_load: + dict_[key] = getter(row) + for key, set_callable in populators["expire"]: + if key in to_load: + dict_.pop(key, None) + if set_callable: + state.expired_attributes.add(key) + for key, populator in populators["new"]: + if key in to_load: + populator(state, dict_, row) + + for key, populator in populators["eager"]: + if key not in unloaded: + populator(state, dict_, row) + + return to_load + + +def _validate_version_id(mapper, state, dict_, row, getter): + if mapper._get_state_attr_by_column( + state, dict_, mapper.version_id_col + ) != getter(row): + raise orm_exc.StaleDataError( + "Instance '%s' has version id '%s' which " + "does not match database-loaded version id '%s'." + % ( + state_str(state), + mapper._get_state_attr_by_column( + state, dict_, mapper.version_id_col + ), + getter(row), + ) + ) + + +def _decorate_polymorphic_switch( + instance_fn, + context, + query_entity, + mapper, + result, + path, + polymorphic_discriminator, + adapter, + ensure_no_pk, +): + if polymorphic_discriminator is not None: + polymorphic_on = polymorphic_discriminator + else: + polymorphic_on = mapper.polymorphic_on + if polymorphic_on is None: + return instance_fn + + if adapter: + polymorphic_on = adapter.columns[polymorphic_on] + + def configure_subclass_mapper(discriminator): + try: + sub_mapper = mapper.polymorphic_map[discriminator] + except KeyError: + raise AssertionError( + "No such polymorphic_identity %r is defined" % discriminator + ) + else: + if sub_mapper is mapper: + return None + elif not sub_mapper.isa(mapper): + return False + + return _instance_processor( + query_entity, + sub_mapper, + context, + result, + path, + adapter, + _polymorphic_from=mapper, + ) + + polymorphic_instances = util.PopulateDict(configure_subclass_mapper) + + getter = result._getter(polymorphic_on) + + def polymorphic_instance(row): + discriminator = getter(row) + if discriminator is not None: + _instance = polymorphic_instances[discriminator] + if _instance: + return _instance(row) + elif _instance is False: + identitykey = ensure_no_pk(row) + + if identitykey: + raise sa_exc.InvalidRequestError( + "Row with identity key %s can't be loaded into an " + "object; the polymorphic discriminator column '%s' " + "refers to %s, which is not a sub-mapper of " + "the requested %s" + % ( + identitykey, + polymorphic_on, + mapper.polymorphic_map[discriminator], + mapper, + ) + ) + else: + return None + else: + return instance_fn(row) + else: + identitykey = ensure_no_pk(row) + + if identitykey: + raise sa_exc.InvalidRequestError( + "Row with identity key %s can't be loaded into an " + "object; the polymorphic discriminator column '%s' is " + "NULL" % (identitykey, polymorphic_on) + ) + else: + return None + + return polymorphic_instance + + +class PostLoad: + """Track loaders and states for "post load" operations.""" + + __slots__ = "loaders", "states", "load_keys" + + def __init__(self): + self.loaders = {} + self.states = util.OrderedDict() + self.load_keys = None + + def add_state(self, state, overwrite): + # the states for a polymorphic load here are all shared + # within a single PostLoad object among multiple subtypes. + # Filtering of callables on a per-subclass basis needs to be done at + # the invocation level + self.states[state] = overwrite + + def invoke(self, context, path): + if not self.states: + return + path = path_registry.PathRegistry.coerce(path) + for ( + effective_context, + token, + limit_to_mapper, + loader, + arg, + kw, + ) in self.loaders.values(): + states = [ + (state, overwrite) + for state, overwrite in self.states.items() + if state.manager.mapper.isa(limit_to_mapper) + ] + if states: + loader( + effective_context, path, states, self.load_keys, *arg, **kw + ) + self.states.clear() + + @classmethod + def for_context(cls, context, path, only_load_props): + pl = context.post_load_paths.get(path.path) + if pl is not None and only_load_props: + pl.load_keys = only_load_props + return pl + + @classmethod + def path_exists(self, context, path, key): + return ( + path.path in context.post_load_paths + and key in context.post_load_paths[path.path].loaders + ) + + @classmethod + def callable_for_path( + cls, context, path, limit_to_mapper, token, loader_callable, *arg, **kw + ): + if path.path in context.post_load_paths: + pl = context.post_load_paths[path.path] + else: + pl = context.post_load_paths[path.path] = PostLoad() + pl.loaders[token] = ( + context, + token, + limit_to_mapper, + loader_callable, + arg, + kw, + ) + + +def load_scalar_attributes(mapper, state, attribute_names, passive): + """initiate a column-based attribute refresh operation.""" + + # assert mapper is _state_mapper(state) + session = state.session + if not session: + raise orm_exc.DetachedInstanceError( + "Instance %s is not bound to a Session; " + "attribute refresh operation cannot proceed" % (state_str(state)) + ) + + no_autoflush = bool(passive & attributes.NO_AUTOFLUSH) + + # in the case of inheritance, particularly concrete and abstract + # concrete inheritance, the class manager might have some keys + # of attributes on the superclass that we didn't actually map. + # These could be mapped as "concrete, don't load" or could be completely + # excluded from the mapping and we know nothing about them. Filter them + # here to prevent them from coming through. + if attribute_names: + attribute_names = attribute_names.intersection(mapper.attrs.keys()) + + if mapper.inherits and not mapper.concrete: + # load based on committed attributes in the object, formed into + # a truncated SELECT that only includes relevant tables. does not + # currently use state.key + statement = mapper._optimized_get_statement(state, attribute_names) + if statement is not None: + # undefer() isn't needed here because statement has the + # columns needed already, this implicitly undefers that column + stmt = FromStatement(mapper, statement) + + return load_on_ident( + session, + stmt, + None, + only_load_props=attribute_names, + refresh_state=state, + no_autoflush=no_autoflush, + ) + + # normal load, use state.key as the identity to SELECT + has_key = bool(state.key) + + if has_key: + identity_key = state.key + else: + # this codepath is rare - only valid when inside a flush, and the + # object is becoming persistent but hasn't yet been assigned + # an identity_key. + # check here to ensure we have the attrs we need. + pk_attrs = [ + mapper._columntoproperty[col].key for col in mapper.primary_key + ] + if state.expired_attributes.intersection(pk_attrs): + raise sa_exc.InvalidRequestError( + "Instance %s cannot be refreshed - it's not " + " persistent and does not " + "contain a full primary key." % state_str(state) + ) + identity_key = mapper._identity_key_from_state(state) + + if ( + _none_set.issubset(identity_key) and not mapper.allow_partial_pks + ) or _none_set.issuperset(identity_key): + util.warn_limited( + "Instance %s to be refreshed doesn't " + "contain a full primary key - can't be refreshed " + "(and shouldn't be expired, either).", + state_str(state), + ) + return + + result = load_on_ident( + session, + select(mapper).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL), + identity_key, + refresh_state=state, + only_load_props=attribute_names, + no_autoflush=no_autoflush, + ) + + # if instance is pending, a refresh operation + # may not complete (even if PK attributes are assigned) + if has_key and result is None: + raise orm_exc.ObjectDeletedError(state) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/mapped_collection.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/mapped_collection.py new file mode 100644 index 0000000..13c6b68 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/mapped_collection.py @@ -0,0 +1,560 @@ +# orm/mapped_collection.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import operator +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import base +from .collections import collection +from .collections import collection_adapter +from .. import exc as sa_exc +from .. import util +from ..sql import coercions +from ..sql import expression +from ..sql import roles +from ..util.typing import Literal + +if TYPE_CHECKING: + from . import AttributeEventToken + from . import Mapper + from .collections import CollectionAdapter + from ..sql.elements import ColumnElement + +_KT = TypeVar("_KT", bound=Any) +_VT = TypeVar("_VT", bound=Any) + +_F = TypeVar("_F", bound=Callable[[Any], Any]) + + +class _PlainColumnGetter(Generic[_KT]): + """Plain column getter, stores collection of Column objects + directly. + + Serializes to a :class:`._SerializableColumnGetterV2` + which has more expensive __call__() performance + and some rare caveats. + + """ + + __slots__ = ("cols", "composite") + + def __init__(self, cols: Sequence[ColumnElement[_KT]]) -> None: + self.cols = cols + self.composite = len(cols) > 1 + + def __reduce__( + self, + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: + return _SerializableColumnGetterV2._reduce_from_cols(self.cols) + + def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: + return self.cols + + def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]: + state = base.instance_state(value) + m = base._state_mapper(state) + + key: List[_KT] = [ + m._get_state_attr_by_column(state, state.dict, col) + for col in self._cols(m) + ] + if self.composite: + return tuple(key) + else: + obj = key[0] + if obj is None: + return _UNMAPPED_AMBIGUOUS_NONE + else: + return obj + + +class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]): + """Updated serializable getter which deals with + multi-table mapped classes. + + Two extremely unusual cases are not supported. + Mappings which have tables across multiple metadata + objects, or which are mapped to non-Table selectables + linked across inheriting mappers may fail to function + here. + + """ + + __slots__ = ("colkeys",) + + def __init__( + self, colkeys: Sequence[Tuple[Optional[str], Optional[str]]] + ) -> None: + self.colkeys = colkeys + self.composite = len(colkeys) > 1 + + def __reduce__( + self, + ) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: + return self.__class__, (self.colkeys,) + + @classmethod + def _reduce_from_cols(cls, cols: Sequence[ColumnElement[_KT]]) -> Tuple[ + Type[_SerializableColumnGetterV2[_KT]], + Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], + ]: + def _table_key(c: ColumnElement[_KT]) -> Optional[str]: + if not isinstance(c.table, expression.TableClause): + return None + else: + return c.table.key # type: ignore + + colkeys = [(c.key, _table_key(c)) for c in cols] + return _SerializableColumnGetterV2, (colkeys,) + + def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: + cols: List[ColumnElement[_KT]] = [] + metadata = getattr(mapper.local_table, "metadata", None) + for ckey, tkey in self.colkeys: + if tkey is None or metadata is None or tkey not in metadata: + cols.append(mapper.local_table.c[ckey]) # type: ignore + else: + cols.append(metadata.tables[tkey].c[ckey]) + return cols + + +def column_keyed_dict( + mapping_spec: Union[Type[_KT], Callable[[_KT], _VT]], + *, + ignore_unpopulated_attribute: bool = False, +) -> Type[KeyFuncDict[_KT, _KT]]: + """A dictionary-based collection type with column-based keying. + + .. versionchanged:: 2.0 Renamed :data:`.column_mapped_collection` to + :class:`.column_keyed_dict`. + + Returns a :class:`.KeyFuncDict` factory which will produce new + dictionary keys based on the value of a particular :class:`.Column`-mapped + attribute on ORM mapped instances to be added to the dictionary. + + .. note:: the value of the target attribute must be assigned with its + value at the time that the object is being added to the + dictionary collection. Additionally, changes to the key attribute + are **not tracked**, which means the key in the dictionary is not + automatically synchronized with the key value on the target object + itself. See :ref:`key_collections_mutations` for further details. + + .. seealso:: + + :ref:`orm_dictionary_collection` - background on use + + :param mapping_spec: a :class:`_schema.Column` object that is expected + to be mapped by the target mapper to a particular attribute on the + mapped class, the value of which on a particular instance is to be used + as the key for a new dictionary entry for that instance. + :param ignore_unpopulated_attribute: if True, and the mapped attribute + indicated by the given :class:`_schema.Column` target attribute + on an object is not populated at all, the operation will be silently + skipped. By default, an error is raised. + + .. versionadded:: 2.0 an error is raised by default if the attribute + being used for the dictionary key is determined that it was never + populated with any value. The + :paramref:`_orm.column_keyed_dict.ignore_unpopulated_attribute` + parameter may be set which will instead indicate that this condition + should be ignored, and the append operation silently skipped. + This is in contrast to the behavior of the 1.x series which would + erroneously populate the value in the dictionary with an arbitrary key + value of ``None``. + + + """ + cols = [ + coercions.expect(roles.ColumnArgumentRole, q, argname="mapping_spec") + for q in util.to_list(mapping_spec) + ] + keyfunc = _PlainColumnGetter(cols) + return _mapped_collection_cls( + keyfunc, + ignore_unpopulated_attribute=ignore_unpopulated_attribute, + ) + + +_UNMAPPED_AMBIGUOUS_NONE = object() + + +class _AttrGetter: + __slots__ = ("attr_name", "getter") + + def __init__(self, attr_name: str): + self.attr_name = attr_name + self.getter = operator.attrgetter(attr_name) + + def __call__(self, mapped_object: Any) -> Any: + obj = self.getter(mapped_object) + if obj is None: + state = base.instance_state(mapped_object) + mp = state.mapper + if self.attr_name in mp.attrs: + dict_ = state.dict + obj = dict_.get(self.attr_name, base.NO_VALUE) + if obj is None: + return _UNMAPPED_AMBIGUOUS_NONE + else: + return _UNMAPPED_AMBIGUOUS_NONE + + return obj + + def __reduce__(self) -> Tuple[Type[_AttrGetter], Tuple[str]]: + return _AttrGetter, (self.attr_name,) + + +def attribute_keyed_dict( + attr_name: str, *, ignore_unpopulated_attribute: bool = False +) -> Type[KeyFuncDict[Any, Any]]: + """A dictionary-based collection type with attribute-based keying. + + .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to + :func:`.attribute_keyed_dict`. + + Returns a :class:`.KeyFuncDict` factory which will produce new + dictionary keys based on the value of a particular named attribute on + ORM mapped instances to be added to the dictionary. + + .. note:: the value of the target attribute must be assigned with its + value at the time that the object is being added to the + dictionary collection. Additionally, changes to the key attribute + are **not tracked**, which means the key in the dictionary is not + automatically synchronized with the key value on the target object + itself. See :ref:`key_collections_mutations` for further details. + + .. seealso:: + + :ref:`orm_dictionary_collection` - background on use + + :param attr_name: string name of an ORM-mapped attribute + on the mapped class, the value of which on a particular instance + is to be used as the key for a new dictionary entry for that instance. + :param ignore_unpopulated_attribute: if True, and the target attribute + on an object is not populated at all, the operation will be silently + skipped. By default, an error is raised. + + .. versionadded:: 2.0 an error is raised by default if the attribute + being used for the dictionary key is determined that it was never + populated with any value. The + :paramref:`_orm.attribute_keyed_dict.ignore_unpopulated_attribute` + parameter may be set which will instead indicate that this condition + should be ignored, and the append operation silently skipped. + This is in contrast to the behavior of the 1.x series which would + erroneously populate the value in the dictionary with an arbitrary key + value of ``None``. + + + """ + + return _mapped_collection_cls( + _AttrGetter(attr_name), + ignore_unpopulated_attribute=ignore_unpopulated_attribute, + ) + + +def keyfunc_mapping( + keyfunc: _F, + *, + ignore_unpopulated_attribute: bool = False, +) -> Type[KeyFuncDict[_KT, Any]]: + """A dictionary-based collection type with arbitrary keying. + + .. versionchanged:: 2.0 Renamed :data:`.mapped_collection` to + :func:`.keyfunc_mapping`. + + Returns a :class:`.KeyFuncDict` factory with a keying function + generated from keyfunc, a callable that takes an entity and returns a + key value. + + .. note:: the given keyfunc is called only once at the time that the + target object is being added to the collection. Changes to the + effective value returned by the function are not tracked. + + + .. seealso:: + + :ref:`orm_dictionary_collection` - background on use + + :param keyfunc: a callable that will be passed the ORM-mapped instance + which should then generate a new key to use in the dictionary. + If the value returned is :attr:`.LoaderCallableStatus.NO_VALUE`, an error + is raised. + :param ignore_unpopulated_attribute: if True, and the callable returns + :attr:`.LoaderCallableStatus.NO_VALUE` for a particular instance, the + operation will be silently skipped. By default, an error is raised. + + .. versionadded:: 2.0 an error is raised by default if the callable + being used for the dictionary key returns + :attr:`.LoaderCallableStatus.NO_VALUE`, which in an ORM attribute + context indicates an attribute that was never populated with any value. + The :paramref:`_orm.mapped_collection.ignore_unpopulated_attribute` + parameter may be set which will instead indicate that this condition + should be ignored, and the append operation silently skipped. This is + in contrast to the behavior of the 1.x series which would erroneously + populate the value in the dictionary with an arbitrary key value of + ``None``. + + + """ + return _mapped_collection_cls( + keyfunc, ignore_unpopulated_attribute=ignore_unpopulated_attribute + ) + + +class KeyFuncDict(Dict[_KT, _VT]): + """Base for ORM mapped dictionary classes. + + Extends the ``dict`` type with additional methods needed by SQLAlchemy ORM + collection classes. Use of :class:`_orm.KeyFuncDict` is most directly + by using the :func:`.attribute_keyed_dict` or + :func:`.column_keyed_dict` class factories. + :class:`_orm.KeyFuncDict` may also serve as the base for user-defined + custom dictionary classes. + + .. versionchanged:: 2.0 Renamed :class:`.MappedCollection` to + :class:`.KeyFuncDict`. + + .. seealso:: + + :func:`_orm.attribute_keyed_dict` + + :func:`_orm.column_keyed_dict` + + :ref:`orm_dictionary_collection` + + :ref:`orm_custom_collection` + + + """ + + def __init__( + self, + keyfunc: _F, + *dict_args: Any, + ignore_unpopulated_attribute: bool = False, + ) -> None: + """Create a new collection with keying provided by keyfunc. + + keyfunc may be any callable that takes an object and returns an object + for use as a dictionary key. + + The keyfunc will be called every time the ORM needs to add a member by + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- + ``keyfunc(object)`` should return the same output for the life of the + collection. Keying based on mutable properties can result in + unreachable instances "lost" in the collection. + + """ + self.keyfunc = keyfunc + self.ignore_unpopulated_attribute = ignore_unpopulated_attribute + super().__init__(*dict_args) + + @classmethod + def _unreduce( + cls, + keyfunc: _F, + values: Dict[_KT, _KT], + adapter: Optional[CollectionAdapter] = None, + ) -> "KeyFuncDict[_KT, _KT]": + mp: KeyFuncDict[_KT, _KT] = KeyFuncDict(keyfunc) + mp.update(values) + # note that the adapter sets itself up onto this collection + # when its `__setstate__` method is called + return mp + + def __reduce__( + self, + ) -> Tuple[ + Callable[[_KT, _KT], KeyFuncDict[_KT, _KT]], + Tuple[Any, Union[Dict[_KT, _KT], Dict[_KT, _KT]], CollectionAdapter], + ]: + return ( + KeyFuncDict._unreduce, + ( + self.keyfunc, + dict(self), + collection_adapter(self), + ), + ) + + @util.preload_module("sqlalchemy.orm.attributes") + def _raise_for_unpopulated( + self, + value: _KT, + initiator: Union[AttributeEventToken, Literal[None, False]] = None, + *, + warn_only: bool, + ) -> None: + mapper = base.instance_state(value).mapper + + attributes = util.preloaded.orm_attributes + + if not isinstance(initiator, attributes.AttributeEventToken): + relationship = "unknown relationship" + elif initiator.key in mapper.attrs: + relationship = f"{mapper.attrs[initiator.key]}" + else: + relationship = initiator.key + + if warn_only: + util.warn( + f"Attribute keyed dictionary value for " + f"attribute '{relationship}' was None; this will raise " + "in a future release. " + f"To skip this assignment entirely, " + f'Set the "ignore_unpopulated_attribute=True" ' + f"parameter on the mapped collection factory." + ) + else: + raise sa_exc.InvalidRequestError( + "In event triggered from population of " + f"attribute '{relationship}' " + "(potentially from a backref), " + f"can't populate value in KeyFuncDict; " + "dictionary key " + f"derived from {base.instance_str(value)} is not " + f"populated. Ensure appropriate state is set up on " + f"the {base.instance_str(value)} object " + f"before assigning to the {relationship} attribute. " + f"To skip this assignment entirely, " + f'Set the "ignore_unpopulated_attribute=True" ' + f"parameter on the mapped collection factory." + ) + + @collection.appender # type: ignore[misc] + @collection.internally_instrumented # type: ignore[misc] + def set( + self, + value: _KT, + _sa_initiator: Union[AttributeEventToken, Literal[None, False]] = None, + ) -> None: + """Add an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + + if key is base.NO_VALUE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=False + ) + else: + return + elif key is _UNMAPPED_AMBIGUOUS_NONE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=True + ) + key = None + else: + return + + self.__setitem__(key, value, _sa_initiator) # type: ignore[call-arg] + + @collection.remover # type: ignore[misc] + @collection.internally_instrumented # type: ignore[misc] + def remove( + self, + value: _KT, + _sa_initiator: Union[AttributeEventToken, Literal[None, False]] = None, + ) -> None: + """Remove an item by value, consulting the keyfunc for the key.""" + + key = self.keyfunc(value) + + if key is base.NO_VALUE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=False + ) + return + elif key is _UNMAPPED_AMBIGUOUS_NONE: + if not self.ignore_unpopulated_attribute: + self._raise_for_unpopulated( + value, _sa_initiator, warn_only=True + ) + key = None + else: + return + + # Let self[key] raise if key is not in this collection + # testlib.pragma exempt:__ne__ + if self[key] != value: + raise sa_exc.InvalidRequestError( + "Can not remove '%s': collection holds '%s' for key '%s'. " + "Possible cause: is the KeyFuncDict key function " + "based on mutable properties or properties that only obtain " + "values after flush?" % (value, self[key], key) + ) + self.__delitem__(key, _sa_initiator) # type: ignore[call-arg] + + +def _mapped_collection_cls( + keyfunc: _F, ignore_unpopulated_attribute: bool +) -> Type[KeyFuncDict[_KT, _KT]]: + class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]): + def __init__(self, *dict_args: Any) -> None: + super().__init__( + keyfunc, + *dict_args, + ignore_unpopulated_attribute=ignore_unpopulated_attribute, + ) + + return _MKeyfuncMapped + + +MappedCollection = KeyFuncDict +"""A synonym for :class:`.KeyFuncDict`. + +.. versionchanged:: 2.0 Renamed :class:`.MappedCollection` to + :class:`.KeyFuncDict`. + +""" + +mapped_collection = keyfunc_mapping +"""A synonym for :func:`_orm.keyfunc_mapping`. + +.. versionchanged:: 2.0 Renamed :data:`.mapped_collection` to + :func:`_orm.keyfunc_mapping` + +""" + +attribute_mapped_collection = attribute_keyed_dict +"""A synonym for :func:`_orm.attribute_keyed_dict`. + +.. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to + :func:`_orm.attribute_keyed_dict` + +""" + +column_mapped_collection = column_keyed_dict +"""A synonym for :func:`_orm.column_keyed_dict. + +.. versionchanged:: 2.0 Renamed :func:`.column_mapped_collection` to + :func:`_orm.column_keyed_dict` + +""" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/mapper.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/mapper.py new file mode 100644 index 0000000..0caed0e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/mapper.py @@ -0,0 +1,4420 @@ +# orm/mapper.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Logic to map Python classes to and from selectables. + +Defines the :class:`~sqlalchemy.orm.mapper.Mapper` class, the central +configurational unit which associates a class with a database table. + +This is a semi-private module; the main configurational API of the ORM is +available in :class:`~sqlalchemy.orm.`. + +""" +from __future__ import annotations + +from collections import deque +from functools import reduce +from itertools import chain +import sys +import threading +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Deque +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import exc as orm_exc +from . import instrumentation +from . import loading +from . import properties +from . import util as orm_util +from ._typing import _O +from .base import _class_to_mapper +from .base import _parse_mapper_argument +from .base import _state_mapper +from .base import PassiveFlag +from .base import state_str +from .interfaces import _MappedAttribute +from .interfaces import EXT_SKIP +from .interfaces import InspectionAttr +from .interfaces import MapperProperty +from .interfaces import ORMEntityColumnsClauseRole +from .interfaces import ORMFromClauseRole +from .interfaces import StrategizedProperty +from .path_registry import PathRegistry +from .. import event +from .. import exc as sa_exc +from .. import inspection +from .. import log +from .. import schema +from .. import sql +from .. import util +from ..event import dispatcher +from ..event import EventTarget +from ..sql import base as sql_base +from ..sql import coercions +from ..sql import expression +from ..sql import operators +from ..sql import roles +from ..sql import TableClause +from ..sql import util as sql_util +from ..sql import visitors +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import KeyedColumnElement +from ..sql.schema import Column +from ..sql.schema import Table +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import HasMemoized +from ..util import HasMemoized_ro_memoized_attribute +from ..util.typing import Literal + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .decl_api import registry + from .dependency import DependencyProcessor + from .descriptor_props import CompositeProperty + from .descriptor_props import SynonymProperty + from .events import MapperEvents + from .instrumentation import ClassManager + from .path_registry import CachingEntityRegistry + from .properties import ColumnProperty + from .relationships import RelationshipProperty + from .state import InstanceState + from .util import ORMAdapter + from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import ColumnClause + from ..sql.elements import ColumnElement + from ..sql.selectable import FromClause + from ..util import OrderedSet + + +_T = TypeVar("_T", bound=Any) +_MP = TypeVar("_MP", bound="MapperProperty[Any]") +_Fn = TypeVar("_Fn", bound="Callable[..., Any]") + + +_WithPolymorphicArg = Union[ + Literal["*"], + Tuple[ + Union[Literal["*"], Sequence[Union["Mapper[Any]", Type[Any]]]], + Optional["FromClause"], + ], + Sequence[Union["Mapper[Any]", Type[Any]]], +] + + +_mapper_registries: weakref.WeakKeyDictionary[_RegistryType, bool] = ( + weakref.WeakKeyDictionary() +) + + +def _all_registries() -> Set[registry]: + with _CONFIGURE_MUTEX: + return set(_mapper_registries) + + +def _unconfigured_mappers() -> Iterator[Mapper[Any]]: + for reg in _all_registries(): + yield from reg._mappers_to_configure() + + +_already_compiling = False + + +# a constant returned by _get_attr_by_column to indicate +# this mapper is not handling an attribute for a particular +# column +NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE") + +# lock used to synchronize the "mapper configure" step +_CONFIGURE_MUTEX = threading.RLock() + + +@inspection._self_inspects +@log.class_logger +class Mapper( + ORMFromClauseRole, + ORMEntityColumnsClauseRole[_O], + MemoizedHasCacheKey, + InspectionAttr, + log.Identified, + inspection.Inspectable["Mapper[_O]"], + EventTarget, + Generic[_O], +): + """Defines an association between a Python class and a database table or + other relational structure, so that ORM operations against the class may + proceed. + + The :class:`_orm.Mapper` object is instantiated using mapping methods + present on the :class:`_orm.registry` object. For information + about instantiating new :class:`_orm.Mapper` objects, see + :ref:`orm_mapping_classes_toplevel`. + + """ + + dispatch: dispatcher[Mapper[_O]] + + _dispose_called = False + _configure_failed: Any = False + _ready_for_configure = False + + @util.deprecated_params( + non_primary=( + "1.3", + "The :paramref:`.mapper.non_primary` parameter is deprecated, " + "and will be removed in a future release. The functionality " + "of non primary mappers is now better suited using the " + ":class:`.AliasedClass` construct, which can also be used " + "as the target of a :func:`_orm.relationship` in 1.3.", + ), + ) + def __init__( + self, + class_: Type[_O], + local_table: Optional[FromClause] = None, + properties: Optional[Mapping[str, MapperProperty[Any]]] = None, + primary_key: Optional[Iterable[_ORMColumnExprArgument[Any]]] = None, + non_primary: bool = False, + inherits: Optional[Union[Mapper[Any], Type[Any]]] = None, + inherit_condition: Optional[_ColumnExpressionArgument[bool]] = None, + inherit_foreign_keys: Optional[ + Sequence[_ORMColumnExprArgument[Any]] + ] = None, + always_refresh: bool = False, + version_id_col: Optional[_ORMColumnExprArgument[Any]] = None, + version_id_generator: Optional[ + Union[Literal[False], Callable[[Any], Any]] + ] = None, + polymorphic_on: Optional[ + Union[_ORMColumnExprArgument[Any], str, MapperProperty[Any]] + ] = None, + _polymorphic_map: Optional[Dict[Any, Mapper[Any]]] = None, + polymorphic_identity: Optional[Any] = None, + concrete: bool = False, + with_polymorphic: Optional[_WithPolymorphicArg] = None, + polymorphic_abstract: bool = False, + polymorphic_load: Optional[Literal["selectin", "inline"]] = None, + allow_partial_pks: bool = True, + batch: bool = True, + column_prefix: Optional[str] = None, + include_properties: Optional[Sequence[str]] = None, + exclude_properties: Optional[Sequence[str]] = None, + passive_updates: bool = True, + passive_deletes: bool = False, + confirm_deleted_rows: bool = True, + eager_defaults: Literal[True, False, "auto"] = "auto", + legacy_is_orphan: bool = False, + _compiled_cache_size: int = 100, + ): + r"""Direct constructor for a new :class:`_orm.Mapper` object. + + The :class:`_orm.Mapper` constructor is not called directly, and + is normally invoked through the + use of the :class:`_orm.registry` object through either the + :ref:`Declarative ` or + :ref:`Imperative ` mapping styles. + + .. versionchanged:: 2.0 The public facing ``mapper()`` function is + removed; for a classical mapping configuration, use the + :meth:`_orm.registry.map_imperatively` method. + + Parameters documented below may be passed to either the + :meth:`_orm.registry.map_imperatively` method, or may be passed in the + ``__mapper_args__`` declarative class attribute described at + :ref:`orm_declarative_mapper_options`. + + :param class\_: The class to be mapped. When using Declarative, + this argument is automatically passed as the declared class + itself. + + :param local_table: The :class:`_schema.Table` or other + :class:`_sql.FromClause` (i.e. selectable) to which the class is + mapped. May be ``None`` if this mapper inherits from another mapper + using single-table inheritance. When using Declarative, this + argument is automatically passed by the extension, based on what is + configured via the :attr:`_orm.DeclarativeBase.__table__` attribute + or via the :class:`_schema.Table` produced as a result of + the :attr:`_orm.DeclarativeBase.__tablename__` attribute being + present. + + :param polymorphic_abstract: Indicates this class will be mapped in a + polymorphic hierarchy, but not directly instantiated. The class is + mapped normally, except that it has no requirement for a + :paramref:`_orm.Mapper.polymorphic_identity` within an inheritance + hierarchy. The class however must be part of a polymorphic + inheritance scheme which uses + :paramref:`_orm.Mapper.polymorphic_on` at the base. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_inheritance_abstract_poly` + + :param always_refresh: If True, all query operations for this mapped + class will overwrite all data within object instances that already + exist within the session, erasing any in-memory changes with + whatever information was loaded from the database. Usage of this + flag is highly discouraged; as an alternative, see the method + :meth:`_query.Query.populate_existing`. + + :param allow_partial_pks: Defaults to True. Indicates that a + composite primary key with some NULL values should be considered as + possibly existing within the database. This affects whether a + mapper will assign an incoming row to an existing identity, as well + as if :meth:`.Session.merge` will check the database first for a + particular primary key value. A "partial primary key" can occur if + one has mapped to an OUTER JOIN, for example. + + :param batch: Defaults to ``True``, indicating that save operations + of multiple entities can be batched together for efficiency. + Setting to False indicates + that an instance will be fully saved before saving the next + instance. This is used in the extremely rare case that a + :class:`.MapperEvents` listener requires being called + in between individual row persistence operations. + + :param column_prefix: A string which will be prepended + to the mapped attribute name when :class:`_schema.Column` + objects are automatically assigned as attributes to the + mapped class. Does not affect :class:`.Column` objects that + are mapped explicitly in the :paramref:`.Mapper.properties` + dictionary. + + This parameter is typically useful with imperative mappings + that keep the :class:`.Table` object separate. Below, assuming + the ``user_table`` :class:`.Table` object has columns named + ``user_id``, ``user_name``, and ``password``:: + + class User(Base): + __table__ = user_table + __mapper_args__ = {'column_prefix':'_'} + + The above mapping will assign the ``user_id``, ``user_name``, and + ``password`` columns to attributes named ``_user_id``, + ``_user_name``, and ``_password`` on the mapped ``User`` class. + + The :paramref:`.Mapper.column_prefix` parameter is uncommon in + modern use. For dealing with reflected tables, a more flexible + approach to automating a naming scheme is to intercept the + :class:`.Column` objects as they are reflected; see the section + :ref:`mapper_automated_reflection_schemes` for notes on this usage + pattern. + + :param concrete: If True, indicates this mapper should use concrete + table inheritance with its parent mapper. + + See the section :ref:`concrete_inheritance` for an example. + + :param confirm_deleted_rows: defaults to True; when a DELETE occurs + of one more rows based on specific primary keys, a warning is + emitted when the number of rows matched does not equal the number + of rows expected. This parameter may be set to False to handle the + case where database ON DELETE CASCADE rules may be deleting some of + those rows automatically. The warning may be changed to an + exception in a future release. + + :param eager_defaults: if True, the ORM will immediately fetch the + value of server-generated default values after an INSERT or UPDATE, + rather than leaving them as expired to be fetched on next access. + This can be used for event schemes where the server-generated values + are needed immediately before the flush completes. + + The fetch of values occurs either by using ``RETURNING`` inline + with the ``INSERT`` or ``UPDATE`` statement, or by adding an + additional ``SELECT`` statement subsequent to the ``INSERT`` or + ``UPDATE``, if the backend does not support ``RETURNING``. + + The use of ``RETURNING`` is extremely performant in particular for + ``INSERT`` statements where SQLAlchemy can take advantage of + :ref:`insertmanyvalues `, whereas the use of + an additional ``SELECT`` is relatively poor performing, adding + additional SQL round trips which would be unnecessary if these new + attributes are not to be accessed in any case. + + For this reason, :paramref:`.Mapper.eager_defaults` defaults to the + string value ``"auto"``, which indicates that server defaults for + INSERT should be fetched using ``RETURNING`` if the backing database + supports it and if the dialect in use supports "insertmanyreturning" + for an INSERT statement. If the backing database does not support + ``RETURNING`` or "insertmanyreturning" is not available, server + defaults will not be fetched. + + .. versionchanged:: 2.0.0rc1 added the "auto" option for + :paramref:`.Mapper.eager_defaults` + + .. seealso:: + + :ref:`orm_server_defaults` + + .. versionchanged:: 2.0.0 RETURNING now works with multiple rows + INSERTed at once using the + :ref:`insertmanyvalues ` feature, which + among other things allows the :paramref:`.Mapper.eager_defaults` + feature to be very performant on supporting backends. + + :param exclude_properties: A list or set of string column names to + be excluded from mapping. + + .. seealso:: + + :ref:`include_exclude_cols` + + :param include_properties: An inclusive list or set of string column + names to map. + + .. seealso:: + + :ref:`include_exclude_cols` + + :param inherits: A mapped class or the corresponding + :class:`_orm.Mapper` + of one indicating a superclass to which this :class:`_orm.Mapper` + should *inherit* from. The mapped class here must be a subclass + of the other mapper's class. When using Declarative, this argument + is passed automatically as a result of the natural class + hierarchy of the declared classes. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param inherit_condition: For joined table inheritance, a SQL + expression which will + define how the two tables are joined; defaults to a natural join + between the two tables. + + :param inherit_foreign_keys: When ``inherit_condition`` is used and + the columns present are missing a :class:`_schema.ForeignKey` + configuration, this parameter can be used to specify which columns + are "foreign". In most cases can be left as ``None``. + + :param legacy_is_orphan: Boolean, defaults to ``False``. + When ``True``, specifies that "legacy" orphan consideration + is to be applied to objects mapped by this mapper, which means + that a pending (that is, not persistent) object is auto-expunged + from an owning :class:`.Session` only when it is de-associated + from *all* parents that specify a ``delete-orphan`` cascade towards + this mapper. The new default behavior is that the object is + auto-expunged when it is de-associated with *any* of its parents + that specify ``delete-orphan`` cascade. This behavior is more + consistent with that of a persistent object, and allows behavior to + be consistent in more scenarios independently of whether or not an + orphan object has been flushed yet or not. + + See the change note and example at :ref:`legacy_is_orphan_addition` + for more detail on this change. + + :param non_primary: Specify that this :class:`_orm.Mapper` + is in addition + to the "primary" mapper, that is, the one used for persistence. + The :class:`_orm.Mapper` created here may be used for ad-hoc + mapping of the class to an alternate selectable, for loading + only. + + .. seealso:: + + :ref:`relationship_aliased_class` - the new pattern that removes + the need for the :paramref:`_orm.Mapper.non_primary` flag. + + :param passive_deletes: Indicates DELETE behavior of foreign key + columns when a joined-table inheritance entity is being deleted. + Defaults to ``False`` for a base mapper; for an inheriting mapper, + defaults to ``False`` unless the value is set to ``True`` + on the superclass mapper. + + When ``True``, it is assumed that ON DELETE CASCADE is configured + on the foreign key relationships that link this mapper's table + to its superclass table, so that when the unit of work attempts + to delete the entity, it need only emit a DELETE statement for the + superclass table, and not this table. + + When ``False``, a DELETE statement is emitted for this mapper's + table individually. If the primary key attributes local to this + table are unloaded, then a SELECT must be emitted in order to + validate these attributes; note that the primary key columns + of a joined-table subclass are not part of the "primary key" of + the object as a whole. + + Note that a value of ``True`` is **always** forced onto the + subclass mappers; that is, it's not possible for a superclass + to specify passive_deletes without this taking effect for + all subclass mappers. + + .. seealso:: + + :ref:`passive_deletes` - description of similar feature as + used with :func:`_orm.relationship` + + :paramref:`.mapper.passive_updates` - supporting ON UPDATE + CASCADE for joined-table inheritance mappers + + :param passive_updates: Indicates UPDATE behavior of foreign key + columns when a primary key column changes on a joined-table + inheritance mapping. Defaults to ``True``. + + When True, it is assumed that ON UPDATE CASCADE is configured on + the foreign key in the database, and that the database will handle + propagation of an UPDATE from a source column to dependent columns + on joined-table rows. + + When False, it is assumed that the database does not enforce + referential integrity and will not be issuing its own CASCADE + operation for an update. The unit of work process will + emit an UPDATE statement for the dependent columns during a + primary key change. + + .. seealso:: + + :ref:`passive_updates` - description of a similar feature as + used with :func:`_orm.relationship` + + :paramref:`.mapper.passive_deletes` - supporting ON DELETE + CASCADE for joined-table inheritance mappers + + :param polymorphic_load: Specifies "polymorphic loading" behavior + for a subclass in an inheritance hierarchy (joined and single + table inheritance only). Valid values are: + + * "'inline'" - specifies this class should be part of + the "with_polymorphic" mappers, e.g. its columns will be included + in a SELECT query against the base. + + * "'selectin'" - specifies that when instances of this class + are loaded, an additional SELECT will be emitted to retrieve + the columns specific to this subclass. The SELECT uses + IN to fetch multiple subclasses at once. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`with_polymorphic_mapper_config` + + :ref:`polymorphic_selectin` + + :param polymorphic_on: Specifies the column, attribute, or + SQL expression used to determine the target class for an + incoming row, when inheriting classes are present. + + May be specified as a string attribute name, or as a SQL + expression such as a :class:`_schema.Column` or in a Declarative + mapping a :func:`_orm.mapped_column` object. It is typically + expected that the SQL expression corresponds to a column in the + base-most mapped :class:`.Table`:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + discriminator: Mapped[str] = mapped_column(String(50)) + + __mapper_args__ = { + "polymorphic_on":discriminator, + "polymorphic_identity":"employee" + } + + It may also be specified + as a SQL expression, as in this example where we + use the :func:`.case` construct to provide a conditional + approach:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + discriminator: Mapped[str] = mapped_column(String(50)) + + __mapper_args__ = { + "polymorphic_on":case( + (discriminator == "EN", "engineer"), + (discriminator == "MA", "manager"), + else_="employee"), + "polymorphic_identity":"employee" + } + + It may also refer to any attribute using its string name, + which is of particular use when using annotated column + configurations:: + + class Employee(Base): + __tablename__ = 'employee' + + id: Mapped[int] = mapped_column(primary_key=True) + discriminator: Mapped[str] + + __mapper_args__ = { + "polymorphic_on": "discriminator", + "polymorphic_identity": "employee" + } + + When setting ``polymorphic_on`` to reference an + attribute or expression that's not present in the + locally mapped :class:`_schema.Table`, yet the value + of the discriminator should be persisted to the database, + the value of the + discriminator is not automatically set on new + instances; this must be handled by the user, + either through manual means or via event listeners. + A typical approach to establishing such a listener + looks like:: + + from sqlalchemy import event + from sqlalchemy.orm import object_mapper + + @event.listens_for(Employee, "init", propagate=True) + def set_identity(instance, *arg, **kw): + mapper = object_mapper(instance) + instance.discriminator = mapper.polymorphic_identity + + Where above, we assign the value of ``polymorphic_identity`` + for the mapped class to the ``discriminator`` attribute, + thus persisting the value to the ``discriminator`` column + in the database. + + .. warning:: + + Currently, **only one discriminator column may be set**, typically + on the base-most class in the hierarchy. "Cascading" polymorphic + columns are not yet supported. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param polymorphic_identity: Specifies the value which + identifies this particular class as returned by the column expression + referred to by the :paramref:`_orm.Mapper.polymorphic_on` setting. As + rows are received, the value corresponding to the + :paramref:`_orm.Mapper.polymorphic_on` column expression is compared + to this value, indicating which subclass should be used for the newly + reconstructed object. + + .. seealso:: + + :ref:`inheritance_toplevel` + + :param properties: A dictionary mapping the string names of object + attributes to :class:`.MapperProperty` instances, which define the + persistence behavior of that attribute. Note that + :class:`_schema.Column` + objects present in + the mapped :class:`_schema.Table` are automatically placed into + ``ColumnProperty`` instances upon mapping, unless overridden. + When using Declarative, this argument is passed automatically, + based on all those :class:`.MapperProperty` instances declared + in the declared class body. + + .. seealso:: + + :ref:`orm_mapping_properties` - in the + :ref:`orm_mapping_classes_toplevel` + + :param primary_key: A list of :class:`_schema.Column` + objects, or alternatively string names of attribute names which + refer to :class:`_schema.Column`, which define + the primary key to be used against this mapper's selectable unit. + This is normally simply the primary key of the ``local_table``, but + can be overridden here. + + .. versionchanged:: 2.0.2 :paramref:`_orm.Mapper.primary_key` + arguments may be indicated as string attribute names as well. + + .. seealso:: + + :ref:`mapper_primary_key` - background and example use + + :param version_id_col: A :class:`_schema.Column` + that will be used to keep a running version id of rows + in the table. This is used to detect concurrent updates or + the presence of stale data in a flush. The methodology is to + detect if an UPDATE statement does not match the last known + version id, a + :class:`~sqlalchemy.orm.exc.StaleDataError` exception is + thrown. + By default, the column must be of :class:`.Integer` type, + unless ``version_id_generator`` specifies an alternative version + generator. + + .. seealso:: + + :ref:`mapper_version_counter` - discussion of version counting + and rationale. + + :param version_id_generator: Define how new version ids should + be generated. Defaults to ``None``, which indicates that + a simple integer counting scheme be employed. To provide a custom + versioning scheme, provide a callable function of the form:: + + def generate_version(version): + return next_version + + Alternatively, server-side versioning functions such as triggers, + or programmatic versioning schemes outside of the version id + generator may be used, by specifying the value ``False``. + Please see :ref:`server_side_version_counter` for a discussion + of important points when using this option. + + .. seealso:: + + :ref:`custom_version_counter` + + :ref:`server_side_version_counter` + + + :param with_polymorphic: A tuple in the form ``(, + )`` indicating the default style of "polymorphic" + loading, that is, which tables are queried at once. is + any single or list of mappers and/or classes indicating the + inherited classes that should be loaded at once. The special value + ``'*'`` may be used to indicate all descending classes should be + loaded immediately. The second tuple argument + indicates a selectable that will be used to query for multiple + classes. + + The :paramref:`_orm.Mapper.polymorphic_load` parameter may be + preferable over the use of :paramref:`_orm.Mapper.with_polymorphic` + in modern mappings to indicate a per-subclass technique of + indicating polymorphic loading styles. + + .. seealso:: + + :ref:`with_polymorphic_mapper_config` + + """ + self.class_ = util.assert_arg_type(class_, type, "class_") + self._sort_key = "%s.%s" % ( + self.class_.__module__, + self.class_.__name__, + ) + + self._primary_key_argument = util.to_list(primary_key) + self.non_primary = non_primary + + self.always_refresh = always_refresh + + if isinstance(version_id_col, MapperProperty): + self.version_id_prop = version_id_col + self.version_id_col = None + else: + self.version_id_col = ( + coercions.expect( + roles.ColumnArgumentOrKeyRole, + version_id_col, + argname="version_id_col", + ) + if version_id_col is not None + else None + ) + + if version_id_generator is False: + self.version_id_generator = False + elif version_id_generator is None: + self.version_id_generator = lambda x: (x or 0) + 1 + else: + self.version_id_generator = version_id_generator + + self.concrete = concrete + self.single = False + + if inherits is not None: + self.inherits = _parse_mapper_argument(inherits) + else: + self.inherits = None + + if local_table is not None: + self.local_table = coercions.expect( + roles.StrictFromClauseRole, + local_table, + disable_inspection=True, + argname="local_table", + ) + elif self.inherits: + # note this is a new flow as of 2.0 so that + # .local_table need not be Optional + self.local_table = self.inherits.local_table + self.single = True + else: + raise sa_exc.ArgumentError( + f"Mapper[{self.class_.__name__}(None)] has None for a " + "primary table argument and does not specify 'inherits'" + ) + + if inherit_condition is not None: + self.inherit_condition = coercions.expect( + roles.OnClauseRole, inherit_condition + ) + else: + self.inherit_condition = None + + self.inherit_foreign_keys = inherit_foreign_keys + self._init_properties = dict(properties) if properties else {} + self._delete_orphans = [] + self.batch = batch + self.eager_defaults = eager_defaults + self.column_prefix = column_prefix + + # interim - polymorphic_on is further refined in + # _configure_polymorphic_setter + self.polymorphic_on = ( + coercions.expect( # type: ignore + roles.ColumnArgumentOrKeyRole, + polymorphic_on, + argname="polymorphic_on", + ) + if polymorphic_on is not None + else None + ) + self.polymorphic_abstract = polymorphic_abstract + self._dependency_processors = [] + self.validators = util.EMPTY_DICT + self.passive_updates = passive_updates + self.passive_deletes = passive_deletes + self.legacy_is_orphan = legacy_is_orphan + self._clause_adapter = None + self._requires_row_aliasing = False + self._inherits_equated_pairs = None + self._memoized_values = {} + self._compiled_cache_size = _compiled_cache_size + self._reconstructor = None + self.allow_partial_pks = allow_partial_pks + + if self.inherits and not self.concrete: + self.confirm_deleted_rows = False + else: + self.confirm_deleted_rows = confirm_deleted_rows + + self._set_with_polymorphic(with_polymorphic) + self.polymorphic_load = polymorphic_load + + # our 'polymorphic identity', a string name that when located in a + # result set row indicates this Mapper should be used to construct + # the object instance for that row. + self.polymorphic_identity = polymorphic_identity + + # a dictionary of 'polymorphic identity' names, associating those + # names with Mappers that will be used to construct object instances + # upon a select operation. + if _polymorphic_map is None: + self.polymorphic_map = {} + else: + self.polymorphic_map = _polymorphic_map + + if include_properties is not None: + self.include_properties = util.to_set(include_properties) + else: + self.include_properties = None + if exclude_properties: + self.exclude_properties = util.to_set(exclude_properties) + else: + self.exclude_properties = None + + # prevent this mapper from being constructed + # while a configure_mappers() is occurring (and defer a + # configure_mappers() until construction succeeds) + with _CONFIGURE_MUTEX: + cast("MapperEvents", self.dispatch._events)._new_mapper_instance( + class_, self + ) + self._configure_inheritance() + self._configure_class_instrumentation() + self._configure_properties() + self._configure_polymorphic_setter() + self._configure_pks() + self.registry._flag_new_mapper(self) + self._log("constructed") + self._expire_memoizations() + + self.dispatch.after_mapper_constructed(self, self.class_) + + def _prefer_eager_defaults(self, dialect, table): + if self.eager_defaults == "auto": + if not table.implicit_returning: + return False + + return ( + table in self._server_default_col_keys + and dialect.insert_executemany_returning + ) + else: + return self.eager_defaults + + def _gen_cache_key(self, anon_map, bindparams): + return (self,) + + # ### BEGIN + # ATTRIBUTE DECLARATIONS START HERE + + is_mapper = True + """Part of the inspection API.""" + + represents_outer_join = False + + registry: _RegistryType + + @property + def mapper(self) -> Mapper[_O]: + """Part of the inspection API. + + Returns self. + + """ + return self + + @property + def entity(self): + r"""Part of the inspection API. + + Returns self.class\_. + + """ + return self.class_ + + class_: Type[_O] + """The class to which this :class:`_orm.Mapper` is mapped.""" + + _identity_class: Type[_O] + + _delete_orphans: List[Tuple[str, Type[Any]]] + _dependency_processors: List[DependencyProcessor] + _memoized_values: Dict[Any, Callable[[], Any]] + _inheriting_mappers: util.WeakSequence[Mapper[Any]] + _all_tables: Set[TableClause] + _polymorphic_attr_key: Optional[str] + + _pks_by_table: Dict[FromClause, OrderedSet[ColumnClause[Any]]] + _cols_by_table: Dict[FromClause, OrderedSet[ColumnElement[Any]]] + + _props: util.OrderedDict[str, MapperProperty[Any]] + _init_properties: Dict[str, MapperProperty[Any]] + + _columntoproperty: _ColumnMapping + + _set_polymorphic_identity: Optional[Callable[[InstanceState[_O]], None]] + _validate_polymorphic_identity: Optional[ + Callable[[Mapper[_O], InstanceState[_O], _InstanceDict], None] + ] + + tables: Sequence[TableClause] + """A sequence containing the collection of :class:`_schema.Table` + or :class:`_schema.TableClause` objects which this :class:`_orm.Mapper` + is aware of. + + If the mapper is mapped to a :class:`_expression.Join`, or an + :class:`_expression.Alias` + representing a :class:`_expression.Select`, the individual + :class:`_schema.Table` + objects that comprise the full construct will be represented here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + validators: util.immutabledict[str, Tuple[str, Dict[str, Any]]] + """An immutable dictionary of attributes which have been decorated + using the :func:`_orm.validates` decorator. + + The dictionary contains string attribute names as keys + mapped to the actual validation method. + + """ + + always_refresh: bool + allow_partial_pks: bool + version_id_col: Optional[ColumnElement[Any]] + + with_polymorphic: Optional[ + Tuple[ + Union[Literal["*"], Sequence[Union[Mapper[Any], Type[Any]]]], + Optional[FromClause], + ] + ] + + version_id_generator: Optional[Union[Literal[False], Callable[[Any], Any]]] + + local_table: FromClause + """The immediate :class:`_expression.FromClause` to which this + :class:`_orm.Mapper` refers. + + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. + + The "local" table is the + selectable that the :class:`_orm.Mapper` is directly responsible for + managing from an attribute access and flush perspective. For + non-inheriting mappers, :attr:`.Mapper.local_table` will be the same + as :attr:`.Mapper.persist_selectable`. For inheriting mappers, + :attr:`.Mapper.local_table` refers to the specific portion of + :attr:`.Mapper.persist_selectable` that includes the columns to which + this :class:`.Mapper` is loading/persisting, such as a particular + :class:`.Table` within a join. + + .. seealso:: + + :attr:`_orm.Mapper.persist_selectable`. + + :attr:`_orm.Mapper.selectable`. + + """ + + persist_selectable: FromClause + """The :class:`_expression.FromClause` to which this :class:`_orm.Mapper` + is mapped. + + Typically is an instance of :class:`_schema.Table`, may be any + :class:`.FromClause`. + + The :attr:`_orm.Mapper.persist_selectable` is similar to + :attr:`.Mapper.local_table`, but represents the :class:`.FromClause` that + represents the inheriting class hierarchy overall in an inheritance + scenario. + + :attr.`.Mapper.persist_selectable` is also separate from the + :attr:`.Mapper.selectable` attribute, the latter of which may be an + alternate subquery used for selecting columns. + :attr.`.Mapper.persist_selectable` is oriented towards columns that + will be written on a persist operation. + + .. seealso:: + + :attr:`_orm.Mapper.selectable`. + + :attr:`_orm.Mapper.local_table`. + + """ + + inherits: Optional[Mapper[Any]] + """References the :class:`_orm.Mapper` which this :class:`_orm.Mapper` + inherits from, if any. + + """ + + inherit_condition: Optional[ColumnElement[bool]] + + configured: bool = False + """Represent ``True`` if this :class:`_orm.Mapper` has been configured. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + .. seealso:: + + :func:`.configure_mappers`. + + """ + + concrete: bool + """Represent ``True`` if this :class:`_orm.Mapper` is a concrete + inheritance mapper. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + primary_key: Tuple[Column[Any], ...] + """An iterable containing the collection of :class:`_schema.Column` + objects + which comprise the 'primary key' of the mapped table, from the + perspective of this :class:`_orm.Mapper`. + + This list is against the selectable in + :attr:`_orm.Mapper.persist_selectable`. + In the case of inheriting mappers, some columns may be managed by a + superclass mapper. For example, in the case of a + :class:`_expression.Join`, the + primary key is determined by all of the primary key columns across all + tables referenced by the :class:`_expression.Join`. + + The list is also not necessarily the same as the primary key column + collection associated with the underlying tables; the :class:`_orm.Mapper` + features a ``primary_key`` argument that can override what the + :class:`_orm.Mapper` considers as primary key columns. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + class_manager: ClassManager[_O] + """The :class:`.ClassManager` which maintains event listeners + and class-bound descriptors for this :class:`_orm.Mapper`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + single: bool + """Represent ``True`` if this :class:`_orm.Mapper` is a single table + inheritance mapper. + + :attr:`_orm.Mapper.local_table` will be ``None`` if this flag is set. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + non_primary: bool + """Represent ``True`` if this :class:`_orm.Mapper` is a "non-primary" + mapper, e.g. a mapper that is used only to select rows but not for + persistence management. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_on: Optional[KeyedColumnElement[Any]] + """The :class:`_schema.Column` or SQL expression specified as the + ``polymorphic_on`` argument + for this :class:`_orm.Mapper`, within an inheritance scenario. + + This attribute is normally a :class:`_schema.Column` instance but + may also be an expression, such as one derived from + :func:`.cast`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_map: Dict[Any, Mapper[Any]] + """A mapping of "polymorphic identity" identifiers mapped to + :class:`_orm.Mapper` instances, within an inheritance scenario. + + The identifiers can be of any type which is comparable to the + type of column represented by :attr:`_orm.Mapper.polymorphic_on`. + + An inheritance chain of mappers will all reference the same + polymorphic map object. The object is used to correlate incoming + result rows to target mappers. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + polymorphic_identity: Optional[Any] + """Represent an identifier which is matched against the + :attr:`_orm.Mapper.polymorphic_on` column during result row loading. + + Used only with inheritance, this object can be of any type which is + comparable to the type of column represented by + :attr:`_orm.Mapper.polymorphic_on`. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + base_mapper: Mapper[Any] + """The base-most :class:`_orm.Mapper` in an inheritance chain. + + In a non-inheriting scenario, this attribute will always be this + :class:`_orm.Mapper`. In an inheritance scenario, it references + the :class:`_orm.Mapper` which is parent to all other :class:`_orm.Mapper` + objects in the inheritance chain. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + columns: ReadOnlyColumnCollection[str, Column[Any]] + """A collection of :class:`_schema.Column` or other scalar expression + objects maintained by this :class:`_orm.Mapper`. + + The collection behaves the same as that of the ``c`` attribute on + any :class:`_schema.Table` object, + except that only those columns included in + this mapping are present, and are keyed based on the attribute name + defined in the mapping, not necessarily the ``key`` attribute of the + :class:`_schema.Column` itself. Additionally, scalar expressions mapped + by :func:`.column_property` are also present here. + + This is a *read only* attribute determined during mapper construction. + Behavior is undefined if directly modified. + + """ + + c: ReadOnlyColumnCollection[str, Column[Any]] + """A synonym for :attr:`_orm.Mapper.columns`.""" + + @util.non_memoized_property + @util.deprecated("1.3", "Use .persist_selectable") + def mapped_table(self): + return self.persist_selectable + + @util.memoized_property + def _path_registry(self) -> CachingEntityRegistry: + return PathRegistry.per_mapper(self) + + def _configure_inheritance(self): + """Configure settings related to inheriting and/or inherited mappers + being present.""" + + # a set of all mappers which inherit from this one. + self._inheriting_mappers = util.WeakSequence() + + if self.inherits: + if not issubclass(self.class_, self.inherits.class_): + raise sa_exc.ArgumentError( + "Class '%s' does not inherit from '%s'" + % (self.class_.__name__, self.inherits.class_.__name__) + ) + + self.dispatch._update(self.inherits.dispatch) + + if self.non_primary != self.inherits.non_primary: + np = not self.non_primary and "primary" or "non-primary" + raise sa_exc.ArgumentError( + "Inheritance of %s mapper for class '%s' is " + "only allowed from a %s mapper" + % (np, self.class_.__name__, np) + ) + + if self.single: + self.persist_selectable = self.inherits.persist_selectable + elif self.local_table is not self.inherits.local_table: + if self.concrete: + self.persist_selectable = self.local_table + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + else: + if self.inherit_condition is None: + # figure out inherit condition from our table to the + # immediate table of the inherited mapper, not its + # full table which could pull in other stuff we don't + # want (allows test/inheritance.InheritTest4 to pass) + try: + self.inherit_condition = sql_util.join_condition( + self.inherits.local_table, self.local_table + ) + except sa_exc.NoForeignKeysError as nfe: + assert self.inherits.local_table is not None + assert self.local_table is not None + raise sa_exc.NoForeignKeysError( + "Can't determine the inherit condition " + "between inherited table '%s' and " + "inheriting " + "table '%s'; tables have no " + "foreign key relationships established. " + "Please ensure the inheriting table has " + "a foreign key relationship to the " + "inherited " + "table, or provide an " + "'on clause' using " + "the 'inherit_condition' mapper argument." + % ( + self.inherits.local_table.description, + self.local_table.description, + ) + ) from nfe + except sa_exc.AmbiguousForeignKeysError as afe: + assert self.inherits.local_table is not None + assert self.local_table is not None + raise sa_exc.AmbiguousForeignKeysError( + "Can't determine the inherit condition " + "between inherited table '%s' and " + "inheriting " + "table '%s'; tables have more than one " + "foreign key relationship established. " + "Please specify the 'on clause' using " + "the 'inherit_condition' mapper argument." + % ( + self.inherits.local_table.description, + self.local_table.description, + ) + ) from afe + assert self.inherits.persist_selectable is not None + self.persist_selectable = sql.join( + self.inherits.persist_selectable, + self.local_table, + self.inherit_condition, + ) + + fks = util.to_set(self.inherit_foreign_keys) + self._inherits_equated_pairs = sql_util.criterion_as_pairs( + self.persist_selectable.onclause, + consider_as_foreign_keys=fks, + ) + else: + self.persist_selectable = self.local_table + + if self.polymorphic_identity is None: + self._identity_class = self.class_ + + if ( + not self.polymorphic_abstract + and self.inherits.base_mapper.polymorphic_on is not None + ): + util.warn( + f"{self} does not indicate a 'polymorphic_identity', " + "yet is part of an inheritance hierarchy that has a " + f"'polymorphic_on' column of " + f"'{self.inherits.base_mapper.polymorphic_on}'. " + "If this is an intermediary class that should not be " + "instantiated, the class may either be left unmapped, " + "or may include the 'polymorphic_abstract=True' " + "parameter in its Mapper arguments. To leave the " + "class unmapped when using Declarative, set the " + "'__abstract__ = True' attribute on the class." + ) + elif self.concrete: + self._identity_class = self.class_ + else: + self._identity_class = self.inherits._identity_class + + if self.version_id_col is None: + self.version_id_col = self.inherits.version_id_col + self.version_id_generator = self.inherits.version_id_generator + elif ( + self.inherits.version_id_col is not None + and self.version_id_col is not self.inherits.version_id_col + ): + util.warn( + "Inheriting version_id_col '%s' does not match inherited " + "version_id_col '%s' and will not automatically populate " + "the inherited versioning column. " + "version_id_col should only be specified on " + "the base-most mapper that includes versioning." + % ( + self.version_id_col.description, + self.inherits.version_id_col.description, + ) + ) + + self.polymorphic_map = self.inherits.polymorphic_map + self.batch = self.inherits.batch + self.inherits._inheriting_mappers.append(self) + self.base_mapper = self.inherits.base_mapper + self.passive_updates = self.inherits.passive_updates + self.passive_deletes = ( + self.inherits.passive_deletes or self.passive_deletes + ) + self._all_tables = self.inherits._all_tables + + if self.polymorphic_identity is not None: + if self.polymorphic_identity in self.polymorphic_map: + util.warn( + "Reassigning polymorphic association for identity %r " + "from %r to %r: Check for duplicate use of %r as " + "value for polymorphic_identity." + % ( + self.polymorphic_identity, + self.polymorphic_map[self.polymorphic_identity], + self, + self.polymorphic_identity, + ) + ) + self.polymorphic_map[self.polymorphic_identity] = self + + if self.polymorphic_load and self.concrete: + raise sa_exc.ArgumentError( + "polymorphic_load is not currently supported " + "with concrete table inheritance" + ) + if self.polymorphic_load == "inline": + self.inherits._add_with_polymorphic_subclass(self) + elif self.polymorphic_load == "selectin": + pass + elif self.polymorphic_load is not None: + raise sa_exc.ArgumentError( + "unknown argument for polymorphic_load: %r" + % self.polymorphic_load + ) + + else: + self._all_tables = set() + self.base_mapper = self + assert self.local_table is not None + self.persist_selectable = self.local_table + if self.polymorphic_identity is not None: + self.polymorphic_map[self.polymorphic_identity] = self + self._identity_class = self.class_ + + if self.persist_selectable is None: + raise sa_exc.ArgumentError( + "Mapper '%s' does not have a persist_selectable specified." + % self + ) + + def _set_with_polymorphic( + self, with_polymorphic: Optional[_WithPolymorphicArg] + ) -> None: + if with_polymorphic == "*": + self.with_polymorphic = ("*", None) + elif isinstance(with_polymorphic, (tuple, list)): + if isinstance(with_polymorphic[0], (str, tuple, list)): + self.with_polymorphic = cast( + """Tuple[ + Union[ + Literal["*"], + Sequence[Union["Mapper[Any]", Type[Any]]], + ], + Optional["FromClause"], + ]""", + with_polymorphic, + ) + else: + self.with_polymorphic = (with_polymorphic, None) + elif with_polymorphic is not None: + raise sa_exc.ArgumentError( + f"Invalid setting for with_polymorphic: {with_polymorphic!r}" + ) + else: + self.with_polymorphic = None + + if self.with_polymorphic and self.with_polymorphic[1] is not None: + self.with_polymorphic = ( + self.with_polymorphic[0], + coercions.expect( + roles.StrictFromClauseRole, + self.with_polymorphic[1], + allow_select=True, + ), + ) + + if self.configured: + self._expire_memoizations() + + def _add_with_polymorphic_subclass(self, mapper): + subcl = mapper.class_ + if self.with_polymorphic is None: + self._set_with_polymorphic((subcl,)) + elif self.with_polymorphic[0] != "*": + assert isinstance(self.with_polymorphic[0], tuple) + self._set_with_polymorphic( + (self.with_polymorphic[0] + (subcl,), self.with_polymorphic[1]) + ) + + def _set_concrete_base(self, mapper): + """Set the given :class:`_orm.Mapper` as the 'inherits' for this + :class:`_orm.Mapper`, assuming this :class:`_orm.Mapper` is concrete + and does not already have an inherits.""" + + assert self.concrete + assert not self.inherits + assert isinstance(mapper, Mapper) + self.inherits = mapper + self.inherits.polymorphic_map.update(self.polymorphic_map) + self.polymorphic_map = self.inherits.polymorphic_map + for mapper in self.iterate_to_root(): + if mapper.polymorphic_on is not None: + mapper._requires_row_aliasing = True + self.batch = self.inherits.batch + for mp in self.self_and_descendants: + mp.base_mapper = self.inherits.base_mapper + self.inherits._inheriting_mappers.append(self) + self.passive_updates = self.inherits.passive_updates + self._all_tables = self.inherits._all_tables + + for key, prop in mapper._props.items(): + if key not in self._props and not self._should_exclude( + key, key, local=False, column=None + ): + self._adapt_inherited_property(key, prop, False) + + def _set_polymorphic_on(self, polymorphic_on): + self.polymorphic_on = polymorphic_on + self._configure_polymorphic_setter(True) + + def _configure_class_instrumentation(self): + """If this mapper is to be a primary mapper (i.e. the + non_primary flag is not set), associate this Mapper with the + given class and entity name. + + Subsequent calls to ``class_mapper()`` for the ``class_`` / ``entity`` + name combination will return this mapper. Also decorate the + `__init__` method on the mapped class to include optional + auto-session attachment logic. + + """ + + # we expect that declarative has applied the class manager + # already and set up a registry. if this is None, + # this raises as of 2.0. + manager = attributes.opt_manager_of_class(self.class_) + + if self.non_primary: + if not manager or not manager.is_mapped: + raise sa_exc.InvalidRequestError( + "Class %s has no primary mapper configured. Configure " + "a primary mapper first before setting up a non primary " + "Mapper." % self.class_ + ) + self.class_manager = manager + + assert manager.registry is not None + self.registry = manager.registry + self._identity_class = manager.mapper._identity_class + manager.registry._add_non_primary_mapper(self) + return + + if manager is None or not manager.registry: + raise sa_exc.InvalidRequestError( + "The _mapper() function and Mapper() constructor may not be " + "invoked directly outside of a declarative registry." + " Please use the sqlalchemy.orm.registry.map_imperatively() " + "function for a classical mapping." + ) + + self.dispatch.instrument_class(self, self.class_) + + # this invokes the class_instrument event and sets up + # the __init__ method. documented behavior is that this must + # occur after the instrument_class event above. + # yes two events with the same two words reversed and different APIs. + # :( + + manager = instrumentation.register_class( + self.class_, + mapper=self, + expired_attribute_loader=util.partial( + loading.load_scalar_attributes, self + ), + # finalize flag means instrument the __init__ method + # and call the class_instrument event + finalize=True, + ) + + self.class_manager = manager + + assert manager.registry is not None + self.registry = manager.registry + + # The remaining members can be added by any mapper, + # e_name None or not. + if manager.mapper is None: + return + + event.listen(manager, "init", _event_on_init, raw=True) + + for key, method in util.iterate_attributes(self.class_): + if key == "__init__" and hasattr(method, "_sa_original_init"): + method = method._sa_original_init + if hasattr(method, "__func__"): + method = method.__func__ + if callable(method): + if hasattr(method, "__sa_reconstructor__"): + self._reconstructor = method + event.listen(manager, "load", _event_on_load, raw=True) + elif hasattr(method, "__sa_validators__"): + validation_opts = method.__sa_validation_opts__ + for name in method.__sa_validators__: + if name in self.validators: + raise sa_exc.InvalidRequestError( + "A validation function for mapped " + "attribute %r on mapper %s already exists." + % (name, self) + ) + self.validators = self.validators.union( + {name: (method, validation_opts)} + ) + + def _set_dispose_flags(self) -> None: + self.configured = True + self._ready_for_configure = True + self._dispose_called = True + + self.__dict__.pop("_configure_failed", None) + + def _str_arg_to_mapped_col(self, argname: str, key: str) -> Column[Any]: + try: + prop = self._props[key] + except KeyError as err: + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}' - " + "no attribute is mapped to this name." + ) from err + try: + expr = prop.expression + except AttributeError as ae: + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}'; " + "property does not refer to a single mapped Column" + ) from ae + if not isinstance(expr, Column): + raise sa_exc.ArgumentError( + f"Can't determine {argname} column '{key}'; " + "property does not refer to a single " + "mapped Column" + ) + return expr + + def _configure_pks(self) -> None: + self.tables = sql_util.find_tables(self.persist_selectable) + + self._all_tables.update(t for t in self.tables) + + self._pks_by_table = {} + self._cols_by_table = {} + + all_cols = util.column_set( + chain(*[col.proxy_set for col in self._columntoproperty]) + ) + + pk_cols = util.column_set(c for c in all_cols if c.primary_key) + + # identify primary key columns which are also mapped by this mapper. + for fc in set(self.tables).union([self.persist_selectable]): + if fc.primary_key and pk_cols.issuperset(fc.primary_key): + # ordering is important since it determines the ordering of + # mapper.primary_key (and therefore query.get()) + self._pks_by_table[fc] = util.ordered_column_set( # type: ignore # noqa: E501 + fc.primary_key + ).intersection( + pk_cols + ) + self._cols_by_table[fc] = util.ordered_column_set(fc.c).intersection( # type: ignore # noqa: E501 + all_cols + ) + + if self._primary_key_argument: + coerced_pk_arg = [ + ( + self._str_arg_to_mapped_col("primary_key", c) + if isinstance(c, str) + else c + ) + for c in ( + coercions.expect( + roles.DDLConstraintColumnRole, + coerce_pk, + argname="primary_key", + ) + for coerce_pk in self._primary_key_argument + ) + ] + else: + coerced_pk_arg = None + + # if explicit PK argument sent, add those columns to the + # primary key mappings + if coerced_pk_arg: + for k in coerced_pk_arg: + if k.table not in self._pks_by_table: + self._pks_by_table[k.table] = util.OrderedSet() + self._pks_by_table[k.table].add(k) + + # otherwise, see that we got a full PK for the mapped table + elif ( + self.persist_selectable not in self._pks_by_table + or len(self._pks_by_table[self.persist_selectable]) == 0 + ): + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" + % (self, self.persist_selectable.description) + ) + elif self.local_table not in self._pks_by_table and isinstance( + self.local_table, schema.Table + ): + util.warn( + "Could not assemble any primary " + "keys for locally mapped table '%s' - " + "no rows will be persisted in this Table." + % self.local_table.description + ) + + if ( + self.inherits + and not self.concrete + and not self._primary_key_argument + ): + # if inheriting, the "primary key" for this mapper is + # that of the inheriting (unless concrete or explicit) + self.primary_key = self.inherits.primary_key + else: + # determine primary key from argument or persist_selectable pks + primary_key: Collection[ColumnElement[Any]] + + if coerced_pk_arg: + primary_key = [ + cc if cc is not None else c + for cc, c in ( + (self.persist_selectable.corresponding_column(c), c) + for c in coerced_pk_arg + ) + ] + else: + # if heuristically determined PKs, reduce to the minimal set + # of columns by eliminating FK->PK pairs for a multi-table + # expression. May over-reduce for some kinds of UNIONs + # / CTEs; use explicit PK argument for these special cases + primary_key = sql_util.reduce_columns( + self._pks_by_table[self.persist_selectable], + ignore_nonexistent_tables=True, + ) + + if len(primary_key) == 0: + raise sa_exc.ArgumentError( + "Mapper %s could not assemble any primary " + "key columns for mapped table '%s'" + % (self, self.persist_selectable.description) + ) + + self.primary_key = tuple(primary_key) + self._log("Identified primary key columns: %s", primary_key) + + # determine cols that aren't expressed within our tables; mark these + # as "read only" properties which are refreshed upon INSERT/UPDATE + self._readonly_props = { + self._columntoproperty[col] + for col in self._columntoproperty + if self._columntoproperty[col] not in self._identity_key_props + and ( + not hasattr(col, "table") + or col.table not in self._cols_by_table + ) + } + + def _configure_properties(self) -> None: + self.columns = self.c = sql_base.ColumnCollection() # type: ignore + + # object attribute names mapped to MapperProperty objects + self._props = util.OrderedDict() + + # table columns mapped to MapperProperty + self._columntoproperty = _ColumnMapping(self) + + explicit_col_props_by_column: Dict[ + KeyedColumnElement[Any], Tuple[str, ColumnProperty[Any]] + ] = {} + explicit_col_props_by_key: Dict[str, ColumnProperty[Any]] = {} + + # step 1: go through properties that were explicitly passed + # in the properties dictionary. For Columns that are local, put them + # aside in a separate collection we will reconcile with the Table + # that's given. For other properties, set them up in _props now. + if self._init_properties: + for key, prop_arg in self._init_properties.items(): + if not isinstance(prop_arg, MapperProperty): + possible_col_prop = self._make_prop_from_column( + key, prop_arg + ) + else: + possible_col_prop = prop_arg + + # issue #8705. if the explicit property is actually a + # Column that is local to the local Table, don't set it up + # in ._props yet, integrate it into the order given within + # the Table. + + _map_as_property_now = True + if isinstance(possible_col_prop, properties.ColumnProperty): + for given_col in possible_col_prop.columns: + if self.local_table.c.contains_column(given_col): + _map_as_property_now = False + explicit_col_props_by_key[key] = possible_col_prop + explicit_col_props_by_column[given_col] = ( + key, + possible_col_prop, + ) + + if _map_as_property_now: + self._configure_property( + key, + possible_col_prop, + init=False, + ) + + # step 2: pull properties from the inherited mapper. reconcile + # columns with those which are explicit above. for properties that + # are only in the inheriting mapper, set them up as local props + if self.inherits: + for key, inherited_prop in self.inherits._props.items(): + if self._should_exclude(key, key, local=False, column=None): + continue + + incoming_prop = explicit_col_props_by_key.get(key) + if incoming_prop: + new_prop = self._reconcile_prop_with_incoming_columns( + key, + inherited_prop, + warn_only=False, + incoming_prop=incoming_prop, + ) + explicit_col_props_by_key[key] = new_prop + + for inc_col in incoming_prop.columns: + explicit_col_props_by_column[inc_col] = ( + key, + new_prop, + ) + elif key not in self._props: + self._adapt_inherited_property(key, inherited_prop, False) + + # step 3. Iterate through all columns in the persist selectable. + # this includes not only columns in the local table / fromclause, + # but also those columns in the superclass table if we are joined + # inh or single inh mapper. map these columns as well. additional + # reconciliation against inherited columns occurs here also. + + for column in self.persist_selectable.columns: + if column in explicit_col_props_by_column: + # column was explicitly passed to properties; configure + # it now in the order in which it corresponds to the + # Table / selectable + key, prop = explicit_col_props_by_column[column] + self._configure_property(key, prop, init=False) + continue + + elif column in self._columntoproperty: + continue + + column_key = (self.column_prefix or "") + column.key + if self._should_exclude( + column.key, + column_key, + local=self.local_table.c.contains_column(column), + column=column, + ): + continue + + # adjust the "key" used for this column to that + # of the inheriting mapper + for mapper in self.iterate_to_root(): + if column in mapper._columntoproperty: + column_key = mapper._columntoproperty[column].key + + self._configure_property( + column_key, + column, + init=False, + setparent=True, + ) + + def _configure_polymorphic_setter(self, init=False): + """Configure an attribute on the mapper representing the + 'polymorphic_on' column, if applicable, and not + already generated by _configure_properties (which is typical). + + Also create a setter function which will assign this + attribute to the value of the 'polymorphic_identity' + upon instance construction, also if applicable. This + routine will run when an instance is created. + + """ + setter = False + polymorphic_key: Optional[str] = None + + if self.polymorphic_on is not None: + setter = True + + if isinstance(self.polymorphic_on, str): + # polymorphic_on specified as a string - link + # it to mapped ColumnProperty + try: + self.polymorphic_on = self._props[self.polymorphic_on] + except KeyError as err: + raise sa_exc.ArgumentError( + "Can't determine polymorphic_on " + "value '%s' - no attribute is " + "mapped to this name." % self.polymorphic_on + ) from err + + if self.polymorphic_on in self._columntoproperty: + # polymorphic_on is a column that is already mapped + # to a ColumnProperty + prop = self._columntoproperty[self.polymorphic_on] + elif isinstance(self.polymorphic_on, MapperProperty): + # polymorphic_on is directly a MapperProperty, + # ensure it's a ColumnProperty + if not isinstance( + self.polymorphic_on, properties.ColumnProperty + ): + raise sa_exc.ArgumentError( + "Only direct column-mapped " + "property or SQL expression " + "can be passed for polymorphic_on" + ) + prop = self.polymorphic_on + else: + # polymorphic_on is a Column or SQL expression and + # doesn't appear to be mapped. this means it can be 1. + # only present in the with_polymorphic selectable or + # 2. a totally standalone SQL expression which we'd + # hope is compatible with this mapper's persist_selectable + col = self.persist_selectable.corresponding_column( + self.polymorphic_on + ) + if col is None: + # polymorphic_on doesn't derive from any + # column/expression isn't present in the mapped + # table. we will make a "hidden" ColumnProperty + # for it. Just check that if it's directly a + # schema.Column and we have with_polymorphic, it's + # likely a user error if the schema.Column isn't + # represented somehow in either persist_selectable or + # with_polymorphic. Otherwise as of 0.7.4 we + # just go with it and assume the user wants it + # that way (i.e. a CASE statement) + setter = False + instrument = False + col = self.polymorphic_on + if isinstance(col, schema.Column) and ( + self.with_polymorphic is None + or self.with_polymorphic[1] is None + or self.with_polymorphic[1].corresponding_column(col) + is None + ): + raise sa_exc.InvalidRequestError( + "Could not map polymorphic_on column " + "'%s' to the mapped table - polymorphic " + "loads will not function properly" + % col.description + ) + else: + # column/expression that polymorphic_on derives from + # is present in our mapped table + # and is probably mapped, but polymorphic_on itself + # is not. This happens when + # the polymorphic_on is only directly present in the + # with_polymorphic selectable, as when use + # polymorphic_union. + # we'll make a separate ColumnProperty for it. + instrument = True + key = getattr(col, "key", None) + if key: + if self._should_exclude(key, key, False, col): + raise sa_exc.InvalidRequestError( + "Cannot exclude or override the " + "discriminator column %r" % key + ) + else: + self.polymorphic_on = col = col.label("_sa_polymorphic_on") + key = col.key + + prop = properties.ColumnProperty(col, _instrument=instrument) + self._configure_property(key, prop, init=init, setparent=True) + + # the actual polymorphic_on should be the first public-facing + # column in the property + self.polymorphic_on = prop.columns[0] + polymorphic_key = prop.key + else: + # no polymorphic_on was set. + # check inheriting mappers for one. + for mapper in self.iterate_to_root(): + # determine if polymorphic_on of the parent + # should be propagated here. If the col + # is present in our mapped table, or if our mapped + # table is the same as the parent (i.e. single table + # inheritance), we can use it + if mapper.polymorphic_on is not None: + if self.persist_selectable is mapper.persist_selectable: + self.polymorphic_on = mapper.polymorphic_on + else: + self.polymorphic_on = ( + self.persist_selectable + ).corresponding_column(mapper.polymorphic_on) + # we can use the parent mapper's _set_polymorphic_identity + # directly; it ensures the polymorphic_identity of the + # instance's mapper is used so is portable to subclasses. + if self.polymorphic_on is not None: + self._set_polymorphic_identity = ( + mapper._set_polymorphic_identity + ) + self._polymorphic_attr_key = ( + mapper._polymorphic_attr_key + ) + self._validate_polymorphic_identity = ( + mapper._validate_polymorphic_identity + ) + else: + self._set_polymorphic_identity = None + self._polymorphic_attr_key = None + return + + if self.polymorphic_abstract and self.polymorphic_on is None: + raise sa_exc.InvalidRequestError( + "The Mapper.polymorphic_abstract parameter may only be used " + "on a mapper hierarchy which includes the " + "Mapper.polymorphic_on parameter at the base of the hierarchy." + ) + + if setter: + + def _set_polymorphic_identity(state): + dict_ = state.dict + # TODO: what happens if polymorphic_on column attribute name + # does not match .key? + + polymorphic_identity = ( + state.manager.mapper.polymorphic_identity + ) + if ( + polymorphic_identity is None + and state.manager.mapper.polymorphic_abstract + ): + raise sa_exc.InvalidRequestError( + f"Can't instantiate class for {state.manager.mapper}; " + "mapper is marked polymorphic_abstract=True" + ) + + state.get_impl(polymorphic_key).set( + state, + dict_, + polymorphic_identity, + None, + ) + + self._polymorphic_attr_key = polymorphic_key + + def _validate_polymorphic_identity(mapper, state, dict_): + if ( + polymorphic_key in dict_ + and dict_[polymorphic_key] + not in mapper._acceptable_polymorphic_identities + ): + util.warn_limited( + "Flushing object %s with " + "incompatible polymorphic identity %r; the " + "object may not refresh and/or load correctly", + (state_str(state), dict_[polymorphic_key]), + ) + + self._set_polymorphic_identity = _set_polymorphic_identity + self._validate_polymorphic_identity = ( + _validate_polymorphic_identity + ) + else: + self._polymorphic_attr_key = None + self._set_polymorphic_identity = None + + _validate_polymorphic_identity = None + + @HasMemoized.memoized_attribute + def _version_id_prop(self): + if self.version_id_col is not None: + return self._columntoproperty[self.version_id_col] + else: + return None + + @HasMemoized.memoized_attribute + def _acceptable_polymorphic_identities(self): + identities = set() + + stack = deque([self]) + while stack: + item = stack.popleft() + if item.persist_selectable is self.persist_selectable: + identities.add(item.polymorphic_identity) + stack.extend(item._inheriting_mappers) + + return identities + + @HasMemoized.memoized_attribute + def _prop_set(self): + return frozenset(self._props.values()) + + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _adapt_inherited_property(self, key, prop, init): + descriptor_props = util.preloaded.orm_descriptor_props + + if not self.concrete: + self._configure_property(key, prop, init=False, setparent=False) + elif key not in self._props: + # determine if the class implements this attribute; if not, + # or if it is implemented by the attribute that is handling the + # given superclass-mapped property, then we need to report that we + # can't use this at the instance level since we are a concrete + # mapper and we don't map this. don't trip user-defined + # descriptors that might have side effects when invoked. + implementing_attribute = self.class_manager._get_class_attr_mro( + key, prop + ) + if implementing_attribute is prop or ( + isinstance( + implementing_attribute, attributes.InstrumentedAttribute + ) + and implementing_attribute._parententity is prop.parent + ): + self._configure_property( + key, + descriptor_props.ConcreteInheritedProperty(), + init=init, + setparent=True, + ) + + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _configure_property( + self, + key: str, + prop_arg: Union[KeyedColumnElement[Any], MapperProperty[Any]], + *, + init: bool = True, + setparent: bool = True, + warn_for_existing: bool = False, + ) -> MapperProperty[Any]: + descriptor_props = util.preloaded.orm_descriptor_props + self._log( + "_configure_property(%s, %s)", key, prop_arg.__class__.__name__ + ) + + if not isinstance(prop_arg, MapperProperty): + prop: MapperProperty[Any] = self._property_from_column( + key, prop_arg + ) + else: + prop = prop_arg + + if isinstance(prop, properties.ColumnProperty): + col = self.persist_selectable.corresponding_column(prop.columns[0]) + + # if the column is not present in the mapped table, + # test if a column has been added after the fact to the + # parent table (or their parent, etc.) [ticket:1570] + if col is None and self.inherits: + path = [self] + for m in self.inherits.iterate_to_root(): + col = m.local_table.corresponding_column(prop.columns[0]) + if col is not None: + for m2 in path: + m2.persist_selectable._refresh_for_new_column(col) + col = self.persist_selectable.corresponding_column( + prop.columns[0] + ) + break + path.append(m) + + # subquery expression, column not present in the mapped + # selectable. + if col is None: + col = prop.columns[0] + + # column is coming in after _readonly_props was + # initialized; check for 'readonly' + if hasattr(self, "_readonly_props") and ( + not hasattr(col, "table") + or col.table not in self._cols_by_table + ): + self._readonly_props.add(prop) + + else: + # if column is coming in after _cols_by_table was + # initialized, ensure the col is in the right set + if ( + hasattr(self, "_cols_by_table") + and col.table in self._cols_by_table + and col not in self._cols_by_table[col.table] + ): + self._cols_by_table[col.table].add(col) + + # if this properties.ColumnProperty represents the "polymorphic + # discriminator" column, mark it. We'll need this when rendering + # columns in SELECT statements. + if not hasattr(prop, "_is_polymorphic_discriminator"): + prop._is_polymorphic_discriminator = ( + col is self.polymorphic_on + or prop.columns[0] is self.polymorphic_on + ) + + if isinstance(col, expression.Label): + # new in 1.4, get column property against expressions + # to be addressable in subqueries + col.key = col._tq_key_label = key + + self.columns.add(col, key) + + for col in prop.columns: + for proxy_col in col.proxy_set: + self._columntoproperty[proxy_col] = prop + + if getattr(prop, "key", key) != key: + util.warn( + f"ORM mapped property {self.class_.__name__}.{prop.key} being " + "assigned to attribute " + f"{key!r} is already associated with " + f"attribute {prop.key!r}. The attribute will be de-associated " + f"from {prop.key!r}." + ) + + prop.key = key + + if setparent: + prop.set_parent(self, init) + + if key in self._props and getattr( + self._props[key], "_mapped_by_synonym", False + ): + syn = self._props[key]._mapped_by_synonym + raise sa_exc.ArgumentError( + "Can't call map_column=True for synonym %r=%r, " + "a ColumnProperty already exists keyed to the name " + "%r for column %r" % (syn, key, key, syn) + ) + + # replacement cases + + # case one: prop is replacing a prop that we have mapped. this is + # independent of whatever might be in the actual class dictionary + if ( + key in self._props + and not isinstance( + self._props[key], descriptor_props.ConcreteInheritedProperty + ) + and not isinstance(prop, descriptor_props.SynonymProperty) + ): + if warn_for_existing: + util.warn_deprecated( + f"User-placed attribute {self.class_.__name__}.{key} on " + f"{self} is replacing an existing ORM-mapped attribute. " + "Behavior is not fully defined in this case. This " + "use is deprecated and will raise an error in a future " + "release", + "2.0", + ) + oldprop = self._props[key] + self._path_registry.pop(oldprop, None) + + # case two: prop is replacing an attribute on the class of some kind. + # we have to be more careful here since it's normal when using + # Declarative that all the "declared attributes" on the class + # get replaced. + elif ( + warn_for_existing + and self.class_.__dict__.get(key, None) is not None + and not isinstance(prop, descriptor_props.SynonymProperty) + and not isinstance( + self._props.get(key, None), + descriptor_props.ConcreteInheritedProperty, + ) + ): + util.warn_deprecated( + f"User-placed attribute {self.class_.__name__}.{key} on " + f"{self} is replacing an existing class-bound " + "attribute of the same name. " + "Behavior is not fully defined in this case. This " + "use is deprecated and will raise an error in a future " + "release", + "2.0", + ) + + self._props[key] = prop + + if not self.non_primary: + prop.instrument_class(self) + + for mapper in self._inheriting_mappers: + mapper._adapt_inherited_property(key, prop, init) + + if init: + prop.init() + prop.post_instrument_class(self) + + if self.configured: + self._expire_memoizations() + + return prop + + def _make_prop_from_column( + self, + key: str, + column: Union[ + Sequence[KeyedColumnElement[Any]], KeyedColumnElement[Any] + ], + ) -> ColumnProperty[Any]: + columns = util.to_list(column) + mapped_column = [] + for c in columns: + mc = self.persist_selectable.corresponding_column(c) + if mc is None: + mc = self.local_table.corresponding_column(c) + if mc is not None: + # if the column is in the local table but not the + # mapped table, this corresponds to adding a + # column after the fact to the local table. + # [ticket:1523] + self.persist_selectable._refresh_for_new_column(mc) + mc = self.persist_selectable.corresponding_column(c) + if mc is None: + raise sa_exc.ArgumentError( + "When configuring property '%s' on %s, " + "column '%s' is not represented in the mapper's " + "table. Use the `column_property()` function to " + "force this column to be mapped as a read-only " + "attribute." % (key, self, c) + ) + mapped_column.append(mc) + return properties.ColumnProperty(*mapped_column) + + def _reconcile_prop_with_incoming_columns( + self, + key: str, + existing_prop: MapperProperty[Any], + warn_only: bool, + incoming_prop: Optional[ColumnProperty[Any]] = None, + single_column: Optional[KeyedColumnElement[Any]] = None, + ) -> ColumnProperty[Any]: + if incoming_prop and ( + self.concrete + or not isinstance(existing_prop, properties.ColumnProperty) + ): + return incoming_prop + + existing_column = existing_prop.columns[0] + + if incoming_prop and existing_column in incoming_prop.columns: + return incoming_prop + + if incoming_prop is None: + assert single_column is not None + incoming_column = single_column + equated_pair_key = (existing_prop.columns[0], incoming_column) + else: + assert single_column is None + incoming_column = incoming_prop.columns[0] + equated_pair_key = (incoming_column, existing_prop.columns[0]) + + if ( + ( + not self._inherits_equated_pairs + or (equated_pair_key not in self._inherits_equated_pairs) + ) + and not existing_column.shares_lineage(incoming_column) + and existing_column is not self.version_id_col + and incoming_column is not self.version_id_col + ): + msg = ( + "Implicitly combining column %s with column " + "%s under attribute '%s'. Please configure one " + "or more attributes for these same-named columns " + "explicitly." + % ( + existing_prop.columns[-1], + incoming_column, + key, + ) + ) + if warn_only: + util.warn(msg) + else: + raise sa_exc.InvalidRequestError(msg) + + # existing properties.ColumnProperty from an inheriting + # mapper. make a copy and append our column to it + # breakpoint() + new_prop = existing_prop.copy() + + new_prop.columns.insert(0, incoming_column) + self._log( + "inserting column to existing list " + "in properties.ColumnProperty %s", + key, + ) + return new_prop # type: ignore + + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _property_from_column( + self, + key: str, + column: KeyedColumnElement[Any], + ) -> ColumnProperty[Any]: + """generate/update a :class:`.ColumnProperty` given a + :class:`_schema.Column` or other SQL expression object.""" + + descriptor_props = util.preloaded.orm_descriptor_props + + prop = self._props.get(key) + + if isinstance(prop, properties.ColumnProperty): + return self._reconcile_prop_with_incoming_columns( + key, + prop, + single_column=column, + warn_only=prop.parent is not self, + ) + elif prop is None or isinstance( + prop, descriptor_props.ConcreteInheritedProperty + ): + return self._make_prop_from_column(key, column) + else: + raise sa_exc.ArgumentError( + "WARNING: when configuring property '%s' on %s, " + "column '%s' conflicts with property '%r'. " + "To resolve this, map the column to the class under a " + "different name in the 'properties' dictionary. Or, " + "to remove all awareness of the column entirely " + "(including its availability as a foreign key), " + "use the 'include_properties' or 'exclude_properties' " + "mapper arguments to control specifically which table " + "columns get mapped." % (key, self, column.key, prop) + ) + + @util.langhelpers.tag_method_for_warnings( + "This warning originated from the `configure_mappers()` process, " + "which was invoked automatically in response to a user-initiated " + "operation.", + sa_exc.SAWarning, + ) + def _check_configure(self) -> None: + if self.registry._new_mappers: + _configure_registries({self.registry}, cascade=True) + + def _post_configure_properties(self) -> None: + """Call the ``init()`` method on all ``MapperProperties`` + attached to this mapper. + + This is a deferred configuration step which is intended + to execute once all mappers have been constructed. + + """ + + self._log("_post_configure_properties() started") + l = [(key, prop) for key, prop in self._props.items()] + for key, prop in l: + self._log("initialize prop %s", key) + + if prop.parent is self and not prop._configure_started: + prop.init() + + if prop._configure_finished: + prop.post_instrument_class(self) + + self._log("_post_configure_properties() complete") + self.configured = True + + def add_properties(self, dict_of_properties): + """Add the given dictionary of properties to this mapper, + using `add_property`. + + """ + for key, value in dict_of_properties.items(): + self.add_property(key, value) + + def add_property( + self, key: str, prop: Union[Column[Any], MapperProperty[Any]] + ) -> None: + """Add an individual MapperProperty to this mapper. + + If the mapper has not been configured yet, just adds the + property to the initial properties dictionary sent to the + constructor. If this Mapper has already been configured, then + the given MapperProperty is configured immediately. + + """ + prop = self._configure_property( + key, prop, init=self.configured, warn_for_existing=True + ) + assert isinstance(prop, MapperProperty) + self._init_properties[key] = prop + + def _expire_memoizations(self) -> None: + for mapper in self.iterate_to_root(): + mapper._reset_memoizations() + + @property + def _log_desc(self) -> str: + return ( + "(" + + self.class_.__name__ + + "|" + + ( + self.local_table is not None + and self.local_table.description + or str(self.local_table) + ) + + (self.non_primary and "|non-primary" or "") + + ")" + ) + + def _log(self, msg: str, *args: Any) -> None: + self.logger.info("%s " + msg, *((self._log_desc,) + args)) + + def _log_debug(self, msg: str, *args: Any) -> None: + self.logger.debug("%s " + msg, *((self._log_desc,) + args)) + + def __repr__(self) -> str: + return "" % (id(self), self.class_.__name__) + + def __str__(self) -> str: + return "Mapper[%s%s(%s)]" % ( + self.class_.__name__, + self.non_primary and " (non-primary)" or "", + ( + self.local_table.description + if self.local_table is not None + else self.persist_selectable.description + ), + ) + + def _is_orphan(self, state: InstanceState[_O]) -> bool: + orphan_possible = False + for mapper in self.iterate_to_root(): + for key, cls in mapper._delete_orphans: + orphan_possible = True + + has_parent = attributes.manager_of_class(cls).has_parent( + state, key, optimistic=state.has_identity + ) + + if self.legacy_is_orphan and has_parent: + return False + elif not self.legacy_is_orphan and not has_parent: + return True + + if self.legacy_is_orphan: + return orphan_possible + else: + return False + + def has_property(self, key: str) -> bool: + return key in self._props + + def get_property( + self, key: str, _configure_mappers: bool = False + ) -> MapperProperty[Any]: + """return a MapperProperty associated with the given key.""" + + if _configure_mappers: + self._check_configure() + + try: + return self._props[key] + except KeyError as err: + raise sa_exc.InvalidRequestError( + f"Mapper '{self}' has no property '{key}'. If this property " + "was indicated from other mappers or configure events, ensure " + "registry.configure() has been called." + ) from err + + def get_property_by_column( + self, column: ColumnElement[_T] + ) -> MapperProperty[_T]: + """Given a :class:`_schema.Column` object, return the + :class:`.MapperProperty` which maps this column.""" + + return self._columntoproperty[column] + + @property + def iterate_properties(self): + """return an iterator of all MapperProperty objects.""" + + return iter(self._props.values()) + + def _mappers_from_spec( + self, spec: Any, selectable: Optional[FromClause] + ) -> Sequence[Mapper[Any]]: + """given a with_polymorphic() argument, return the set of mappers it + represents. + + Trims the list of mappers to just those represented within the given + selectable, if present. This helps some more legacy-ish mappings. + + """ + if spec == "*": + mappers = list(self.self_and_descendants) + elif spec: + mapper_set = set() + for m in util.to_list(spec): + m = _class_to_mapper(m) + if not m.isa(self): + raise sa_exc.InvalidRequestError( + "%r does not inherit from %r" % (m, self) + ) + + if selectable is None: + mapper_set.update(m.iterate_to_root()) + else: + mapper_set.add(m) + mappers = [m for m in self.self_and_descendants if m in mapper_set] + else: + mappers = [] + + if selectable is not None: + tables = set( + sql_util.find_tables(selectable, include_aliases=True) + ) + mappers = [m for m in mappers if m.local_table in tables] + return mappers + + def _selectable_from_mappers( + self, mappers: Iterable[Mapper[Any]], innerjoin: bool + ) -> FromClause: + """given a list of mappers (assumed to be within this mapper's + inheritance hierarchy), construct an outerjoin amongst those mapper's + mapped tables. + + """ + from_obj = self.persist_selectable + for m in mappers: + if m is self: + continue + if m.concrete: + raise sa_exc.InvalidRequestError( + "'with_polymorphic()' requires 'selectable' argument " + "when concrete-inheriting mappers are used." + ) + elif not m.single: + if innerjoin: + from_obj = from_obj.join( + m.local_table, m.inherit_condition + ) + else: + from_obj = from_obj.outerjoin( + m.local_table, m.inherit_condition + ) + + return from_obj + + @HasMemoized.memoized_attribute + def _version_id_has_server_side_value(self) -> bool: + vid_col = self.version_id_col + + if vid_col is None: + return False + + elif not isinstance(vid_col, Column): + return True + else: + return vid_col.server_default is not None or ( + vid_col.default is not None + and ( + not vid_col.default.is_scalar + and not vid_col.default.is_callable + ) + ) + + @HasMemoized.memoized_attribute + def _single_table_criterion(self): + if self.single and self.inherits and self.polymorphic_on is not None: + return self.polymorphic_on._annotate( + {"parententity": self, "parentmapper": self} + ).in_( + [ + m.polymorphic_identity + for m in self.self_and_descendants + if not m.polymorphic_abstract + ] + ) + else: + return None + + @HasMemoized.memoized_attribute + def _has_aliased_polymorphic_fromclause(self): + """return True if with_polymorphic[1] is an aliased fromclause, + like a subquery. + + As of #8168, polymorphic adaption with ORMAdapter is used only + if this is present. + + """ + return self.with_polymorphic and isinstance( + self.with_polymorphic[1], + expression.AliasedReturnsRows, + ) + + @HasMemoized.memoized_attribute + def _should_select_with_poly_adapter(self): + """determine if _MapperEntity or _ORMColumnEntity will need to use + polymorphic adaption when setting up a SELECT as well as fetching + rows for mapped classes and subclasses against this Mapper. + + moved here from context.py for #8456 to generalize the ruleset + for this condition. + + """ + + # this has been simplified as of #8456. + # rule is: if we have a with_polymorphic or a concrete-style + # polymorphic selectable, *or* if the base mapper has either of those, + # we turn on the adaption thing. if not, we do *no* adaption. + # + # (UPDATE for #8168: the above comment was not accurate, as we were + # still saying "do polymorphic" if we were using an auto-generated + # flattened JOIN for with_polymorphic.) + # + # this splits the behavior among the "regular" joined inheritance + # and single inheritance mappers, vs. the "weird / difficult" + # concrete and joined inh mappings that use a with_polymorphic of + # some kind or polymorphic_union. + # + # note we have some tests in test_polymorphic_rel that query against + # a subclass, then refer to the superclass that has a with_polymorphic + # on it (such as test_join_from_polymorphic_explicit_aliased_three). + # these tests actually adapt the polymorphic selectable (like, the + # UNION or the SELECT subquery with JOIN in it) to be just the simple + # subclass table. Hence even if we are a "plain" inheriting mapper + # but our base has a wpoly on it, we turn on adaption. This is a + # legacy case we should probably disable. + # + # + # UPDATE: simplified way more as of #8168. polymorphic adaption + # is turned off even if with_polymorphic is set, as long as there + # is no user-defined aliased selectable / subquery configured. + # this scales back the use of polymorphic adaption in practice + # to basically no cases except for concrete inheritance with a + # polymorphic base class. + # + return ( + self._has_aliased_polymorphic_fromclause + or self._requires_row_aliasing + or (self.base_mapper._has_aliased_polymorphic_fromclause) + or self.base_mapper._requires_row_aliasing + ) + + @HasMemoized.memoized_attribute + def _with_polymorphic_mappers(self) -> Sequence[Mapper[Any]]: + self._check_configure() + + if not self.with_polymorphic: + return [] + return self._mappers_from_spec(*self.with_polymorphic) + + @HasMemoized.memoized_attribute + def _post_inspect(self): + """This hook is invoked by attribute inspection. + + E.g. when Query calls: + + coercions.expect(roles.ColumnsClauseRole, ent, keep_inspect=True) + + This allows the inspection process run a configure mappers hook. + + """ + self._check_configure() + + @HasMemoized_ro_memoized_attribute + def _with_polymorphic_selectable(self) -> FromClause: + if not self.with_polymorphic: + return self.persist_selectable + + spec, selectable = self.with_polymorphic + if selectable is not None: + return selectable + else: + return self._selectable_from_mappers( + self._mappers_from_spec(spec, selectable), False + ) + + with_polymorphic_mappers = _with_polymorphic_mappers + """The list of :class:`_orm.Mapper` objects included in the + default "polymorphic" query. + + """ + + @HasMemoized_ro_memoized_attribute + def _insert_cols_evaluating_none(self): + return { + table: frozenset( + col for col in columns if col.type.should_evaluate_none + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _insert_cols_as_none(self): + return { + table: frozenset( + col.key + for col in columns + if not col.primary_key + and not col.server_default + and not col.default + and not col.type.should_evaluate_none + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _propkey_to_col(self): + return { + table: {self._columntoproperty[col].key: col for col in columns} + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _pk_keys_by_table(self): + return { + table: frozenset([col.key for col in pks]) + for table, pks in self._pks_by_table.items() + } + + @HasMemoized.memoized_attribute + def _pk_attr_keys_by_table(self): + return { + table: frozenset([self._columntoproperty[col].key for col in pks]) + for table, pks in self._pks_by_table.items() + } + + @HasMemoized.memoized_attribute + def _server_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_default is not None + or ( + col.default is not None + and col.default.is_clause_element + ) + ] + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _server_onupdate_default_cols( + self, + ) -> Mapping[FromClause, FrozenSet[Column[Any]]]: + return { + table: frozenset( + [ + col + for col in cast("Iterable[Column[Any]]", columns) + if col.server_onupdate is not None + or ( + col.onupdate is not None + and col.onupdate.is_clause_element + ) + ] + ) + for table, columns in self._cols_by_table.items() + } + + @HasMemoized.memoized_attribute + def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_onupdate_default_col_keys( + self, + ) -> Mapping[FromClause, FrozenSet[str]]: + return { + table: frozenset(col.key for col in cols if col.key is not None) + for table, cols in self._server_onupdate_default_cols.items() + } + + @HasMemoized.memoized_attribute + def _server_default_plus_onupdate_propkeys(self) -> Set[str]: + result: Set[str] = set() + + col_to_property = self._columntoproperty + for table, columns in self._server_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + for table, columns in self._server_onupdate_default_cols.items(): + result.update( + col_to_property[col].key + for col in columns.intersection(col_to_property) + ) + return result + + @HasMemoized.memoized_instancemethod + def __clause_element__(self): + annotations: Dict[str, Any] = { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + } + if self.persist_selectable is not self.local_table: + # joined table inheritance, with polymorphic selectable, + # etc. + annotations["dml_table"] = self.local_table._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + return self.selectable._annotate(annotations)._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + @util.memoized_property + def select_identity_token(self): + return ( + expression.null() + ._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "identity_token": True, + } + ) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) + + @property + def selectable(self) -> FromClause: + """The :class:`_schema.FromClause` construct this + :class:`_orm.Mapper` selects from by default. + + Normally, this is equivalent to :attr:`.persist_selectable`, unless + the ``with_polymorphic`` feature is in use, in which case the + full "polymorphic" selectable is returned. + + """ + return self._with_polymorphic_selectable + + def _with_polymorphic_args( + self, + spec: Any = None, + selectable: Union[Literal[False, None], FromClause] = False, + innerjoin: bool = False, + ) -> Tuple[Sequence[Mapper[Any]], FromClause]: + if selectable not in (None, False): + selectable = coercions.expect( + roles.StrictFromClauseRole, selectable, allow_select=True + ) + + if self.with_polymorphic: + if not spec: + spec = self.with_polymorphic[0] + if selectable is False: + selectable = self.with_polymorphic[1] + elif selectable is False: + selectable = None + mappers = self._mappers_from_spec(spec, selectable) + if selectable is not None: + return mappers, selectable + else: + return mappers, self._selectable_from_mappers(mappers, innerjoin) + + @HasMemoized.memoized_attribute + def _polymorphic_properties(self): + return list( + self._iterate_polymorphic_properties( + self._with_polymorphic_mappers + ) + ) + + @property + def _all_column_expressions(self): + poly_properties = self._polymorphic_properties + adapter = self._polymorphic_adapter + + return [ + adapter.columns[c] if adapter else c + for prop in poly_properties + if isinstance(prop, properties.ColumnProperty) + and prop._renders_in_subqueries + for c in prop.columns + ] + + def _columns_plus_keys(self, polymorphic_mappers=()): + if polymorphic_mappers: + poly_properties = self._iterate_polymorphic_properties( + polymorphic_mappers + ) + else: + poly_properties = self._polymorphic_properties + + return [ + (prop.key, prop.columns[0]) + for prop in poly_properties + if isinstance(prop, properties.ColumnProperty) + ] + + @HasMemoized.memoized_attribute + def _polymorphic_adapter(self) -> Optional[orm_util.ORMAdapter]: + if self._has_aliased_polymorphic_fromclause: + return orm_util.ORMAdapter( + orm_util._TraceAdaptRole.MAPPER_POLYMORPHIC_ADAPTER, + self, + selectable=self.selectable, + equivalents=self._equivalent_columns, + limit_on_entity=False, + ) + else: + return None + + def _iterate_polymorphic_properties(self, mappers=None): + """Return an iterator of MapperProperty objects which will render into + a SELECT.""" + if mappers is None: + mappers = self._with_polymorphic_mappers + + if not mappers: + for c in self.iterate_properties: + yield c + else: + # in the polymorphic case, filter out discriminator columns + # from other mappers, as these are sometimes dependent on that + # mapper's polymorphic selectable (which we don't want rendered) + for c in util.unique_list( + chain( + *[ + list(mapper.iterate_properties) + for mapper in [self] + mappers + ] + ) + ): + if getattr(c, "_is_polymorphic_discriminator", False) and ( + self.polymorphic_on is None + or c.columns[0] is not self.polymorphic_on + ): + continue + yield c + + @HasMemoized.memoized_attribute + def attrs(self) -> util.ReadOnlyProperties[MapperProperty[Any]]: + """A namespace of all :class:`.MapperProperty` objects + associated this mapper. + + This is an object that provides each property based on + its key name. For instance, the mapper for a + ``User`` class which has ``User.name`` attribute would + provide ``mapper.attrs.name``, which would be the + :class:`.ColumnProperty` representing the ``name`` + column. The namespace object can also be iterated, + which would yield each :class:`.MapperProperty`. + + :class:`_orm.Mapper` has several pre-filtered views + of this attribute which limit the types of properties + returned, including :attr:`.synonyms`, :attr:`.column_attrs`, + :attr:`.relationships`, and :attr:`.composites`. + + .. warning:: + + The :attr:`_orm.Mapper.attrs` accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.attrs[somename]`` over + ``getattr(mapper.attrs, somename)`` to avoid name collisions. + + .. seealso:: + + :attr:`_orm.Mapper.all_orm_descriptors` + + """ + + self._check_configure() + return util.ReadOnlyProperties(self._props) + + @HasMemoized.memoized_attribute + def all_orm_descriptors(self) -> util.ReadOnlyProperties[InspectionAttr]: + """A namespace of all :class:`.InspectionAttr` attributes associated + with the mapped class. + + These attributes are in all cases Python :term:`descriptors` + associated with the mapped class or its superclasses. + + This namespace includes attributes that are mapped to the class + as well as attributes declared by extension modules. + It includes any Python descriptor type that inherits from + :class:`.InspectionAttr`. This includes + :class:`.QueryableAttribute`, as well as extension types such as + :class:`.hybrid_property`, :class:`.hybrid_method` and + :class:`.AssociationProxy`. + + To distinguish between mapped attributes and extension attributes, + the attribute :attr:`.InspectionAttr.extension_type` will refer + to a constant that distinguishes between different extension types. + + The sorting of the attributes is based on the following rules: + + 1. Iterate through the class and its superclasses in order from + subclass to superclass (i.e. iterate through ``cls.__mro__``) + + 2. For each class, yield the attributes in the order in which they + appear in ``__dict__``, with the exception of those in step + 3 below. In Python 3.6 and above this ordering will be the + same as that of the class' construction, with the exception + of attributes that were added after the fact by the application + or the mapper. + + 3. If a certain attribute key is also in the superclass ``__dict__``, + then it's included in the iteration for that class, and not the + class in which it first appeared. + + The above process produces an ordering that is deterministic in terms + of the order in which attributes were assigned to the class. + + .. versionchanged:: 1.3.19 ensured deterministic ordering for + :meth:`_orm.Mapper.all_orm_descriptors`. + + When dealing with a :class:`.QueryableAttribute`, the + :attr:`.QueryableAttribute.property` attribute refers to the + :class:`.MapperProperty` property, which is what you get when + referring to the collection of mapped properties via + :attr:`_orm.Mapper.attrs`. + + .. warning:: + + The :attr:`_orm.Mapper.all_orm_descriptors` + accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.all_orm_descriptors[somename]`` over + ``getattr(mapper.all_orm_descriptors, somename)`` to avoid name + collisions. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` + + """ + return util.ReadOnlyProperties( + dict(self.class_manager._all_sqla_attributes()) + ) + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.descriptor_props") + def _pk_synonyms(self) -> Dict[str, str]: + """return a dictionary of {syn_attribute_name: pk_attr_name} for + all synonyms that refer to primary key columns + + """ + descriptor_props = util.preloaded.orm_descriptor_props + + pk_keys = {prop.key for prop in self._identity_key_props} + + return { + syn.key: syn.name + for k, syn in self._props.items() + if isinstance(syn, descriptor_props.SynonymProperty) + and syn.name in pk_keys + } + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.descriptor_props") + def synonyms(self) -> util.ReadOnlyProperties[SynonymProperty[Any]]: + """Return a namespace of all :class:`.Synonym` + properties maintained by this :class:`_orm.Mapper`. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + descriptor_props = util.preloaded.orm_descriptor_props + + return self._filter_properties(descriptor_props.SynonymProperty) + + @property + def entity_namespace(self): + return self.class_ + + @HasMemoized.memoized_attribute + def column_attrs(self) -> util.ReadOnlyProperties[ColumnProperty[Any]]: + """Return a namespace of all :class:`.ColumnProperty` + properties maintained by this :class:`_orm.Mapper`. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + return self._filter_properties(properties.ColumnProperty) + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.relationships") + def relationships( + self, + ) -> util.ReadOnlyProperties[RelationshipProperty[Any]]: + """A namespace of all :class:`.Relationship` properties + maintained by this :class:`_orm.Mapper`. + + .. warning:: + + the :attr:`_orm.Mapper.relationships` accessor namespace is an + instance of :class:`.OrderedProperties`. This is + a dictionary-like object which includes a small number of + named methods such as :meth:`.OrderedProperties.items` + and :meth:`.OrderedProperties.values`. When + accessing attributes dynamically, favor using the dict-access + scheme, e.g. ``mapper.relationships[somename]`` over + ``getattr(mapper.relationships, somename)`` to avoid name + collisions. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + return self._filter_properties( + util.preloaded.orm_relationships.RelationshipProperty + ) + + @HasMemoized.memoized_attribute + @util.preload_module("sqlalchemy.orm.descriptor_props") + def composites(self) -> util.ReadOnlyProperties[CompositeProperty[Any]]: + """Return a namespace of all :class:`.Composite` + properties maintained by this :class:`_orm.Mapper`. + + .. seealso:: + + :attr:`_orm.Mapper.attrs` - namespace of all + :class:`.MapperProperty` + objects. + + """ + return self._filter_properties( + util.preloaded.orm_descriptor_props.CompositeProperty + ) + + def _filter_properties( + self, type_: Type[_MP] + ) -> util.ReadOnlyProperties[_MP]: + self._check_configure() + return util.ReadOnlyProperties( + util.OrderedDict( + (k, v) for k, v in self._props.items() if isinstance(v, type_) + ) + ) + + @HasMemoized.memoized_attribute + def _get_clause(self): + """create a "get clause" based on the primary key. this is used + by query.get() and many-to-one lazyloads to load this item + by primary key. + + """ + params = [ + ( + primary_key, + sql.bindparam("pk_%d" % idx, type_=primary_key.type), + ) + for idx, primary_key in enumerate(self.primary_key, 1) + ] + return ( + sql.and_(*[k == v for (k, v) in params]), + util.column_dict(params), + ) + + @HasMemoized.memoized_attribute + def _equivalent_columns(self) -> _EquivalentColumnMap: + """Create a map of all equivalent columns, based on + the determination of column pairs that are equated to + one another based on inherit condition. This is designed + to work with the queries that util.polymorphic_union + comes up with, which often don't include the columns from + the base table directly (including the subclass table columns + only). + + The resulting structure is a dictionary of columns mapped + to lists of equivalent columns, e.g.:: + + { + tablea.col1: + {tableb.col1, tablec.col1}, + tablea.col2: + {tabled.col2} + } + + """ + result: _EquivalentColumnMap = {} + + def visit_binary(binary): + if binary.operator == operators.eq: + if binary.left in result: + result[binary.left].add(binary.right) + else: + result[binary.left] = {binary.right} + if binary.right in result: + result[binary.right].add(binary.left) + else: + result[binary.right] = {binary.left} + + for mapper in self.base_mapper.self_and_descendants: + if mapper.inherit_condition is not None: + visitors.traverse( + mapper.inherit_condition, {}, {"binary": visit_binary} + ) + + return result + + def _is_userland_descriptor(self, assigned_name: str, obj: Any) -> bool: + if isinstance( + obj, + ( + _MappedAttribute, + instrumentation.ClassManager, + expression.ColumnElement, + ), + ): + return False + else: + return assigned_name not in self._dataclass_fields + + @HasMemoized.memoized_attribute + def _dataclass_fields(self): + return [f.name for f in util.dataclass_fields(self.class_)] + + def _should_exclude(self, name, assigned_name, local, column): + """determine whether a particular property should be implicitly + present on the class. + + This occurs when properties are propagated from an inherited class, or + are applied from the columns present in the mapped table. + + """ + + if column is not None and sql_base._never_select_column(column): + return True + + # check for class-bound attributes and/or descriptors, + # either local or from an inherited class + # ignore dataclass field default values + if local: + if self.class_.__dict__.get( + assigned_name, None + ) is not None and self._is_userland_descriptor( + assigned_name, self.class_.__dict__[assigned_name] + ): + return True + else: + attr = self.class_manager._get_class_attr_mro(assigned_name, None) + if attr is not None and self._is_userland_descriptor( + assigned_name, attr + ): + return True + + if ( + self.include_properties is not None + and name not in self.include_properties + and (column is None or column not in self.include_properties) + ): + self._log("not including property %s" % (name)) + return True + + if self.exclude_properties is not None and ( + name in self.exclude_properties + or (column is not None and column in self.exclude_properties) + ): + self._log("excluding property %s" % (name)) + return True + + return False + + def common_parent(self, other: Mapper[Any]) -> bool: + """Return true if the given mapper shares a + common inherited parent as this mapper.""" + + return self.base_mapper is other.base_mapper + + def is_sibling(self, other: Mapper[Any]) -> bool: + """return true if the other mapper is an inheriting sibling to this + one. common parent but different branch + + """ + return ( + self.base_mapper is other.base_mapper + and not self.isa(other) + and not other.isa(self) + ) + + def _canload( + self, state: InstanceState[Any], allow_subtypes: bool + ) -> bool: + s = self.primary_mapper() + if self.polymorphic_on is not None or allow_subtypes: + return _state_mapper(state).isa(s) + else: + return _state_mapper(state) is s + + def isa(self, other: Mapper[Any]) -> bool: + """Return True if the this mapper inherits from the given mapper.""" + + m: Optional[Mapper[Any]] = self + while m and m is not other: + m = m.inherits + return bool(m) + + def iterate_to_root(self) -> Iterator[Mapper[Any]]: + m: Optional[Mapper[Any]] = self + while m: + yield m + m = m.inherits + + @HasMemoized.memoized_attribute + def self_and_descendants(self) -> Sequence[Mapper[Any]]: + """The collection including this mapper and all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + """ + descendants = [] + stack = deque([self]) + while stack: + item = stack.popleft() + descendants.append(item) + stack.extend(item._inheriting_mappers) + return util.WeakSequence(descendants) + + def polymorphic_iterator(self) -> Iterator[Mapper[Any]]: + """Iterate through the collection including this mapper and + all descendant mappers. + + This includes not just the immediately inheriting mappers but + all their inheriting mappers as well. + + To iterate through an entire hierarchy, use + ``mapper.base_mapper.polymorphic_iterator()``. + + """ + return iter(self.self_and_descendants) + + def primary_mapper(self) -> Mapper[Any]: + """Return the primary mapper corresponding to this mapper's class key + (class).""" + + return self.class_manager.mapper + + @property + def primary_base_mapper(self) -> Mapper[Any]: + return self.class_manager.mapper.base_mapper + + def _result_has_identity_key(self, result, adapter=None): + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + rk = result.keys() + for col in pk_cols: + if col not in rk: + return False + else: + return True + + def identity_key_from_row( + self, + row: Optional[Union[Row[Any], RowMapping]], + identity_token: Optional[Any] = None, + adapter: Optional[ORMAdapter] = None, + ) -> _IdentityKeyType[_O]: + """Return an identity-map key for use in storing/retrieving an + item from the identity map. + + :param row: A :class:`.Row` or :class:`.RowMapping` produced from a + result set that selected from the ORM mapped primary key columns. + + .. versionchanged:: 2.0 + :class:`.Row` or :class:`.RowMapping` are accepted + for the "row" argument + + """ + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + + if hasattr(row, "_mapping"): + mapping = row._mapping # type: ignore + else: + mapping = cast("Mapping[Any, Any]", row) + + return ( + self._identity_class, + tuple(mapping[column] for column in pk_cols), # type: ignore + identity_token, + ) + + def identity_key_from_primary_key( + self, + primary_key: Tuple[Any, ...], + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[_O]: + """Return an identity-map key for use in storing/retrieving an + item from an identity map. + + :param primary_key: A list of values indicating the identifier. + + """ + return ( + self._identity_class, + tuple(primary_key), + identity_token, + ) + + def identity_key_from_instance(self, instance: _O) -> _IdentityKeyType[_O]: + """Return the identity key for the given instance, based on + its primary key attributes. + + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + This value is typically also found on the instance state under the + attribute name `key`. + + """ + state = attributes.instance_state(instance) + return self._identity_key_from_state(state, PassiveFlag.PASSIVE_OFF) + + def _identity_key_from_state( + self, + state: InstanceState[_O], + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, + ) -> _IdentityKeyType[_O]: + dict_ = state.dict + manager = state.manager + return ( + self._identity_class, + tuple( + [ + manager[prop.key].impl.get(state, dict_, passive) + for prop in self._identity_key_props + ] + ), + state.identity_token, + ) + + def primary_key_from_instance(self, instance: _O) -> Tuple[Any, ...]: + """Return the list of primary key values for the given + instance. + + If the instance's state is expired, calling this method + will result in a database check to see if the object has been deleted. + If the row no longer exists, + :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + """ + state = attributes.instance_state(instance) + identity_key = self._identity_key_from_state( + state, PassiveFlag.PASSIVE_OFF + ) + return identity_key[1] + + @HasMemoized.memoized_attribute + def _persistent_sortkey_fn(self): + key_fns = [col.type.sort_key_function for col in self.primary_key] + + if set(key_fns).difference([None]): + + def key(state): + return tuple( + key_fn(val) if key_fn is not None else val + for key_fn, val in zip(key_fns, state.key[1]) + ) + + else: + + def key(state): + return state.key[1] + + return key + + @HasMemoized.memoized_attribute + def _identity_key_props(self): + return [self._columntoproperty[col] for col in self.primary_key] + + @HasMemoized.memoized_attribute + def _all_pk_cols(self): + collection: Set[ColumnClause[Any]] = set() + for table in self.tables: + collection.update(self._pks_by_table[table]) + return collection + + @HasMemoized.memoized_attribute + def _should_undefer_in_wildcard(self): + cols: Set[ColumnElement[Any]] = set(self.primary_key) + if self.polymorphic_on is not None: + cols.add(self.polymorphic_on) + return cols + + @HasMemoized.memoized_attribute + def _primary_key_propkeys(self): + return {self._columntoproperty[col].key for col in self._all_pk_cols} + + def _get_state_attr_by_column( + self, + state: InstanceState[_O], + dict_: _InstanceDict, + column: ColumnElement[Any], + passive: PassiveFlag = PassiveFlag.PASSIVE_RETURN_NO_VALUE, + ) -> Any: + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.get(state, dict_, passive=passive) + + def _set_committed_state_attr_by_column(self, state, dict_, column, value): + prop = self._columntoproperty[column] + state.manager[prop.key].impl.set_committed_value(state, dict_, value) + + def _set_state_attr_by_column(self, state, dict_, column, value): + prop = self._columntoproperty[column] + state.manager[prop.key].impl.set(state, dict_, value, None) + + def _get_committed_attr_by_column(self, obj, column): + state = attributes.instance_state(obj) + dict_ = attributes.instance_dict(obj) + return self._get_committed_state_attr_by_column( + state, dict_, column, passive=PassiveFlag.PASSIVE_OFF + ) + + def _get_committed_state_attr_by_column( + self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE + ): + prop = self._columntoproperty[column] + return state.manager[prop.key].impl.get_committed_value( + state, dict_, passive=passive + ) + + def _optimized_get_statement(self, state, attribute_names): + """assemble a WHERE clause which retrieves a given state by primary + key, using a minimized set of tables. + + Applies to a joined-table inheritance mapper where the + requested attribute names are only present on joined tables, + not the base table. The WHERE clause attempts to include + only those tables to minimize joins. + + """ + props = self._props + + col_attribute_names = set(attribute_names).intersection( + state.mapper.column_attrs.keys() + ) + tables: Set[FromClause] = set( + chain( + *[ + sql_util.find_tables(c, check_columns=True) + for key in col_attribute_names + for c in props[key].columns + ] + ) + ) + + if self.base_mapper.local_table in tables: + return None + + def visit_binary(binary): + leftcol = binary.left + rightcol = binary.right + if leftcol is None or rightcol is None: + return + + if leftcol.table not in tables: + leftval = self._get_committed_state_attr_by_column( + state, + state.dict, + leftcol, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + if leftval in orm_util._none_set: + raise _OptGetColumnsNotAvailable() + binary.left = sql.bindparam( + None, leftval, type_=binary.right.type + ) + elif rightcol.table not in tables: + rightval = self._get_committed_state_attr_by_column( + state, + state.dict, + rightcol, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + if rightval in orm_util._none_set: + raise _OptGetColumnsNotAvailable() + binary.right = sql.bindparam( + None, rightval, type_=binary.right.type + ) + + allconds: List[ColumnElement[bool]] = [] + + start = False + + # as of #7507, from the lowest base table on upwards, + # we include all intermediary tables. + + for mapper in reversed(list(self.iterate_to_root())): + if mapper.local_table in tables: + start = True + elif not isinstance(mapper.local_table, expression.TableClause): + return None + if start and not mapper.single: + assert mapper.inherits + assert not mapper.concrete + assert mapper.inherit_condition is not None + allconds.append(mapper.inherit_condition) + tables.add(mapper.local_table) + + # only the bottom table needs its criteria to be altered to fit + # the primary key ident - the rest of the tables upwards to the + # descendant-most class should all be present and joined to each + # other. + try: + _traversed = visitors.cloned_traverse( + allconds[0], {}, {"binary": visit_binary} + ) + except _OptGetColumnsNotAvailable: + return None + else: + allconds[0] = _traversed + + cond = sql.and_(*allconds) + + cols = [] + for key in col_attribute_names: + cols.extend(props[key].columns) + return ( + sql.select(*cols) + .where(cond) + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + ) + + def _iterate_to_target_viawpoly(self, mapper): + if self.isa(mapper): + prev = self + for m in self.iterate_to_root(): + yield m + + if m is not prev and prev not in m._with_polymorphic_mappers: + break + + prev = m + if m is mapper: + break + + @HasMemoized.memoized_attribute + def _would_selectinload_combinations_cache(self): + return {} + + def _would_selectin_load_only_from_given_mapper(self, super_mapper): + """return True if this mapper would "selectin" polymorphic load based + on the given super mapper, and not from a setting from a subclass. + + given:: + + class A: + ... + + class B(A): + __mapper_args__ = {"polymorphic_load": "selectin"} + + class C(B): + ... + + class D(B): + __mapper_args__ = {"polymorphic_load": "selectin"} + + ``inspect(C)._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns True, because C does selectin loading because of B's setting. + + OTOH, ``inspect(D) + ._would_selectin_load_only_from_given_mapper(inspect(B))`` + returns False, because D does selectin loading because of its own + setting; when we are doing a selectin poly load from B, we want to + filter out D because it would already have its own selectin poly load + set up separately. + + Added as part of #9373. + + """ + cache = self._would_selectinload_combinations_cache + + try: + return cache[super_mapper] + except KeyError: + pass + + # assert that given object is a supermapper, meaning we already + # strong reference it directly or indirectly. this allows us + # to not worry that we are creating new strongrefs to unrelated + # mappers or other objects. + assert self.isa(super_mapper) + + mapper = super_mapper + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == "selectin": + retval = m is super_mapper + break + else: + retval = False + + cache[super_mapper] = retval + return retval + + def _should_selectin_load(self, enabled_via_opt, polymorphic_from): + if not enabled_via_opt: + # common case, takes place for all polymorphic loads + mapper = polymorphic_from + for m in self._iterate_to_target_viawpoly(mapper): + if m.polymorphic_load == "selectin": + return m + else: + # uncommon case, selectin load options were used + enabled_via_opt = set(enabled_via_opt) + enabled_via_opt_mappers = {e.mapper: e for e in enabled_via_opt} + for entity in enabled_via_opt.union([polymorphic_from]): + mapper = entity.mapper + for m in self._iterate_to_target_viawpoly(mapper): + if ( + m.polymorphic_load == "selectin" + or m in enabled_via_opt_mappers + ): + return enabled_via_opt_mappers.get(m, m) + + return None + + @util.preload_module("sqlalchemy.orm.strategy_options") + def _subclass_load_via_in(self, entity, polymorphic_from): + """Assemble a that can load the columns local to + this subclass as a SELECT with IN. + + """ + strategy_options = util.preloaded.orm_strategy_options + + assert self.inherits + + if self.polymorphic_on is not None: + polymorphic_prop = self._columntoproperty[self.polymorphic_on] + keep_props = set([polymorphic_prop] + self._identity_key_props) + else: + keep_props = set(self._identity_key_props) + + disable_opt = strategy_options.Load(entity) + enable_opt = strategy_options.Load(entity) + + classes_to_include = {self} + m: Optional[Mapper[Any]] = self.inherits + while ( + m is not None + and m is not polymorphic_from + and m.polymorphic_load == "selectin" + ): + classes_to_include.add(m) + m = m.inherits + + for prop in self.attrs: + # skip prop keys that are not instrumented on the mapped class. + # this is primarily the "_sa_polymorphic_on" property that gets + # created for an ad-hoc polymorphic_on SQL expression, issue #8704 + if prop.key not in self.class_manager: + continue + + if prop.parent in classes_to_include or prop in keep_props: + # "enable" options, to turn on the properties that we want to + # load by default (subject to options from the query) + if not isinstance(prop, StrategizedProperty): + continue + + enable_opt = enable_opt._set_generic_strategy( + # convert string name to an attribute before passing + # to loader strategy. note this must be in terms + # of given entity, such as AliasedClass, etc. + (getattr(entity.entity_namespace, prop.key),), + dict(prop.strategy_key), + _reconcile_to_other=True, + ) + else: + # "disable" options, to turn off the properties from the + # superclass that we *don't* want to load, applied after + # the options from the query to override them + disable_opt = disable_opt._set_generic_strategy( + # convert string name to an attribute before passing + # to loader strategy. note this must be in terms + # of given entity, such as AliasedClass, etc. + (getattr(entity.entity_namespace, prop.key),), + {"do_nothing": True}, + _reconcile_to_other=False, + ) + + primary_key = [ + sql_util._deep_annotate(pk, {"_orm_adapt": True}) + for pk in self.primary_key + ] + + in_expr: ColumnElement[Any] + + if len(primary_key) > 1: + in_expr = sql.tuple_(*primary_key) + else: + in_expr = primary_key[0] + + if entity.is_aliased_class: + assert entity.mapper is self + + q = sql.select(entity).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + + in_expr = entity._adapter.traverse(in_expr) + primary_key = [entity._adapter.traverse(k) for k in primary_key] + q = q.where( + in_expr.in_(sql.bindparam("primary_keys", expanding=True)) + ).order_by(*primary_key) + else: + q = sql.select(self).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + q = q.where( + in_expr.in_(sql.bindparam("primary_keys", expanding=True)) + ).order_by(*primary_key) + + return q, enable_opt, disable_opt + + @HasMemoized.memoized_attribute + def _subclass_load_via_in_mapper(self): + # the default is loading this mapper against the basemost mapper + return self._subclass_load_via_in(self, self.base_mapper) + + def cascade_iterator( + self, + type_: str, + state: InstanceState[_O], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[ + Tuple[object, Mapper[Any], InstanceState[Any], _InstanceDict] + ]: + r"""Iterate each element and its mapper in an object graph, + for all relationships that meet the given cascade rule. + + :param type\_: + The name of the cascade rule (i.e. ``"save-update"``, ``"delete"``, + etc.). + + .. note:: the ``"all"`` cascade is not accepted here. For a generic + object traversal function, see :ref:`faq_walk_objects`. + + :param state: + The lead InstanceState. child items will be processed per + the relationships defined for this object's mapper. + + :return: the method yields individual object instances. + + .. seealso:: + + :ref:`unitofwork_cascades` + + :ref:`faq_walk_objects` - illustrates a generic function to + traverse all objects without relying on cascades. + + """ + visited_states: Set[InstanceState[Any]] = set() + prp, mpp = object(), object() + + assert state.mapper.isa(self) + + # this is actually a recursive structure, fully typing it seems + # a little too difficult for what it's worth here + visitables: Deque[ + Tuple[ + Deque[Any], + object, + Optional[InstanceState[Any]], + Optional[_InstanceDict], + ] + ] + + visitables = deque( + [(deque(state.mapper._props.values()), prp, state, state.dict)] + ) + + while visitables: + iterator, item_type, parent_state, parent_dict = visitables[-1] + if not iterator: + visitables.pop() + continue + + if item_type is prp: + prop = iterator.popleft() + if not prop.cascade or type_ not in prop.cascade: + continue + assert parent_state is not None + assert parent_dict is not None + queue = deque( + prop.cascade_iterator( + type_, + parent_state, + parent_dict, + visited_states, + halt_on, + ) + ) + if queue: + visitables.append((queue, mpp, None, None)) + elif item_type is mpp: + ( + instance, + instance_mapper, + corresponding_state, + corresponding_dict, + ) = iterator.popleft() + yield ( + instance, + instance_mapper, + corresponding_state, + corresponding_dict, + ) + visitables.append( + ( + deque(instance_mapper._props.values()), + prp, + corresponding_state, + corresponding_dict, + ) + ) + + @HasMemoized.memoized_attribute + def _compiled_cache(self): + return util.LRUCache(self._compiled_cache_size) + + @HasMemoized.memoized_attribute + def _multiple_persistence_tables(self): + return len(self.tables) > 1 + + @HasMemoized.memoized_attribute + def _sorted_tables(self): + table_to_mapper: Dict[TableClause, Mapper[Any]] = {} + + for mapper in self.base_mapper.self_and_descendants: + for t in mapper.tables: + table_to_mapper.setdefault(t, mapper) + + extra_dependencies = [] + for table, mapper in table_to_mapper.items(): + super_ = mapper.inherits + if super_: + extra_dependencies.extend( + [(super_table, table) for super_table in super_.tables] + ) + + def skip(fk): + # attempt to skip dependencies that are not + # significant to the inheritance chain + # for two tables that are related by inheritance. + # while that dependency may be important, it's technically + # not what we mean to sort on here. + parent = table_to_mapper.get(fk.parent.table) + dep = table_to_mapper.get(fk.column.table) + if ( + parent is not None + and dep is not None + and dep is not parent + and dep.inherit_condition is not None + ): + cols = set(sql_util._find_columns(dep.inherit_condition)) + if parent.inherit_condition is not None: + cols = cols.union( + sql_util._find_columns(parent.inherit_condition) + ) + return fk.parent not in cols and fk.column not in cols + else: + return fk.parent not in cols + return False + + sorted_ = sql_util.sort_tables( + table_to_mapper, + skip_fn=skip, + extra_dependencies=extra_dependencies, + ) + + ret = util.OrderedDict() + for t in sorted_: + ret[t] = table_to_mapper[t] + return ret + + def _memo(self, key: Any, callable_: Callable[[], _T]) -> _T: + if key in self._memoized_values: + return cast(_T, self._memoized_values[key]) + else: + self._memoized_values[key] = value = callable_() + return value + + @util.memoized_property + def _table_to_equated(self): + """memoized map of tables to collections of columns to be + synchronized upwards to the base mapper.""" + + result: util.defaultdict[ + Table, + List[ + Tuple[ + Mapper[Any], + List[Tuple[ColumnElement[Any], ColumnElement[Any]]], + ] + ], + ] = util.defaultdict(list) + + def set_union(x, y): + return x.union(y) + + for table in self._sorted_tables: + cols = set(table.c) + + for m in self.iterate_to_root(): + if m._inherits_equated_pairs and cols.intersection( + reduce( + set_union, + [l.proxy_set for l, r in m._inherits_equated_pairs], + ) + ): + result[table].append((m, m._inherits_equated_pairs)) + + return result + + +class _OptGetColumnsNotAvailable(Exception): + pass + + +def configure_mappers() -> None: + """Initialize the inter-mapper relationships of all mappers that + have been constructed thus far across all :class:`_orm.registry` + collections. + + The configure step is used to reconcile and initialize the + :func:`_orm.relationship` linkages between mapped classes, as well as to + invoke configuration events such as the + :meth:`_orm.MapperEvents.before_configured` and + :meth:`_orm.MapperEvents.after_configured`, which may be used by ORM + extensions or user-defined extension hooks. + + Mapper configuration is normally invoked automatically, the first time + mappings from a particular :class:`_orm.registry` are used, as well as + whenever mappings are used and additional not-yet-configured mappers have + been constructed. The automatic configuration process however is local only + to the :class:`_orm.registry` involving the target mapper and any related + :class:`_orm.registry` objects which it may depend on; this is + equivalent to invoking the :meth:`_orm.registry.configure` method + on a particular :class:`_orm.registry`. + + By contrast, the :func:`_orm.configure_mappers` function will invoke the + configuration process on all :class:`_orm.registry` objects that + exist in memory, and may be useful for scenarios where many individual + :class:`_orm.registry` objects that are nonetheless interrelated are + in use. + + .. versionchanged:: 1.4 + + As of SQLAlchemy 1.4.0b2, this function works on a + per-:class:`_orm.registry` basis, locating all :class:`_orm.registry` + objects present and invoking the :meth:`_orm.registry.configure` method + on each. The :meth:`_orm.registry.configure` method may be preferred to + limit the configuration of mappers to those local to a particular + :class:`_orm.registry` and/or declarative base class. + + Points at which automatic configuration is invoked include when a mapped + class is instantiated into an instance, as well as when ORM queries + are emitted using :meth:`.Session.query` or :meth:`_orm.Session.execute` + with an ORM-enabled statement. + + The mapper configure process, whether invoked by + :func:`_orm.configure_mappers` or from :meth:`_orm.registry.configure`, + provides several event hooks that can be used to augment the mapper + configuration step. These hooks include: + + * :meth:`.MapperEvents.before_configured` - called once before + :func:`.configure_mappers` or :meth:`_orm.registry.configure` does any + work; this can be used to establish additional options, properties, or + related mappings before the operation proceeds. + + * :meth:`.MapperEvents.mapper_configured` - called as each individual + :class:`_orm.Mapper` is configured within the process; will include all + mapper state except for backrefs set up by other mappers that are still + to be configured. + + * :meth:`.MapperEvents.after_configured` - called once after + :func:`.configure_mappers` or :meth:`_orm.registry.configure` is + complete; at this stage, all :class:`_orm.Mapper` objects that fall + within the scope of the configuration operation will be fully configured. + Note that the calling application may still have other mappings that + haven't been produced yet, such as if they are in modules as yet + unimported, and may also have mappings that are still to be configured, + if they are in other :class:`_orm.registry` collections not part of the + current scope of configuration. + + """ + + _configure_registries(_all_registries(), cascade=True) + + +def _configure_registries( + registries: Set[_RegistryType], cascade: bool +) -> None: + for reg in registries: + if reg._new_mappers: + break + else: + return + + with _CONFIGURE_MUTEX: + global _already_compiling + if _already_compiling: + return + _already_compiling = True + try: + # double-check inside mutex + for reg in registries: + if reg._new_mappers: + break + else: + return + + Mapper.dispatch._for_class(Mapper).before_configured() # type: ignore # noqa: E501 + # initialize properties on all mappers + # note that _mapper_registry is unordered, which + # may randomly conceal/reveal issues related to + # the order of mapper compilation + + _do_configure_registries(registries, cascade) + finally: + _already_compiling = False + Mapper.dispatch._for_class(Mapper).after_configured() # type: ignore + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _do_configure_registries( + registries: Set[_RegistryType], cascade: bool +) -> None: + registry = util.preloaded.orm_decl_api.registry + + orig = set(registries) + + for reg in registry._recurse_with_dependencies(registries): + has_skip = False + + for mapper in reg._mappers_to_configure(): + run_configure = None + + for fn in mapper.dispatch.before_mapper_configured: + run_configure = fn(mapper, mapper.class_) + if run_configure is EXT_SKIP: + has_skip = True + break + if run_configure is EXT_SKIP: + continue + + if getattr(mapper, "_configure_failed", False): + e = sa_exc.InvalidRequestError( + "One or more mappers failed to initialize - " + "can't proceed with initialization of other " + "mappers. Triggering mapper: '%s'. " + "Original exception was: %s" + % (mapper, mapper._configure_failed) + ) + e._configure_failed = mapper._configure_failed # type: ignore + raise e + + if not mapper.configured: + try: + mapper._post_configure_properties() + mapper._expire_memoizations() + mapper.dispatch.mapper_configured(mapper, mapper.class_) + except Exception: + exc = sys.exc_info()[1] + if not hasattr(exc, "_configure_failed"): + mapper._configure_failed = exc + raise + if not has_skip: + reg._new_mappers = False + + if not cascade and reg._dependencies.difference(orig): + raise sa_exc.InvalidRequestError( + "configure was called with cascade=False but " + "additional registries remain" + ) + + +@util.preload_module("sqlalchemy.orm.decl_api") +def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None: + registry = util.preloaded.orm_decl_api.registry + + orig = set(registries) + + for reg in registry._recurse_with_dependents(registries): + if not cascade and reg._dependents.difference(orig): + raise sa_exc.InvalidRequestError( + "Registry has dependent registries that are not disposed; " + "pass cascade=True to clear these also" + ) + + while reg._managers: + try: + manager, _ = reg._managers.popitem() + except KeyError: + # guard against race between while and popitem + pass + else: + reg._dispose_manager_and_mapper(manager) + + reg._non_primary_mappers.clear() + reg._dependents.clear() + for dep in reg._dependencies: + dep._dependents.discard(reg) + reg._dependencies.clear() + # this wasn't done in the 1.3 clear_mappers() and in fact it + # was a bug, as it could cause configure_mappers() to invoke + # the "before_configured" event even though mappers had all been + # disposed. + reg._new_mappers = False + + +def reconstructor(fn): + """Decorate a method as the 'reconstructor' hook. + + Designates a single method as the "reconstructor", an ``__init__``-like + method that will be called by the ORM after the instance has been + loaded from the database or otherwise reconstituted. + + .. tip:: + + The :func:`_orm.reconstructor` decorator makes use of the + :meth:`_orm.InstanceEvents.load` event hook, which can be + used directly. + + The reconstructor will be invoked with no arguments. Scalar + (non-collection) database-mapped attributes of the instance will + be available for use within the function. Eagerly-loaded + collections are generally not yet available and will usually only + contain the first element. ORM state changes made to objects at + this stage will not be recorded for the next flush() operation, so + the activity within a reconstructor should be conservative. + + .. seealso:: + + :meth:`.InstanceEvents.load` + + """ + fn.__sa_reconstructor__ = True + return fn + + +def validates( + *names: str, include_removes: bool = False, include_backrefs: bool = True +) -> Callable[[_Fn], _Fn]: + r"""Decorate a method as a 'validator' for one or more named properties. + + Designates a method as a validator, a method which receives the + name of the attribute as well as a value to be assigned, or in the + case of a collection, the value to be added to the collection. + The function can then raise validation exceptions to halt the + process from continuing (where Python's built-in ``ValueError`` + and ``AssertionError`` exceptions are reasonable choices), or can + modify or replace the value before proceeding. The function should + otherwise return the given value. + + Note that a validator for a collection **cannot** issue a load of that + collection within the validation routine - this usage raises + an assertion to avoid recursion overflows. This is a reentrant + condition which is not supported. + + :param \*names: list of attribute names to be validated. + :param include_removes: if True, "remove" events will be + sent as well - the validation function must accept an additional + argument "is_remove" which will be a boolean. + + :param include_backrefs: defaults to ``True``; if ``False``, the + validation function will not emit if the originator is an attribute + event related via a backref. This can be used for bi-directional + :func:`.validates` usage where only one validator should emit per + attribute operation. + + .. versionchanged:: 2.0.16 This paramter inadvertently defaulted to + ``False`` for releases 2.0.0 through 2.0.15. Its correct default + of ``True`` is restored in 2.0.16. + + .. seealso:: + + :ref:`simple_validators` - usage examples for :func:`.validates` + + """ + + def wrap(fn: _Fn) -> _Fn: + fn.__sa_validators__ = names # type: ignore[attr-defined] + fn.__sa_validation_opts__ = { # type: ignore[attr-defined] + "include_removes": include_removes, + "include_backrefs": include_backrefs, + } + return fn + + return wrap + + +def _event_on_load(state, ctx): + instrumenting_mapper = state.manager.mapper + + if instrumenting_mapper._reconstructor: + instrumenting_mapper._reconstructor(state.obj()) + + +def _event_on_init(state, args, kwargs): + """Run init_instance hooks. + + This also includes mapper compilation, normally not needed + here but helps with some piecemeal configuration + scenarios (such as in the ORM tutorial). + + """ + + instrumenting_mapper = state.manager.mapper + if instrumenting_mapper: + instrumenting_mapper._check_configure() + if instrumenting_mapper._set_polymorphic_identity: + instrumenting_mapper._set_polymorphic_identity(state) + + +class _ColumnMapping(Dict["ColumnElement[Any]", "MapperProperty[Any]"]): + """Error reporting helper for mapper._columntoproperty.""" + + __slots__ = ("mapper",) + + def __init__(self, mapper): + # TODO: weakref would be a good idea here + self.mapper = mapper + + def __missing__(self, column): + prop = self.mapper._props.get(column) + if prop: + raise orm_exc.UnmappedColumnError( + "Column '%s.%s' is not available, due to " + "conflicting property '%s':%r" + % (column.table.name, column.name, column.key, prop) + ) + raise orm_exc.UnmappedColumnError( + "No column %s is configured on mapper %s..." + % (column, self.mapper) + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/path_registry.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/path_registry.py new file mode 100644 index 0000000..76484b3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/path_registry.py @@ -0,0 +1,808 @@ +# orm/path_registry.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +"""Path tracking utilities, representing mapper graph traversals. + +""" + +from __future__ import annotations + +from functools import reduce +from itertools import chain +import logging +import operator +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterator +from typing import List +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from . import base as orm_base +from ._typing import insp_is_mapper_property +from .. import exc +from .. import util +from ..sql import visitors +from ..sql.cache_key import HasCacheKey + +if TYPE_CHECKING: + from ._typing import _InternalEntityType + from .interfaces import MapperProperty + from .mapper import Mapper + from .relationships import RelationshipProperty + from .util import AliasedInsp + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.elements import BindParameter + from ..sql.visitors import anon_map + from ..util.typing import _LiteralStar + from ..util.typing import TypeGuard + + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: ... + + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: ... + +else: + is_root = operator.attrgetter("is_root") + is_entity = operator.attrgetter("is_entity") + + +_SerializedPath = List[Any] +_StrPathToken = str +_PathElementType = Union[ + _StrPathToken, "_InternalEntityType[Any]", "MapperProperty[Any]" +] + +# the representation is in fact +# a tuple with alternating: +# [_InternalEntityType[Any], Union[str, MapperProperty[Any]], +# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...] +# this might someday be a tuple of 2-tuples instead, but paths can be +# chopped at odd intervals as well so this is less flexible +_PathRepresentation = Tuple[_PathElementType, ...] + +# NOTE: these names are weird since the array is 0-indexed, +# the "_Odd" entries are at 0, 2, 4, etc +_OddPathRepresentation = Sequence["_InternalEntityType[Any]"] +_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]] + + +log = logging.getLogger(__name__) + + +def _unreduce_path(path: _SerializedPath) -> PathRegistry: + return PathRegistry.deserialize(path) + + +_WILDCARD_TOKEN: _LiteralStar = "*" +_DEFAULT_TOKEN = "_sa_default" + + +class PathRegistry(HasCacheKey): + """Represent query load paths and registry functions. + + Basically represents structures like: + + (, "orders", , "items", ) + + These structures are generated by things like + query options (joinedload(), subqueryload(), etc.) and are + used to compose keys stored in the query._attributes dictionary + for various options. + + They are then re-composed at query compile/result row time as + the query is formed and as rows are fetched, where they again + serve to compose keys to look up options in the context.attributes + dictionary, which is copied from query._attributes. + + The path structure has a limited amount of caching, where each + "root" ultimately pulls from a fixed registry associated with + the first mapper, that also contains elements for each of its + property keys. However paths longer than two elements, which + are the exception rather than the rule, are generated on an + as-needed basis. + + """ + + __slots__ = () + + is_token = False + is_root = False + has_entity = False + is_property = False + is_entity = False + + is_unnatural: bool + + path: _PathRepresentation + natural_path: _PathRepresentation + parent: Optional[PathRegistry] + root: RootRegistry + + _cache_key_traversal: _CacheKeyTraversalType = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key_list) + ] + + def __eq__(self, other: Any) -> bool: + try: + return other is not None and self.path == other._path_for_compare + except AttributeError: + util.warn( + "Comparison of PathRegistry to %r is not supported" + % (type(other)) + ) + return False + + def __ne__(self, other: Any) -> bool: + try: + return other is None or self.path != other._path_for_compare + except AttributeError: + util.warn( + "Comparison of PathRegistry to %r is not supported" + % (type(other)) + ) + return True + + @property + def _path_for_compare(self) -> Optional[_PathRepresentation]: + return self.path + + def odd_element(self, index: int) -> _InternalEntityType[Any]: + return self.path[index] # type: ignore + + def set(self, attributes: Dict[Any, Any], key: Any, value: Any) -> None: + log.debug("set '%s' on path '%s' to '%s'", key, self, value) + attributes[(key, self.natural_path)] = value + + def setdefault( + self, attributes: Dict[Any, Any], key: Any, value: Any + ) -> None: + log.debug("setdefault '%s' on path '%s' to '%s'", key, self, value) + attributes.setdefault((key, self.natural_path), value) + + def get( + self, attributes: Dict[Any, Any], key: Any, value: Optional[Any] = None + ) -> Any: + key = (key, self.natural_path) + if key in attributes: + return attributes[key] + else: + return value + + def __len__(self) -> int: + return len(self.path) + + def __hash__(self) -> int: + return id(self) + + @overload + def __getitem__(self, entity: _StrPathToken) -> TokenRegistry: ... + + @overload + def __getitem__(self, entity: int) -> _PathElementType: ... + + @overload + def __getitem__(self, entity: slice) -> _PathRepresentation: ... + + @overload + def __getitem__( + self, entity: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: ... + + @overload + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: ... + + def __getitem__( + self, + entity: Union[ + _StrPathToken, + int, + slice, + _InternalEntityType[Any], + MapperProperty[Any], + ], + ) -> Union[ + TokenRegistry, + _PathElementType, + _PathRepresentation, + PropRegistry, + AbstractEntityRegistry, + ]: + raise NotImplementedError() + + # TODO: what are we using this for? + @property + def length(self) -> int: + return len(self.path) + + def pairs( + self, + ) -> Iterator[ + Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]] + ]: + odd_path = cast(_OddPathRepresentation, self.path) + even_path = cast(_EvenPathRepresentation, odd_path) + for i in range(0, len(odd_path), 2): + yield odd_path[i], even_path[i + 1] + + def contains_mapper(self, mapper: Mapper[Any]) -> bool: + _m_path = cast(_OddPathRepresentation, self.path) + for path_mapper in [_m_path[i] for i in range(0, len(_m_path), 2)]: + if path_mapper.mapper.isa(mapper): + return True + else: + return False + + def contains(self, attributes: Dict[Any, Any], key: Any) -> bool: + return (key, self.path) in attributes + + def __reduce__(self) -> Any: + return _unreduce_path, (self.serialize(),) + + @classmethod + def _serialize_path(cls, path: _PathRepresentation) -> _SerializedPath: + _m_path = cast(_OddPathRepresentation, path) + _p_path = cast(_EvenPathRepresentation, path) + + return list( + zip( + tuple( + m.class_ if (m.is_mapper or m.is_aliased_class) else str(m) + for m in [_m_path[i] for i in range(0, len(_m_path), 2)] + ), + tuple( + p.key if insp_is_mapper_property(p) else str(p) + for p in [_p_path[i] for i in range(1, len(_p_path), 2)] + ) + + (None,), + ) + ) + + @classmethod + def _deserialize_path(cls, path: _SerializedPath) -> _PathRepresentation: + def _deserialize_mapper_token(mcls: Any) -> Any: + return ( + # note: we likely dont want configure=True here however + # this is maintained at the moment for backwards compatibility + orm_base._inspect_mapped_class(mcls, configure=True) + if mcls not in PathToken._intern + else PathToken._intern[mcls] + ) + + def _deserialize_key_token(mcls: Any, key: Any) -> Any: + if key is None: + return None + elif key in PathToken._intern: + return PathToken._intern[key] + else: + mp = orm_base._inspect_mapped_class(mcls, configure=True) + assert mp is not None + return mp.attrs[key] + + p = tuple( + chain( + *[ + ( + _deserialize_mapper_token(mcls), + _deserialize_key_token(mcls, key), + ) + for mcls, key in path + ] + ) + ) + if p and p[-1] is None: + p = p[0:-1] + return p + + def serialize(self) -> _SerializedPath: + path = self.path + return self._serialize_path(path) + + @classmethod + def deserialize(cls, path: _SerializedPath) -> PathRegistry: + assert path is not None + p = cls._deserialize_path(path) + return cls.coerce(p) + + @overload + @classmethod + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: ... + + @overload + @classmethod + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: ... + + @classmethod + def per_mapper( + cls, mapper: _InternalEntityType[Any] + ) -> AbstractEntityRegistry: + if mapper.is_mapper: + return CachingEntityRegistry(cls.root, mapper) + else: + return SlotsEntityRegistry(cls.root, mapper) + + @classmethod + def coerce(cls, raw: _PathRepresentation) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] + + # can't quite get mypy to appreciate this one :) + return reduce(_red, raw, cls.root) # type: ignore + + def __add__(self, other: PathRegistry) -> PathRegistry: + def _red(prev: PathRegistry, next_: _PathElementType) -> PathRegistry: + return prev[next_] + + return reduce(_red, other.path, self) + + def __str__(self) -> str: + return f"ORM Path[{' -> '.join(str(elem) for elem in self.path)}]" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.path!r})" + + +class CreatesToken(PathRegistry): + __slots__ = () + + is_aliased_class: bool + is_root: bool + + def token(self, token: _StrPathToken) -> TokenRegistry: + if token.endswith(f":{_WILDCARD_TOKEN}"): + return TokenRegistry(self, token) + elif token.endswith(f":{_DEFAULT_TOKEN}"): + return TokenRegistry(self.root, token) + else: + raise exc.ArgumentError(f"invalid token: {token}") + + +class RootRegistry(CreatesToken): + """Root registry, defers to mappers so that + paths are maintained per-root-mapper. + + """ + + __slots__ = () + + inherit_cache = True + + path = natural_path = () + has_entity = False + is_aliased_class = False + is_root = True + is_unnatural = False + + def _getitem( + self, entity: Any + ) -> Union[TokenRegistry, AbstractEntityRegistry]: + if entity in PathToken._intern: + if TYPE_CHECKING: + assert isinstance(entity, _StrPathToken) + return TokenRegistry(self, PathToken._intern[entity]) + else: + try: + return entity._path_registry # type: ignore + except AttributeError: + raise IndexError( + f"invalid argument for RootRegistry.__getitem__: {entity}" + ) + + def _truncate_recursive(self) -> RootRegistry: + return self + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +PathRegistry.root = RootRegistry() + + +class PathToken(orm_base.InspectionAttr, HasCacheKey, str): + """cacheable string token""" + + _intern: Dict[str, PathToken] = {} + + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: + return (str(self),) + + @property + def _path_for_compare(self) -> Optional[_PathRepresentation]: + return None + + @classmethod + def intern(cls, strvalue: str) -> PathToken: + if strvalue in cls._intern: + return cls._intern[strvalue] + else: + cls._intern[strvalue] = result = PathToken(strvalue) + return result + + +class TokenRegistry(PathRegistry): + __slots__ = ("token", "parent", "path", "natural_path") + + inherit_cache = True + + token: _StrPathToken + parent: CreatesToken + + def __init__(self, parent: CreatesToken, token: _StrPathToken): + token = PathToken.intern(token) + + self.token = token + self.parent = parent + self.path = parent.path + (token,) + self.natural_path = parent.natural_path + (token,) + + has_entity = False + + is_token = True + + def generate_for_superclasses(self) -> Iterator[PathRegistry]: + # NOTE: this method is no longer used. consider removal + parent = self.parent + if is_root(parent): + yield self + return + + if TYPE_CHECKING: + assert isinstance(parent, AbstractEntityRegistry) + if not parent.is_aliased_class: + for mp_ent in parent.mapper.iterate_to_root(): + yield TokenRegistry(parent.parent[mp_ent], self.token) + elif ( + parent.is_aliased_class + and cast( + "AliasedInsp[Any]", + parent.entity, + )._is_with_polymorphic + ): + yield self + for ent in cast( + "AliasedInsp[Any]", parent.entity + )._with_polymorphic_entities: + yield TokenRegistry(parent.parent[ent], self.token) + else: + yield self + + def _generate_natural_for_superclasses( + self, + ) -> Iterator[_PathRepresentation]: + parent = self.parent + if is_root(parent): + yield self.natural_path + return + + if TYPE_CHECKING: + assert isinstance(parent, AbstractEntityRegistry) + for mp_ent in parent.mapper.iterate_to_root(): + yield TokenRegistry(parent.parent[mp_ent], self.token).natural_path + if ( + parent.is_aliased_class + and cast( + "AliasedInsp[Any]", + parent.entity, + )._is_with_polymorphic + ): + yield self.natural_path + for ent in cast( + "AliasedInsp[Any]", parent.entity + )._with_polymorphic_entities: + yield ( + TokenRegistry(parent.parent[ent], self.token).natural_path + ) + else: + yield self.natural_path + + def _getitem(self, entity: Any) -> Any: + try: + return self.path[entity] + except TypeError as err: + raise IndexError(f"{entity}") from err + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +class PropRegistry(PathRegistry): + __slots__ = ( + "prop", + "parent", + "path", + "natural_path", + "has_entity", + "entity", + "mapper", + "_wildcard_path_loader_key", + "_default_path_loader_key", + "_loader_key", + "is_unnatural", + ) + inherit_cache = True + is_property = True + + prop: MapperProperty[Any] + mapper: Optional[Mapper[Any]] + entity: Optional[_InternalEntityType[Any]] + + def __init__( + self, parent: AbstractEntityRegistry, prop: MapperProperty[Any] + ): + # restate this path in terms of the + # given MapperProperty's parent. + insp = cast("_InternalEntityType[Any]", parent[-1]) + natural_parent: AbstractEntityRegistry = parent + + # inherit "is_unnatural" from the parent + self.is_unnatural = parent.parent.is_unnatural or bool( + parent.mapper.inherits + ) + + if not insp.is_aliased_class or insp._use_mapper_path: # type: ignore + parent = natural_parent = parent.parent[prop.parent] + elif ( + insp.is_aliased_class + and insp.with_polymorphic_mappers + and prop.parent in insp.with_polymorphic_mappers + ): + subclass_entity: _InternalEntityType[Any] = parent[-1]._entity_for_mapper(prop.parent) # type: ignore # noqa: E501 + parent = parent.parent[subclass_entity] + + # when building a path where with_polymorphic() is in use, + # special logic to determine the "natural path" when subclass + # entities are used. + # + # here we are trying to distinguish between a path that starts + # on a the with_polymorhpic entity vs. one that starts on a + # normal entity that introduces a with_polymorphic() in the + # middle using of_type(): + # + # # as in test_polymorphic_rel-> + # # test_subqueryload_on_subclass_uses_path_correctly + # wp = with_polymorphic(RegularEntity, "*") + # sess.query(wp).options(someload(wp.SomeSubEntity.foos)) + # + # vs + # + # # as in test_relationship->JoinedloadWPolyOfTypeContinued + # wp = with_polymorphic(SomeFoo, "*") + # sess.query(RegularEntity).options( + # someload(RegularEntity.foos.of_type(wp)) + # .someload(wp.SubFoo.bar) + # ) + # + # in the former case, the Query as it generates a path that we + # want to match will be in terms of the with_polymorphic at the + # beginning. in the latter case, Query will generate simple + # paths that don't know about this with_polymorphic, so we must + # use a separate natural path. + # + # + if parent.parent: + natural_parent = parent.parent[subclass_entity.mapper] + self.is_unnatural = True + else: + natural_parent = parent + elif ( + natural_parent.parent + and insp.is_aliased_class + and prop.parent # this should always be the case here + is not insp.mapper + and insp.mapper.isa(prop.parent) + ): + natural_parent = parent.parent[prop.parent] + + self.prop = prop + self.parent = parent + self.path = parent.path + (prop,) + self.natural_path = natural_parent.natural_path + (prop,) + + self.has_entity = prop._links_to_entity + if prop._is_relationship: + if TYPE_CHECKING: + assert isinstance(prop, RelationshipProperty) + self.entity = prop.entity + self.mapper = prop.mapper + else: + self.entity = None + self.mapper = None + + self._wildcard_path_loader_key = ( + "loader", + parent.natural_path + self.prop._wildcard_token, + ) + self._default_path_loader_key = self.prop._default_path_loader_key + self._loader_key = ("loader", self.natural_path) + + def _truncate_recursive(self) -> PropRegistry: + earliest = None + for i, token in enumerate(reversed(self.path[:-1])): + if token is self.prop: + earliest = i + + if earliest is None: + return self + else: + return self.coerce(self.path[0 : -(earliest + 1)]) # type: ignore + + @property + def entity_path(self) -> AbstractEntityRegistry: + assert self.entity is not None + return self[self.entity] + + def _getitem( + self, entity: Union[int, slice, _InternalEntityType[Any]] + ) -> Union[AbstractEntityRegistry, _PathElementType, _PathRepresentation]: + if isinstance(entity, (int, slice)): + return self.path[entity] + else: + return SlotsEntityRegistry(self, entity) + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +class AbstractEntityRegistry(CreatesToken): + __slots__ = ( + "key", + "parent", + "is_aliased_class", + "path", + "entity", + "natural_path", + ) + + has_entity = True + is_entity = True + + parent: Union[RootRegistry, PropRegistry] + key: _InternalEntityType[Any] + entity: _InternalEntityType[Any] + is_aliased_class: bool + + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): + self.key = entity + self.parent = parent + self.is_aliased_class = entity.is_aliased_class + self.entity = entity + self.path = parent.path + (entity,) + + # the "natural path" is the path that we get when Query is traversing + # from the lead entities into the various relationships; it corresponds + # to the structure of mappers and relationships. when we are given a + # path that comes from loader options, as of 1.3 it can have ac-hoc + # with_polymorphic() and other AliasedInsp objects inside of it, which + # are usually not present in mappings. So here we track both the + # "enhanced" path in self.path and the "natural" path that doesn't + # include those objects so these two traversals can be matched up. + + # the test here for "(self.is_aliased_class or parent.is_unnatural)" + # are to avoid the more expensive conditional logic that follows if we + # know we don't have to do it. This conditional can just as well be + # "if parent.path:", it just is more function calls. + # + # This is basically the only place that the "is_unnatural" flag + # actually changes behavior. + if parent.path and (self.is_aliased_class or parent.is_unnatural): + # this is an infrequent code path used only for loader strategies + # that also make use of of_type(). + if entity.mapper.isa(parent.natural_path[-1].mapper): # type: ignore # noqa: E501 + self.natural_path = parent.natural_path + (entity.mapper,) + else: + self.natural_path = parent.natural_path + ( + parent.natural_path[-1].entity, # type: ignore + ) + # it seems to make sense that since these paths get mixed up + # with statements that are cached or not, we should make + # sure the natural path is cacheable across different occurrences + # of equivalent AliasedClass objects. however, so far this + # does not seem to be needed for whatever reason. + # elif not parent.path and self.is_aliased_class: + # self.natural_path = (self.entity._generate_cache_key()[0], ) + else: + self.natural_path = self.path + + def _truncate_recursive(self) -> AbstractEntityRegistry: + return self.parent._truncate_recursive()[self.entity] + + @property + def root_entity(self) -> _InternalEntityType[Any]: + return self.odd_element(0) + + @property + def entity_path(self) -> PathRegistry: + return self + + @property + def mapper(self) -> Mapper[Any]: + return self.entity.mapper + + def __bool__(self) -> bool: + return True + + def _getitem( + self, entity: Any + ) -> Union[_PathElementType, _PathRepresentation, PathRegistry]: + if isinstance(entity, (int, slice)): + return self.path[entity] + elif entity in PathToken._intern: + return TokenRegistry(self, PathToken._intern[entity]) + else: + return PropRegistry(self, entity) + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +class SlotsEntityRegistry(AbstractEntityRegistry): + # for aliased class, return lightweight, no-cycles created + # version + inherit_cache = True + + +class _ERDict(Dict[Any, Any]): + def __init__(self, registry: CachingEntityRegistry): + self.registry = registry + + def __missing__(self, key: Any) -> PropRegistry: + self[key] = item = PropRegistry(self.registry, key) + + return item + + +class CachingEntityRegistry(AbstractEntityRegistry): + # for long lived mapper, return dict based caching + # version that creates reference cycles + + __slots__ = ("_cache",) + + inherit_cache = True + + def __init__( + self, + parent: Union[RootRegistry, PropRegistry], + entity: _InternalEntityType[Any], + ): + super().__init__(parent, entity) + self._cache = _ERDict(self) + + def pop(self, key: Any, default: Any) -> Any: + return self._cache.pop(key, default) + + def _getitem(self, entity: Any) -> Any: + if isinstance(entity, (int, slice)): + return self.path[entity] + elif isinstance(entity, PathToken): + return TokenRegistry(self, entity) + else: + return self._cache[entity] + + if not TYPE_CHECKING: + __getitem__ = _getitem + + +if TYPE_CHECKING: + + def path_is_entity( + path: PathRegistry, + ) -> TypeGuard[AbstractEntityRegistry]: ... + + def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: ... + +else: + path_is_entity = operator.attrgetter("is_entity") + path_is_property = operator.attrgetter("is_property") diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py new file mode 100644 index 0000000..369fc59 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py @@ -0,0 +1,1782 @@ +# orm/persistence.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" +from __future__ import annotations + +from itertools import chain +from itertools import groupby +from itertools import zip_longest +import operator + +from . import attributes +from . import exc as orm_exc +from . import loading +from . import sync +from .base import state_str +from .. import exc as sa_exc +from .. import future +from .. import sql +from .. import util +from ..engine import cursor as _cursor +from ..sql import operators +from ..sql.elements import BooleanClauseList +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL + + +def save_obj(base_mapper, states, uowtransaction, single=False): + """Issue ``INSERT`` and/or ``UPDATE`` statements for a list + of objects. + + This is called within the context of a UOWTransaction during a + flush operation, given a list of states to be flushed. The + base mapper in an inheritance hierarchy handles the inserts/ + updates for all descendant mappers. + + """ + + # if batch=false, call _save_obj separately for each object + if not single and not base_mapper.batch: + for state in _sort_states(base_mapper, states): + save_obj(base_mapper, [state], uowtransaction, single=True) + return + + states_to_update = [] + states_to_insert = [] + + for ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) in _organize_states_for_save(base_mapper, states, uowtransaction): + if has_identity or row_switch: + states_to_update.append( + (state, dict_, mapper, connection, update_version_id) + ) + else: + states_to_insert.append((state, dict_, mapper, connection)) + + for table, mapper in base_mapper._sorted_tables.items(): + if table not in mapper._pks_by_table: + continue + insert = _collect_insert_commands(table, states_to_insert) + + update = _collect_update_commands( + uowtransaction, table, states_to_update + ) + + _emit_update_statements( + base_mapper, + uowtransaction, + mapper, + table, + update, + ) + + _emit_insert_statements( + base_mapper, + uowtransaction, + mapper, + table, + insert, + ) + + _finalize_insert_update_commands( + base_mapper, + uowtransaction, + chain( + ( + (state, state_dict, mapper, connection, False) + for (state, state_dict, mapper, connection) in states_to_insert + ), + ( + (state, state_dict, mapper, connection, True) + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update + ), + ), + ) + + +def post_update(base_mapper, states, uowtransaction, post_update_cols): + """Issue UPDATE statements on behalf of a relationship() which + specifies post_update. + + """ + + states_to_update = list( + _organize_states_for_post_update(base_mapper, states, uowtransaction) + ) + + for table, mapper in base_mapper._sorted_tables.items(): + if table not in mapper._pks_by_table: + continue + + update = ( + ( + state, + state_dict, + sub_mapper, + connection, + ( + mapper._get_committed_state_attr_by_column( + state, state_dict, mapper.version_id_col + ) + if mapper.version_id_col is not None + else None + ), + ) + for state, state_dict, sub_mapper, connection in states_to_update + if table in sub_mapper._pks_by_table + ) + + update = _collect_post_update_commands( + base_mapper, uowtransaction, table, update, post_update_cols + ) + + _emit_post_update_statements( + base_mapper, + uowtransaction, + mapper, + table, + update, + ) + + +def delete_obj(base_mapper, states, uowtransaction): + """Issue ``DELETE`` statements for a list of objects. + + This is called within the context of a UOWTransaction during a + flush operation. + + """ + + states_to_delete = list( + _organize_states_for_delete(base_mapper, states, uowtransaction) + ) + + table_to_mapper = base_mapper._sorted_tables + + for table in reversed(list(table_to_mapper.keys())): + mapper = table_to_mapper[table] + if table not in mapper._pks_by_table: + continue + elif mapper.inherits and mapper.passive_deletes: + continue + + delete = _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete + ) + + _emit_delete_statements( + base_mapper, + uowtransaction, + mapper, + table, + delete, + ) + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: + mapper.dispatch.after_delete(mapper, connection, state) + + +def _organize_states_for_save(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for INSERT or + UPDATE. + + This includes splitting out into distinct lists for + each, calling before_insert/before_update, obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state, + and the identity flag. + + """ + + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, states + ): + has_identity = bool(state.key) + + instance_key = state.key or mapper._identity_key_from_state(state) + + row_switch = update_version_id = None + + # call before_XXX extensions + if not has_identity: + mapper.dispatch.before_insert(mapper, connection, state) + else: + mapper.dispatch.before_update(mapper, connection, state) + + if mapper._validate_polymorphic_identity: + mapper._validate_polymorphic_identity(mapper, state, dict_) + + # detect if we have a "pending" instance (i.e. has + # no instance_key attached to it), and another instance + # with the same identity key already exists as persistent. + # convert to an UPDATE if so. + if ( + not has_identity + and instance_key in uowtransaction.session.identity_map + ): + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) + + if not uowtransaction.was_already_deleted(existing): + if not uowtransaction.is_deleted(existing): + util.warn( + "New instance %s with identity key %s conflicts " + "with persistent instance %s" + % (state_str(state), instance_key, state_str(existing)) + ) + else: + base_mapper._log_debug( + "detected row switch for identity %s. " + "will update %s, remove %s from " + "transaction", + instance_key, + state_str(state), + state_str(existing), + ) + + # remove the "delete" flag from the existing element + uowtransaction.remove_state_actions(existing) + row_switch = existing + + if (has_identity or row_switch) and mapper.version_id_col is not None: + update_version_id = mapper._get_committed_state_attr_by_column( + row_switch if row_switch else state, + row_switch.dict if row_switch else dict_, + mapper.version_id_col, + ) + + yield ( + state, + dict_, + mapper, + connection, + has_identity, + row_switch, + update_version_id, + ) + + +def _organize_states_for_post_update(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for UPDATE + corresponding to post_update. + + This includes obtaining key information for each state + including its dictionary, mapper, the connection to use for + the execution per state. + + """ + return _connections_for_states(base_mapper, uowtransaction, states) + + +def _organize_states_for_delete(base_mapper, states, uowtransaction): + """Make an initial pass across a set of states for DELETE. + + This includes calling out before_delete and obtaining + key information for each state including its dictionary, + mapper, the connection to use for the execution per state. + + """ + for state, dict_, mapper, connection in _connections_for_states( + base_mapper, uowtransaction, states + ): + mapper.dispatch.before_delete(mapper, connection, state) + + if mapper.version_id_col is not None: + update_version_id = mapper._get_committed_state_attr_by_column( + state, dict_, mapper.version_id_col + ) + else: + update_version_id = None + + yield (state, dict_, mapper, connection, update_version_id) + + +def _collect_insert_commands( + table, + states_to_insert, + *, + bulk=False, + return_defaults=False, + render_nulls=False, + include_bulk_keys=(), +): + """Identify sets of values to use in INSERT statements for a + list of states. + + """ + for state, state_dict, mapper, connection in states_to_insert: + if table not in mapper._pks_by_table: + continue + + params = {} + value_params = {} + + propkey_to_col = mapper._propkey_to_col[table] + + eval_none = mapper._insert_cols_evaluating_none[table] + + for propkey in set(propkey_to_col).intersection(state_dict): + value = state_dict[propkey] + col = propkey_to_col[propkey] + if value is None and col not in eval_none and not render_nulls: + continue + elif not bulk and ( + hasattr(value, "__clause_element__") + or isinstance(value, sql.ClauseElement) + ): + value_params[col] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) + else: + params[col.key] = value + + if not bulk: + # for all the columns that have no default and we don't have + # a value and where "None" is not a special value, add + # explicit None to the INSERT. This is a legacy behavior + # which might be worth removing, as it should not be necessary + # and also produces confusion, given that "missing" and None + # now have distinct meanings + for colkey in ( + mapper._insert_cols_as_none[table] + .difference(params) + .difference([c.key for c in value_params]) + ): + params[colkey] = None + + if not bulk or return_defaults: + # params are in terms of Column key objects, so + # compare to pk_keys_by_table + has_all_pks = mapper._pk_keys_by_table[table].issubset(params) + + if mapper.base_mapper._prefer_eager_defaults( + connection.dialect, table + ): + has_all_defaults = mapper._server_default_col_keys[ + table + ].issubset(params) + else: + has_all_defaults = True + else: + has_all_defaults = has_all_pks = True + + if ( + mapper.version_id_generator is not False + and mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + params[mapper.version_id_col.key] = mapper.version_id_generator( + None + ) + + if bulk: + if mapper._set_polymorphic_identity: + params.setdefault( + mapper._polymorphic_attr_key, mapper.polymorphic_identity + ) + + if include_bulk_keys: + params.update((k, state_dict[k]) for k in include_bulk_keys) + + yield ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) + + +def _collect_update_commands( + uowtransaction, + table, + states_to_update, + *, + bulk=False, + use_orm_update_stmt=None, + include_bulk_keys=(), +): + """Identify sets of values to use in UPDATE statements for a + list of states. + + This function works intricately with the history system + to determine exactly what values should be updated + as well as how the row should be matched within an UPDATE + statement. Includes some tricky scenarios where the primary + key of an object might have been changed. + + """ + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: + if table not in mapper._pks_by_table: + continue + + pks = mapper._pks_by_table[table] + + if use_orm_update_stmt is not None: + # TODO: ordered values, etc + value_params = use_orm_update_stmt._values + else: + value_params = {} + + propkey_to_col = mapper._propkey_to_col[table] + + if bulk: + # keys here are mapped attribute keys, so + # look at mapper attribute keys for pk + params = { + propkey_to_col[propkey].key: state_dict[propkey] + for propkey in set(propkey_to_col) + .intersection(state_dict) + .difference(mapper._pk_attr_keys_by_table[table]) + } + has_all_defaults = True + else: + params = {} + for propkey in set(propkey_to_col).intersection( + state.committed_state + ): + value = state_dict[propkey] + col = propkey_to_col[propkey] + + if hasattr(value, "__clause_element__") or isinstance( + value, sql.ClauseElement + ): + value_params[col] = ( + value.__clause_element__() + if hasattr(value, "__clause_element__") + else value + ) + # guard against values that generate non-__nonzero__ + # objects for __eq__() + elif ( + state.manager[propkey].impl.is_equal( + value, state.committed_state[propkey] + ) + is not True + ): + params[col.key] = value + + if mapper.base_mapper.eager_defaults is True: + has_all_defaults = ( + mapper._server_onupdate_default_col_keys[table] + ).issubset(params) + else: + has_all_defaults = True + + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + if not bulk and not (params or value_params): + # HACK: check for history in other tables, in case the + # history is only in a different table than the one + # where the version_id_col is. This logic was lost + # from 0.9 -> 1.0.0 and restored in 1.0.6. + for prop in mapper._columntoproperty.values(): + history = state.manager[prop.key].impl.get_history( + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + break + else: + # no net change, break + continue + + col = mapper.version_id_col + no_params = not params and not value_params + params[col._label] = update_version_id + + if ( + bulk or col.key not in params + ) and mapper.version_id_generator is not False: + val = mapper.version_id_generator(update_version_id) + params[col.key] = val + elif mapper.version_id_generator is False and no_params: + # no version id generator, no values set on the table, + # and version id wasn't manually incremented. + # set version id to itself so we get an UPDATE + # statement + params[col.key] = update_version_id + + elif not (params or value_params): + continue + + has_all_pks = True + expect_pk_cascaded = False + if bulk: + # keys here are mapped attribute keys, so + # look at mapper attribute keys for pk + pk_params = { + propkey_to_col[propkey]._label: state_dict.get(propkey) + for propkey in set(propkey_to_col).intersection( + mapper._pk_attr_keys_by_table[table] + ) + } + if util.NONE_SET.intersection(pk_params.values()): + raise sa_exc.InvalidRequestError( + f"No primary key value supplied for column(s) " + f"""{ + ', '.join( + str(c) for c in pks if pk_params[c._label] is None + ) + }; """ + "per-row ORM Bulk UPDATE by Primary Key requires that " + "records contain primary key values", + code="bupq", + ) + + else: + pk_params = {} + for col in pks: + propkey = mapper._columntoproperty[col].key + + history = state.manager[propkey].impl.get_history( + state, state_dict, attributes.PASSIVE_OFF + ) + + if history.added: + if ( + not history.deleted + or ("pk_cascaded", state, col) + in uowtransaction.attributes + ): + expect_pk_cascaded = True + pk_params[col._label] = history.added[0] + params.pop(col.key, None) + else: + # else, use the old value to locate the row + pk_params[col._label] = history.deleted[0] + if col in value_params: + has_all_pks = False + else: + pk_params[col._label] = history.unchanged[0] + if pk_params[col._label] is None: + raise orm_exc.FlushError( + "Can't update table %s using NULL for primary " + "key value on column %s" % (table, col) + ) + + if include_bulk_keys: + params.update((k, state_dict[k]) for k in include_bulk_keys) + + if params or value_params: + params.update(pk_params) + yield ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) + elif expect_pk_cascaded: + # no UPDATE occurs on this table, but we expect that CASCADE rules + # have changed the primary key of the row; propagate this event to + # other columns that expect to have been modified. this normally + # occurs after the UPDATE is emitted however we invoke it here + # explicitly in the absence of our invoking an UPDATE + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate( + state, + m, + state, + m, + equated_pairs, + uowtransaction, + mapper.passive_updates, + ) + + +def _collect_post_update_commands( + base_mapper, uowtransaction, table, states_to_update, post_update_cols +): + """Identify sets of values to use in UPDATE statements for a + list of states within a post_update operation. + + """ + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_update: + # assert table in mapper._pks_by_table + + pks = mapper._pks_by_table[table] + params = {} + hasdata = False + + for col in mapper._cols_by_table[table]: + if col in pks: + params[col._label] = mapper._get_state_attr_by_column( + state, state_dict, col, passive=attributes.PASSIVE_OFF + ) + + elif col in post_update_cols or col.onupdate is not None: + prop = mapper._columntoproperty[col] + history = state.manager[prop.key].impl.get_history( + state, state_dict, attributes.PASSIVE_NO_INITIALIZE + ) + if history.added: + value = history.added[0] + params[col.key] = value + hasdata = True + if hasdata: + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + col = mapper.version_id_col + params[col._label] = update_version_id + + if ( + bool(state.key) + and col.key not in params + and mapper.version_id_generator is not False + ): + val = mapper.version_id_generator(update_version_id) + params[col.key] = val + yield state, state_dict, mapper, connection, params + + +def _collect_delete_commands( + base_mapper, uowtransaction, table, states_to_delete +): + """Identify values to use in DELETE statements for a list of + states to be deleted.""" + + for ( + state, + state_dict, + mapper, + connection, + update_version_id, + ) in states_to_delete: + if table not in mapper._pks_by_table: + continue + + params = {} + for col in mapper._pks_by_table[table]: + params[col.key] = value = ( + mapper._get_committed_state_attr_by_column( + state, state_dict, col + ) + ) + if value is None: + raise orm_exc.FlushError( + "Can't delete from table %s " + "using NULL for primary " + "key value on column %s" % (table, col) + ) + + if ( + update_version_id is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + params[mapper.version_id_col.key] = update_version_id + yield params, connection + + +def _emit_update_statements( + base_mapper, + uowtransaction, + mapper, + table, + update, + *, + bookkeeping=True, + use_orm_update_stmt=None, + enable_check_rowcount=True, +): + """Emit UPDATE statements corresponding to value lists collected + by _collect_update_commands().""" + + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) + + execution_options = {"compiled_cache": base_mapper._compiled_cache} + + def update_stmt(existing_stmt=None): + clauses = BooleanClauseList._construct_raw(operators.and_) + + for col in mapper._pks_by_table[table]: + clauses._append_inplace( + col == sql.bindparam(col._label, type_=col.type) + ) + + if needs_version_id: + clauses._append_inplace( + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type, + ) + ) + + if existing_stmt is not None: + stmt = existing_stmt.where(clauses) + else: + stmt = table.update().where(clauses) + return stmt + + if use_orm_update_stmt is not None: + cached_stmt = update_stmt(use_orm_update_stmt) + + else: + cached_stmt = base_mapper._memo(("update", table), update_stmt) + + for ( + (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), + records, + ) in groupby( + update, + lambda rec: ( + rec[4], # connection + set(rec[2]), # set of parameter keys + bool(rec[5]), # whether or not we have "value" parameters + rec[6], # has_all_defaults + rec[7], # has all pks + ), + ): + rows = 0 + records = list(records) + + statement = cached_stmt + + if use_orm_update_stmt is not None: + statement = statement._annotate( + { + "_emit_update_table": table, + "_emit_update_mapper": mapper, + } + ) + + return_defaults = False + + if not has_all_pks: + statement = statement.return_defaults(*mapper._pks_by_table[table]) + return_defaults = True + + if ( + bookkeeping + and not has_all_defaults + and mapper.base_mapper.eager_defaults is True + # change as of #8889 - if RETURNING is not going to be used anyway, + # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure + # we can do an executemany UPDATE which is more efficient + and table.implicit_returning + and connection.dialect.update_returning + ): + statement = statement.return_defaults( + *mapper._server_onupdate_default_cols[table] + ) + return_defaults = True + + if mapper._version_id_has_server_side_value: + statement = statement.return_defaults(mapper.version_id_col) + return_defaults = True + + assert_singlerow = connection.dialect.supports_sane_rowcount + + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) + + # change as of #8889 - if RETURNING is not going to be used anyway, + # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure + # we can do an executemany UPDATE which is more efficient + allow_executemany = not return_defaults and not needs_version_id + + if hasvalue: + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = connection.execute( + statement.values(value_params), + params, + execution_options=execution_options, + ) + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params, + True, + c.returned_defaults, + ) + rows += c.rowcount + check_rowcount = enable_check_rowcount and assert_singlerow + else: + if not allow_executemany: + check_rowcount = enable_check_rowcount and assert_singlerow + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + c = connection.execute( + statement, params, execution_options=execution_options + ) + + # TODO: why with bookkeeping=False? + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params, + True, + c.returned_defaults, + ) + rows += c.rowcount + else: + multiparams = [rec[2] for rec in records] + + check_rowcount = enable_check_rowcount and ( + assert_multirow + or (assert_singlerow and len(multiparams) == 1) + ) + + c = connection.execute( + statement, multiparams, execution_options=execution_options + ) + + rows += c.rowcount + + for ( + state, + state_dict, + params, + mapper, + connection, + value_params, + has_all_defaults, + has_all_pks, + ) in records: + if bookkeeping: + _postfetch( + mapper, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + value_params, + True, + ( + c.returned_defaults + if not c.context.executemany + else None + ), + ) + + if check_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) + + elif needs_version_id: + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) + + +def _emit_insert_statements( + base_mapper, + uowtransaction, + mapper, + table, + insert, + *, + bookkeeping=True, + use_orm_insert_stmt=None, + execution_options=None, +): + """Emit INSERT statements corresponding to value lists collected + by _collect_insert_commands().""" + + if use_orm_insert_stmt is not None: + cached_stmt = use_orm_insert_stmt + exec_opt = util.EMPTY_DICT + + # if a user query with RETURNING was passed, we definitely need + # to use RETURNING. + returning_is_required_anyway = bool(use_orm_insert_stmt._returning) + deterministic_results_reqd = ( + returning_is_required_anyway + and use_orm_insert_stmt._sort_by_parameter_order + ) or bookkeeping + else: + returning_is_required_anyway = False + deterministic_results_reqd = bookkeeping + cached_stmt = base_mapper._memo(("insert", table), table.insert) + exec_opt = {"compiled_cache": base_mapper._compiled_cache} + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + exec_opt, execution_options + ) + else: + execution_options = exec_opt + + return_result = None + + for ( + (connection, _, hasvalue, has_all_pks, has_all_defaults), + records, + ) in groupby( + insert, + lambda rec: ( + rec[4], # connection + set(rec[2]), # parameter keys + bool(rec[5]), # whether we have "value" parameters + rec[6], + rec[7], + ), + ): + statement = cached_stmt + + if use_orm_insert_stmt is not None: + statement = statement._annotate( + { + "_emit_insert_table": table, + "_emit_insert_mapper": mapper, + } + ) + + if ( + ( + not bookkeeping + or ( + has_all_defaults + or not base_mapper._prefer_eager_defaults( + connection.dialect, table + ) + or not table.implicit_returning + or not connection.dialect.insert_returning + ) + ) + and not returning_is_required_anyway + and has_all_pks + and not hasvalue + ): + # the "we don't need newly generated values back" section. + # here we have all the PKs, all the defaults or we don't want + # to fetch them, or the dialect doesn't support RETURNING at all + # so we have to post-fetch / use lastrowid anyway. + records = list(records) + multiparams = [rec[2] for rec in records] + + result = connection.execute( + statement, multiparams, execution_options=execution_options + ) + if bookkeeping: + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + ) in zip(records, result.context.compiled_parameters): + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + last_inserted_params, + value_params, + False, + ( + result.returned_defaults + if not result.context.executemany + else None + ), + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + + else: + # here, we need defaults and/or pk values back or we otherwise + # know that we are using RETURNING in any case + + records = list(records) + + if returning_is_required_anyway or ( + table.implicit_returning and not hasvalue and len(records) > 1 + ): + if ( + deterministic_results_reqd + and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501 + ) or ( + not deterministic_results_reqd + and connection.dialect.insert_executemany_returning + ): + do_executemany = True + elif returning_is_required_anyway: + if deterministic_results_reqd: + dt = " with RETURNING and sort by parameter order" + else: + dt = " with RETURNING" + raise sa_exc.InvalidRequestError( + f"Can't use explicit RETURNING for bulk INSERT " + f"operation with " + f"{connection.dialect.dialect_description} backend; " + f"executemany{dt} is not enabled for this dialect." + ) + else: + do_executemany = False + else: + do_executemany = False + + if use_orm_insert_stmt is None: + if ( + not has_all_defaults + and base_mapper._prefer_eager_defaults( + connection.dialect, table + ) + ): + statement = statement.return_defaults( + *mapper._server_default_cols[table], + sort_by_parameter_order=bookkeeping, + ) + + if mapper.version_id_col is not None: + statement = statement.return_defaults( + mapper.version_id_col, + sort_by_parameter_order=bookkeeping, + ) + elif do_executemany: + statement = statement.return_defaults( + *table.primary_key, sort_by_parameter_order=bookkeeping + ) + + if do_executemany: + multiparams = [rec[2] for rec in records] + + result = connection.execute( + statement, multiparams, execution_options=execution_options + ) + + if use_orm_insert_stmt is not None: + if return_result is None: + return_result = result + else: + return_result = return_result.splice_vertically(result) + + if bookkeeping: + for ( + ( + state, + state_dict, + params, + mapper_rec, + conn, + value_params, + has_all_pks, + has_all_defaults, + ), + last_inserted_params, + inserted_primary_key, + returned_defaults, + ) in zip_longest( + records, + result.context.compiled_parameters, + result.inserted_primary_key_rows, + result.returned_defaults_rows or (), + ): + if inserted_primary_key is None: + # this is a real problem and means that we didn't + # get back as many PK rows. we can't continue + # since this indicates PK rows were missing, which + # means we likely mis-populated records starting + # at that point with incorrectly matched PK + # values. + raise orm_exc.FlushError( + "Multi-row INSERT statement for %s did not " + "produce " + "the correct number of INSERTed rows for " + "RETURNING. Ensure there are no triggers or " + "special driver issues preventing INSERT from " + "functioning properly." % mapper_rec + ) + + for pk, col in zip( + inserted_primary_key, + mapper._pks_by_table[table], + ): + prop = mapper_rec._columntoproperty[col] + if state_dict.get(prop.key) is None: + state_dict[prop.key] = pk + + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + last_inserted_params, + value_params, + False, + returned_defaults, + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + else: + assert not returning_is_required_anyway + + for ( + state, + state_dict, + params, + mapper_rec, + connection, + value_params, + has_all_pks, + has_all_defaults, + ) in records: + if value_params: + result = connection.execute( + statement.values(value_params), + params, + execution_options=execution_options, + ) + else: + result = connection.execute( + statement, + params, + execution_options=execution_options, + ) + + primary_key = result.inserted_primary_key + if primary_key is None: + raise orm_exc.FlushError( + "Single-row INSERT statement for %s " + "did not produce a " + "new primary key result " + "being invoked. Ensure there are no triggers or " + "special driver issues preventing INSERT from " + "functioning properly." % (mapper_rec,) + ) + for pk, col in zip( + primary_key, mapper._pks_by_table[table] + ): + prop = mapper_rec._columntoproperty[col] + if ( + col in value_params + or state_dict.get(prop.key) is None + ): + state_dict[prop.key] = pk + if bookkeeping: + if state: + _postfetch( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + result, + result.context.compiled_parameters[0], + value_params, + False, + ( + result.returned_defaults + if not result.context.executemany + else None + ), + ) + else: + _postfetch_bulk_save(mapper_rec, state_dict, table) + + if use_orm_insert_stmt is not None: + if return_result is None: + return _cursor.null_dml_result() + else: + return return_result + + +def _emit_post_update_statements( + base_mapper, uowtransaction, mapper, table, update +): + """Emit UPDATE statements corresponding to value lists collected + by _collect_post_update_commands().""" + + execution_options = {"compiled_cache": base_mapper._compiled_cache} + + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) + + def update_stmt(): + clauses = BooleanClauseList._construct_raw(operators.and_) + + for col in mapper._pks_by_table[table]: + clauses._append_inplace( + col == sql.bindparam(col._label, type_=col.type) + ) + + if needs_version_id: + clauses._append_inplace( + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col._label, + type_=mapper.version_id_col.type, + ) + ) + + stmt = table.update().where(clauses) + + return stmt + + statement = base_mapper._memo(("post_update", table), update_stmt) + + if mapper._version_id_has_server_side_value: + statement = statement.return_defaults(mapper.version_id_col) + + # execute each UPDATE in the order according to the original + # list of states to guarantee row access order, but + # also group them into common (connection, cols) sets + # to support executemany(). + for key, records in groupby( + update, + lambda rec: (rec[3], set(rec[4])), # connection # parameter keys + ): + rows = 0 + + records = list(records) + connection = key[0] + + assert_singlerow = connection.dialect.supports_sane_rowcount + assert_multirow = ( + assert_singlerow + and connection.dialect.supports_sane_multi_rowcount + ) + allow_executemany = not needs_version_id or assert_multirow + + if not allow_executemany: + check_rowcount = assert_singlerow + for state, state_dict, mapper_rec, connection, params in records: + c = connection.execute( + statement, params, execution_options=execution_options + ) + + _postfetch_post_update( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) + rows += c.rowcount + else: + multiparams = [ + params + for state, state_dict, mapper_rec, conn, params in records + ] + + check_rowcount = assert_multirow or ( + assert_singlerow and len(multiparams) == 1 + ) + + c = connection.execute( + statement, multiparams, execution_options=execution_options + ) + + rows += c.rowcount + for state, state_dict, mapper_rec, connection, params in records: + _postfetch_post_update( + mapper_rec, + uowtransaction, + table, + state, + state_dict, + c, + c.context.compiled_parameters[0], + ) + + if check_rowcount: + if rows != len(records): + raise orm_exc.StaleDataError( + "UPDATE statement on table '%s' expected to " + "update %d row(s); %d were matched." + % (table.description, len(records), rows) + ) + + elif needs_version_id: + util.warn( + "Dialect %s does not support updated rowcount " + "- versioning cannot be verified." + % c.dialect.dialect_description + ) + + +def _emit_delete_statements( + base_mapper, uowtransaction, mapper, table, delete +): + """Emit DELETE statements corresponding to value lists collected + by _collect_delete_commands().""" + + need_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) + + def delete_stmt(): + clauses = BooleanClauseList._construct_raw(operators.and_) + + for col in mapper._pks_by_table[table]: + clauses._append_inplace( + col == sql.bindparam(col.key, type_=col.type) + ) + + if need_version_id: + clauses._append_inplace( + mapper.version_id_col + == sql.bindparam( + mapper.version_id_col.key, type_=mapper.version_id_col.type + ) + ) + + return table.delete().where(clauses) + + statement = base_mapper._memo(("delete", table), delete_stmt) + for connection, recs in groupby(delete, lambda rec: rec[1]): # connection + del_objects = [params for params, connection in recs] + + execution_options = {"compiled_cache": base_mapper._compiled_cache} + expected = len(del_objects) + rows_matched = -1 + only_warn = False + + if ( + need_version_id + and not connection.dialect.supports_sane_multi_rowcount + ): + if connection.dialect.supports_sane_rowcount: + rows_matched = 0 + # execute deletes individually so that versioned + # rows can be verified + for params in del_objects: + c = connection.execute( + statement, params, execution_options=execution_options + ) + rows_matched += c.rowcount + else: + util.warn( + "Dialect %s does not support deleted rowcount " + "- versioning cannot be verified." + % connection.dialect.dialect_description + ) + connection.execute( + statement, del_objects, execution_options=execution_options + ) + else: + c = connection.execute( + statement, del_objects, execution_options=execution_options + ) + + if not need_version_id: + only_warn = True + + rows_matched = c.rowcount + + if ( + base_mapper.confirm_deleted_rows + and rows_matched > -1 + and expected != rows_matched + and ( + connection.dialect.supports_sane_multi_rowcount + or len(del_objects) == 1 + ) + ): + # TODO: why does this "only warn" if versioning is turned off, + # whereas the UPDATE raises? + if only_warn: + util.warn( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched. Please set " + "confirm_deleted_rows=False within the mapper " + "configuration to prevent this warning." + % (table.description, expected, rows_matched) + ) + else: + raise orm_exc.StaleDataError( + "DELETE statement on table '%s' expected to " + "delete %d row(s); %d were matched. Please set " + "confirm_deleted_rows=False within the mapper " + "configuration to prevent this warning." + % (table.description, expected, rows_matched) + ) + + +def _finalize_insert_update_commands(base_mapper, uowtransaction, states): + """finalize state on states that have been inserted or updated, + including calling after_insert/after_update events. + + """ + for state, state_dict, mapper, connection, has_identity in states: + if mapper._readonly_props: + readonly = state.unmodified_intersection( + [ + p.key + for p in mapper._readonly_props + if ( + p.expire_on_flush + and (not p.deferred or p.key in state.dict) + ) + or ( + not p.expire_on_flush + and not p.deferred + and p.key not in state.dict + ) + ] + ) + if readonly: + state._expire_attributes(state.dict, readonly) + + # if eager_defaults option is enabled, load + # all expired cols. Else if we have a version_id_col, make sure + # it isn't expired. + toload_now = [] + + # this is specifically to emit a second SELECT for eager_defaults, + # so only if it's set to True, not "auto" + if base_mapper.eager_defaults is True: + toload_now.extend( + state._unloaded_non_object.intersection( + mapper._server_default_plus_onupdate_propkeys + ) + ) + + if ( + mapper.version_id_col is not None + and mapper.version_id_generator is False + ): + if mapper._version_id_prop.key in state.unloaded: + toload_now.extend([mapper._version_id_prop.key]) + + if toload_now: + state.key = base_mapper._identity_key_from_state(state) + stmt = future.select(mapper).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + loading.load_on_ident( + uowtransaction.session, + stmt, + state.key, + refresh_state=state, + only_load_props=toload_now, + ) + + # call after_XXX extensions + if not has_identity: + mapper.dispatch.after_insert(mapper, connection, state) + else: + mapper.dispatch.after_update(mapper, connection, state) + + if ( + mapper.version_id_generator is False + and mapper.version_id_col is not None + ): + if state_dict[mapper._version_id_prop.key] is None: + raise orm_exc.FlushError( + "Instance does not contain a non-NULL version value" + ) + + +def _postfetch_post_update( + mapper, uowtransaction, table, state, dict_, result, params +): + needs_version_id = ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ) + + if not uowtransaction.is_deleted(state): + # post updating after a regular INSERT or UPDATE, do a full postfetch + prefetch_cols = result.context.compiled.prefetch + postfetch_cols = result.context.compiled.postfetch + elif needs_version_id: + # post updating before a DELETE with a version_id_col, need to + # postfetch just version_id_col + prefetch_cols = postfetch_cols = () + else: + # post updating before a DELETE without a version_id_col, + # don't need to postfetch + return + + if needs_version_id: + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) + if refresh_flush: + load_evt_attrs = [] + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + dict_[mapper._columntoproperty[c].key] = params[c.key] + if refresh_flush: + load_evt_attrs.append(mapper._columntoproperty[c].key) + + if refresh_flush and load_evt_attrs: + mapper.class_manager.dispatch.refresh_flush( + state, uowtransaction, load_evt_attrs + ) + + if postfetch_cols: + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) + + +def _postfetch( + mapper, + uowtransaction, + table, + state, + dict_, + result, + params, + value_params, + isupdate, + returned_defaults, +): + """Expire attributes in need of newly persisted database state, + after an INSERT or UPDATE statement has proceeded for that + state.""" + + prefetch_cols = result.context.compiled.prefetch + postfetch_cols = result.context.compiled.postfetch + returning_cols = result.context.compiled.effective_returning + + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): + prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + + refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) + if refresh_flush: + load_evt_attrs = [] + + if returning_cols: + row = returned_defaults + if row is not None: + for row_value, col in zip(row, returning_cols): + # pk cols returned from insert are handled + # distinctly, don't step on the values here + if col.primary_key and result.context.isinsert: + continue + + # note that columns can be in the "return defaults" that are + # not mapped to this mapper, typically because they are + # "excluded", which can be specified directly or also occurs + # when using declarative w/ single table inheritance + prop = mapper._columntoproperty.get(col) + if prop: + dict_[prop.key] = row_value + if refresh_flush: + load_evt_attrs.append(prop.key) + + for c in prefetch_cols: + if c.key in params and c in mapper._columntoproperty: + pkey = mapper._columntoproperty[c].key + + # set prefetched value in dict and also pop from committed_state, + # since this is new database state that replaces whatever might + # have previously been fetched (see #10800). this is essentially a + # shorthand version of set_committed_value(), which could also be + # used here directly (with more overhead) + dict_[pkey] = params[c.key] + state.committed_state.pop(pkey, None) + + if refresh_flush: + load_evt_attrs.append(pkey) + + if refresh_flush and load_evt_attrs: + mapper.class_manager.dispatch.refresh_flush( + state, uowtransaction, load_evt_attrs + ) + + if isupdate and value_params: + # explicitly suit the use case specified by + # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING + # database which are set to themselves in order to do a version bump. + postfetch_cols.extend( + [ + col + for col in value_params + if col.primary_key and col not in returning_cols + ] + ) + + if postfetch_cols: + state._expire_attributes( + state.dict, + [ + mapper._columntoproperty[c].key + for c in postfetch_cols + if c in mapper._columntoproperty + ], + ) + + # synchronize newly inserted ids from one table to the next + # TODO: this still goes a little too often. would be nice to + # have definitive list of "columns that changed" here + for m, equated_pairs in mapper._table_to_equated[table]: + sync.populate( + state, + m, + state, + m, + equated_pairs, + uowtransaction, + mapper.passive_updates, + ) + + +def _postfetch_bulk_save(mapper, dict_, table): + for m, equated_pairs in mapper._table_to_equated[table]: + sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) + + +def _connections_for_states(base_mapper, uowtransaction, states): + """Return an iterator of (state, state.dict, mapper, connection). + + The states are sorted according to _sort_states, then paired + with the connection they should be using for the given + unit of work transaction. + + """ + # if session has a connection callable, + # organize individual states with the connection + # to use for update + if uowtransaction.session.connection_callable: + connection_callable = uowtransaction.session.connection_callable + else: + connection = uowtransaction.transaction.connection(base_mapper) + connection_callable = None + + for state in _sort_states(base_mapper, states): + if connection_callable: + connection = connection_callable(base_mapper, state.obj()) + + mapper = state.manager.mapper + + yield state, state.dict, mapper, connection + + +def _sort_states(mapper, states): + pending = set(states) + persistent = {s for s in pending if s.key is not None} + pending.difference_update(persistent) + + try: + persistent_sorted = sorted( + persistent, key=mapper._persistent_sortkey_fn + ) + except TypeError as err: + raise sa_exc.InvalidRequestError( + "Could not sort objects by primary key; primary key " + "values must be sortable in Python (was: %s)" % err + ) from err + return ( + sorted(pending, key=operator.attrgetter("insert_order")) + + persistent_sorted + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/properties.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/properties.py new file mode 100644 index 0000000..adee44a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/properties.py @@ -0,0 +1,886 @@ +# orm/properties.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""MapperProperty implementations. + +This is a private module which defines the behavior of individual ORM- +mapped attributes. + +""" + +from __future__ import annotations + +from typing import Any +from typing import cast +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import strategy_options +from .base import _DeclarativeMapped +from .base import class_mapper +from .descriptor_props import CompositeProperty +from .descriptor_props import ConcreteInheritedProperty +from .descriptor_props import SynonymProperty +from .interfaces import _AttributeOptions +from .interfaces import _DEFAULT_ATTRIBUTE_OPTIONS +from .interfaces import _IntrospectsAnnotations +from .interfaces import _MapsColumns +from .interfaces import MapperProperty +from .interfaces import PropComparator +from .interfaces import StrategizedProperty +from .relationships import RelationshipProperty +from .util import de_stringify_annotation +from .util import de_stringify_union_elements +from .. import exc as sa_exc +from .. import ForeignKey +from .. import log +from .. import util +from ..sql import coercions +from ..sql import roles +from ..sql.base import _NoArg +from ..sql.schema import Column +from ..sql.schema import SchemaConst +from ..sql.type_api import TypeEngine +from ..util.typing import de_optionalize_union_types +from ..util.typing import is_fwd_ref +from ..util.typing import is_optional_union +from ..util.typing import is_pep593 +from ..util.typing import is_pep695 +from ..util.typing import is_union +from ..util.typing import Self +from ..util.typing import typing_get_args + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _ORMColumnExprArgument + from ._typing import _RegistryType + from .base import Mapped + from .decl_base import _ClassScanMapperConfig + from .mapper import Mapper + from .session import Session + from .state import _InstallLoaderCallableProto + from .state import InstanceState + from ..sql._typing import _InfoType + from ..sql.elements import ColumnElement + from ..sql.elements import NamedColumn + from ..sql.operators import OperatorType + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference + +_T = TypeVar("_T", bound=Any) +_PT = TypeVar("_PT", bound=Any) +_NC = TypeVar("_NC", bound="NamedColumn[Any]") + +__all__ = [ + "ColumnProperty", + "CompositeProperty", + "ConcreteInheritedProperty", + "RelationshipProperty", + "SynonymProperty", +] + + +@log.class_logger +class ColumnProperty( + _MapsColumns[_T], + StrategizedProperty[_T], + _IntrospectsAnnotations, + log.Identified, +): + """Describes an object attribute that corresponds to a table column + or other column expression. + + Public constructor is the :func:`_orm.column_property` function. + + """ + + strategy_wildcard_key = strategy_options._COLUMN_TOKEN + inherit_cache = True + """:meta private:""" + + _links_to_entity = False + + columns: List[NamedColumn[Any]] + + _is_polymorphic_discriminator: bool + + _mapped_by_synonym: Optional[str] + + comparator_factory: Type[PropComparator[_T]] + + __slots__ = ( + "columns", + "group", + "deferred", + "instrument", + "comparator_factory", + "active_history", + "expire_on_flush", + "_creation_order", + "_is_polymorphic_discriminator", + "_mapped_by_synonym", + "_deferred_column_loader", + "_raise_column_loader", + "_renders_in_subqueries", + "raiseload", + ) + + def __init__( + self, + column: _ORMColumnExprArgument[_T], + *additional_columns: _ORMColumnExprArgument[Any], + attribute_options: Optional[_AttributeOptions] = None, + group: Optional[str] = None, + deferred: bool = False, + raiseload: bool = False, + comparator_factory: Optional[Type[PropComparator[_T]]] = None, + active_history: bool = False, + expire_on_flush: bool = True, + info: Optional[_InfoType] = None, + doc: Optional[str] = None, + _instrument: bool = True, + _assume_readonly_dc_attributes: bool = False, + ): + super().__init__( + attribute_options=attribute_options, + _assume_readonly_dc_attributes=_assume_readonly_dc_attributes, + ) + columns = (column,) + additional_columns + self.columns = [ + coercions.expect(roles.LabeledColumnExprRole, c) for c in columns + ] + self.group = group + self.deferred = deferred + self.raiseload = raiseload + self.instrument = _instrument + self.comparator_factory = ( + comparator_factory + if comparator_factory is not None + else self.__class__.Comparator + ) + self.active_history = active_history + self.expire_on_flush = expire_on_flush + + if info is not None: + self.info.update(info) + + if doc is not None: + self.doc = doc + else: + for col in reversed(self.columns): + doc = getattr(col, "doc", None) + if doc is not None: + self.doc = doc + break + else: + self.doc = None + + util.set_creation_order(self) + + self.strategy_key = ( + ("deferred", self.deferred), + ("instrument", self.instrument), + ) + if self.raiseload: + self.strategy_key += (("raiseload", True),) + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + column = self.columns[0] + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: + return self + + @property + def columns_to_assign(self) -> List[Tuple[Column[Any], int]]: + # mypy doesn't care about the isinstance here + return [ + (c, 0) # type: ignore + for c in self.columns + if isinstance(c, Column) and c.table is None + ] + + def _memoized_attr__renders_in_subqueries(self) -> bool: + if ("query_expression", True) in self.strategy_key: + return self.strategy._have_default_expression # type: ignore + + return ("deferred", True) not in self.strategy_key or ( + self not in self.parent._readonly_props # type: ignore + ) + + @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") + def _memoized_attr__deferred_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: + state = util.preloaded.orm_state + strategies = util.preloaded.orm_strategies + return state.InstanceState._instance_level_callable_processor( + self.parent.class_manager, + strategies.LoadDeferredColumns(self.key), + self.key, + ) + + @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") + def _memoized_attr__raise_column_loader( + self, + ) -> _InstallLoaderCallableProto[Any]: + state = util.preloaded.orm_state + strategies = util.preloaded.orm_strategies + return state.InstanceState._instance_level_callable_processor( + self.parent.class_manager, + strategies.LoadDeferredColumns(self.key, True), + self.key, + ) + + def __clause_element__(self) -> roles.ColumnsClauseRole: + """Allow the ColumnProperty to work in expression before it is turned + into an instrumented attribute. + """ + + return self.expression + + @property + def expression(self) -> roles.ColumnsClauseRole: + """Return the primary column or expression for this ColumnProperty. + + E.g.:: + + + class File(Base): + # ... + + name = Column(String(64)) + extension = Column(String(8)) + filename = column_property(name + '.' + extension) + path = column_property('C:/' + filename.expression) + + .. seealso:: + + :ref:`mapper_column_property_sql_expressions_composed` + + """ + return self.columns[0] + + def instrument_class(self, mapper: Mapper[Any]) -> None: + if not self.instrument: + return + + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + doc=self.doc, + ) + + def do_init(self) -> None: + super().do_init() + + if len(self.columns) > 1 and set(self.parent.primary_key).issuperset( + self.columns + ): + util.warn( + ( + "On mapper %s, primary key column '%s' is being combined " + "with distinct primary key column '%s' in attribute '%s'. " + "Use explicit properties to give each column its own " + "mapped attribute name." + ) + % (self.parent, self.columns[1], self.columns[0], self.key) + ) + + def copy(self) -> ColumnProperty[_T]: + return ColumnProperty( + *self.columns, + deferred=self.deferred, + group=self.group, + active_history=self.active_history, + ) + + def merge( + self, + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: + if not self.instrument: + return + elif self.key in source_dict: + value = source_dict[self.key] + + if not load: + dest_dict[self.key] = value + else: + impl = dest_state.get_impl(self.key) + impl.set(dest_state, dest_dict, value, None) + elif dest_state.has_identity and self.key not in dest_dict: + dest_state._expire_attributes( + dest_dict, [self.key], no_loader=True + ) + + class Comparator(util.MemoizedSlots, PropComparator[_PT]): + """Produce boolean, comparison, and other operators for + :class:`.ColumnProperty` attributes. + + See the documentation for :class:`.PropComparator` for a brief + overview. + + .. seealso:: + + :class:`.PropComparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + if not TYPE_CHECKING: + # prevent pylance from being clever about slots + __slots__ = "__clause_element__", "info", "expressions" + + prop: RODescriptorReference[ColumnProperty[_PT]] + + expressions: Sequence[NamedColumn[Any]] + """The full sequence of columns referenced by this + attribute, adjusted for any aliasing in progress. + + .. versionadded:: 1.3.17 + + .. seealso:: + + :ref:`maptojoin` - usage example + """ + + def _orm_annotate_column(self, column: _NC) -> _NC: + """annotate and possibly adapt a column to be returned + as the mapped-attribute exposed version of the column. + + The column in this context needs to act as much like the + column in an ORM mapped context as possible, so includes + annotations to give hints to various ORM functions as to + the source entity of this column. It also adapts it + to the mapper's with_polymorphic selectable if one is + present. + + """ + + pe = self._parententity + annotations: Dict[str, Any] = { + "entity_namespace": pe, + "parententity": pe, + "parentmapper": pe, + "proxy_key": self.prop.key, + } + + col = column + + # for a mapper with polymorphic_on and an adapter, return + # the column against the polymorphic selectable. + # see also orm.util._orm_downgrade_polymorphic_columns + # for the reverse operation. + if self._parentmapper._polymorphic_adapter: + mapper_local_col = col + col = self._parentmapper._polymorphic_adapter.traverse(col) + + # this is a clue to the ORM Query etc. that this column + # was adapted to the mapper's polymorphic_adapter. the + # ORM uses this hint to know which column its adapting. + annotations["adapt_column"] = mapper_local_col + + return col._annotate(annotations)._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": pe} + ) + + if TYPE_CHECKING: + + def __clause_element__(self) -> NamedColumn[_PT]: ... + + def _memoized_method___clause_element__( + self, + ) -> NamedColumn[_PT]: + if self.adapter: + return self.adapter(self.prop.columns[0], self.prop.key) + else: + return self._orm_annotate_column(self.prop.columns[0]) + + def _memoized_attr_info(self) -> _InfoType: + """The .info dictionary for this attribute.""" + + ce = self.__clause_element__() + try: + return ce.info # type: ignore + except AttributeError: + return self.prop.info + + def _memoized_attr_expressions(self) -> Sequence[NamedColumn[Any]]: + """The full sequence of columns referenced by this + attribute, adjusted for any aliasing in progress. + + .. versionadded:: 1.3.17 + + """ + if self.adapter: + return [ + self.adapter(col, self.prop.key) + for col in self.prop.columns + ] + else: + return [ + self._orm_annotate_column(col) for col in self.prop.columns + ] + + def _fallback_getattr(self, key: str) -> Any: + """proxy attribute access down to the mapped column. + + this allows user-defined comparison methods to be accessed. + """ + return getattr(self.__clause_element__(), key) + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def __str__(self) -> str: + if not self.parent or not self.key: + return object.__repr__(self) + return str(self.parent.class_.__name__) + "." + self.key + + +class MappedSQLExpression(ColumnProperty[_T], _DeclarativeMapped[_T]): + """Declarative front-end for the :class:`.ColumnProperty` class. + + Public constructor is the :func:`_orm.column_property` function. + + .. versionchanged:: 2.0 Added :class:`_orm.MappedSQLExpression` as + a Declarative compatible subclass for :class:`_orm.ColumnProperty`. + + .. seealso:: + + :class:`.MappedColumn` + + """ + + inherit_cache = True + """:meta private:""" + + +class MappedColumn( + _IntrospectsAnnotations, + _MapsColumns[_T], + _DeclarativeMapped[_T], +): + """Maps a single :class:`_schema.Column` on a class. + + :class:`_orm.MappedColumn` is a specialization of the + :class:`_orm.ColumnProperty` class and is oriented towards declarative + configuration. + + To construct :class:`_orm.MappedColumn` objects, use the + :func:`_orm.mapped_column` constructor function. + + .. versionadded:: 2.0 + + + """ + + __slots__ = ( + "column", + "_creation_order", + "_sort_order", + "foreign_keys", + "_has_nullable", + "_has_insert_default", + "deferred", + "deferred_group", + "deferred_raiseload", + "active_history", + "_attribute_options", + "_has_dataclass_arguments", + "_use_existing_column", + ) + + deferred: Union[_NoArg, bool] + deferred_raiseload: bool + deferred_group: Optional[str] + + column: Column[_T] + foreign_keys: Optional[Set[ForeignKey]] + _attribute_options: _AttributeOptions + + def __init__(self, *arg: Any, **kw: Any): + self._attribute_options = attr_opts = kw.pop( + "attribute_options", _DEFAULT_ATTRIBUTE_OPTIONS + ) + + self._use_existing_column = kw.pop("use_existing_column", False) + + self._has_dataclass_arguments = ( + attr_opts is not None + and attr_opts != _DEFAULT_ATTRIBUTE_OPTIONS + and any( + attr_opts[i] is not _NoArg.NO_ARG + for i, attr in enumerate(attr_opts._fields) + if attr != "dataclasses_default" + ) + ) + + insert_default = kw.pop("insert_default", _NoArg.NO_ARG) + self._has_insert_default = insert_default is not _NoArg.NO_ARG + + if self._has_insert_default: + kw["default"] = insert_default + elif attr_opts.dataclasses_default is not _NoArg.NO_ARG: + kw["default"] = attr_opts.dataclasses_default + + self.deferred_group = kw.pop("deferred_group", None) + self.deferred_raiseload = kw.pop("deferred_raiseload", None) + self.deferred = kw.pop("deferred", _NoArg.NO_ARG) + self.active_history = kw.pop("active_history", False) + + self._sort_order = kw.pop("sort_order", _NoArg.NO_ARG) + self.column = cast("Column[_T]", Column(*arg, **kw)) + self.foreign_keys = self.column.foreign_keys + self._has_nullable = "nullable" in kw and kw.get("nullable") not in ( + None, + SchemaConst.NULL_UNSPECIFIED, + ) + util.set_creation_order(self) + + def _copy(self, **kw: Any) -> Self: + new = self.__class__.__new__(self.__class__) + new.column = self.column._copy(**kw) + new.deferred = self.deferred + new.deferred_group = self.deferred_group + new.deferred_raiseload = self.deferred_raiseload + new.foreign_keys = new.column.foreign_keys + new.active_history = self.active_history + new._has_nullable = self._has_nullable + new._attribute_options = self._attribute_options + new._has_insert_default = self._has_insert_default + new._has_dataclass_arguments = self._has_dataclass_arguments + new._use_existing_column = self._use_existing_column + new._sort_order = self._sort_order + util.set_creation_order(new) + return new + + @property + def name(self) -> str: + return self.column.name + + @property + def mapper_property_to_assign(self) -> Optional[MapperProperty[_T]]: + effective_deferred = self.deferred + if effective_deferred is _NoArg.NO_ARG: + effective_deferred = bool( + self.deferred_group or self.deferred_raiseload + ) + + if effective_deferred or self.active_history: + return ColumnProperty( + self.column, + deferred=effective_deferred, + group=self.deferred_group, + raiseload=self.deferred_raiseload, + attribute_options=self._attribute_options, + active_history=self.active_history, + ) + else: + return None + + @property + def columns_to_assign(self) -> List[Tuple[Column[Any], int]]: + return [ + ( + self.column, + ( + self._sort_order + if self._sort_order is not _NoArg.NO_ARG + else 0 + ), + ) + ] + + def __clause_element__(self) -> Column[_T]: + return self.column + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(self.__clause_element__(), *other, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + col = self.__clause_element__() + return op(col._bind_param(op, other), col, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def found_in_pep593_annotated(self) -> Any: + # return a blank mapped_column(). This mapped_column()'s + # Column will be merged into it in _init_column_for_annotation(). + return MappedColumn() + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + column = self.column + + if ( + self._use_existing_column + and decl_scan.inherits + and decl_scan.single + ): + if decl_scan.is_deferred: + raise sa_exc.ArgumentError( + "Can't use use_existing_column with deferred mappers" + ) + supercls_mapper = class_mapper(decl_scan.inherits, False) + + colname = column.name if column.name is not None else key + column = self.column = supercls_mapper.local_table.c.get( # type: ignore # noqa: E501 + colname, column + ) + + if column.key is None: + column.key = key + if column.name is None: + column.name = key + + sqltype = column.type + + if extracted_mapped_annotation is None: + if sqltype._isnull and not self.column.foreign_keys: + self._raise_for_required(key, cls) + else: + return + + self._init_column_for_annotation( + cls, + registry, + extracted_mapped_annotation, + originating_module, + ) + + @util.preload_module("sqlalchemy.orm.decl_base") + def declarative_scan_for_composite( + self, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + param_name: str, + param_annotation: _AnnotationScanType, + ) -> None: + decl_base = util.preloaded.orm_decl_base + decl_base._undefer_column_name(param_name, self.column) + self._init_column_for_annotation( + cls, registry, param_annotation, originating_module + ) + + def _init_column_for_annotation( + self, + cls: Type[Any], + registry: _RegistryType, + argument: _AnnotationScanType, + originating_module: Optional[str], + ) -> None: + sqltype = self.column.type + + if isinstance(argument, str) or is_fwd_ref( + argument, check_generic=True + ): + assert originating_module is not None + argument = de_stringify_annotation( + cls, argument, originating_module, include_generic=True + ) + + if is_union(argument): + assert originating_module is not None + argument = de_stringify_union_elements( + cls, argument, originating_module + ) + + nullable = is_optional_union(argument) + + if not self._has_nullable: + self.column.nullable = nullable + + our_type = de_optionalize_union_types(argument) + + use_args_from = None + + our_original_type = our_type + + if is_pep695(our_type): + our_type = our_type.__value__ + + if is_pep593(our_type): + our_type_is_pep593 = True + + pep_593_components = typing_get_args(our_type) + raw_pep_593_type = pep_593_components[0] + if is_optional_union(raw_pep_593_type): + raw_pep_593_type = de_optionalize_union_types(raw_pep_593_type) + + nullable = True + if not self._has_nullable: + self.column.nullable = nullable + for elem in pep_593_components[1:]: + if isinstance(elem, MappedColumn): + use_args_from = elem + break + else: + our_type_is_pep593 = False + raw_pep_593_type = None + + if use_args_from is not None: + if ( + not self._has_insert_default + and use_args_from.column.default is not None + ): + self.column.default = None + + use_args_from.column._merge(self.column) + sqltype = self.column.type + + if ( + use_args_from.deferred is not _NoArg.NO_ARG + and self.deferred is _NoArg.NO_ARG + ): + self.deferred = use_args_from.deferred + + if ( + use_args_from.deferred_group is not None + and self.deferred_group is None + ): + self.deferred_group = use_args_from.deferred_group + + if ( + use_args_from.deferred_raiseload is not None + and self.deferred_raiseload is None + ): + self.deferred_raiseload = use_args_from.deferred_raiseload + + if ( + use_args_from._use_existing_column + and not self._use_existing_column + ): + self._use_existing_column = True + + if use_args_from.active_history: + self.active_history = use_args_from.active_history + + if ( + use_args_from._sort_order is not None + and self._sort_order is _NoArg.NO_ARG + ): + self._sort_order = use_args_from._sort_order + + if ( + use_args_from.column.key is not None + or use_args_from.column.name is not None + ): + util.warn_deprecated( + "Can't use the 'key' or 'name' arguments in " + "Annotated with mapped_column(); this will be ignored", + "2.0.22", + ) + + if use_args_from._has_dataclass_arguments: + for idx, arg in enumerate( + use_args_from._attribute_options._fields + ): + if ( + use_args_from._attribute_options[idx] + is not _NoArg.NO_ARG + ): + arg = arg.replace("dataclasses_", "") + util.warn_deprecated( + f"Argument '{arg}' is a dataclass argument and " + "cannot be specified within a mapped_column() " + "bundled inside of an Annotated object", + "2.0.22", + ) + + if sqltype._isnull and not self.column.foreign_keys: + new_sqltype = None + + if our_type_is_pep593: + checks = [our_original_type, raw_pep_593_type] + else: + checks = [our_original_type] + + for check_type in checks: + new_sqltype = registry._resolve_type(check_type) + if new_sqltype is not None: + break + else: + if isinstance(our_type, TypeEngine) or ( + isinstance(our_type, type) + and issubclass(our_type, TypeEngine) + ): + raise sa_exc.ArgumentError( + f"The type provided inside the {self.column.key!r} " + "attribute Mapped annotation is the SQLAlchemy type " + f"{our_type}. Expected a Python type instead" + ) + else: + raise sa_exc.ArgumentError( + "Could not locate SQLAlchemy Core type for Python " + f"type {our_type} inside the {self.column.key!r} " + "attribute Mapped annotation" + ) + + self.column._set_type(new_sqltype) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/query.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/query.py new file mode 100644 index 0000000..1dfc9cb --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/query.py @@ -0,0 +1,3394 @@ +# orm/query.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""The Query class and support. + +Defines the :class:`_query.Query` class, the central +construct used by the ORM to construct database queries. + +The :class:`_query.Query` class should not be confused with the +:class:`_expression.Select` class, which defines database +SELECT operations at the SQL (non-ORM) level. ``Query`` differs from +``Select`` in that it returns ORM-mapped objects and interacts with an +ORM session, whereas the ``Select`` construct interacts directly with the +database to return iterable result sets. + +""" +from __future__ import annotations + +import collections.abc as collections_abc +import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import attributes +from . import interfaces +from . import loading +from . import util as orm_util +from ._typing import _O +from .base import _assertions +from .context import _column_descriptions +from .context import _determine_last_joined_entity +from .context import _legacy_filter_by_entity_zero +from .context import FromStatement +from .context import ORMCompileState +from .context import QueryContext +from .interfaces import ORMColumnDescription +from .interfaces import ORMColumnsClauseRole +from .util import AliasedClass +from .util import object_mapper +from .util import with_parent +from .. import exc as sa_exc +from .. import inspect +from .. import inspection +from .. import log +from .. import sql +from .. import util +from ..engine import Result +from ..engine import Row +from ..event import dispatcher +from ..event import EventTarget +from ..sql import coercions +from ..sql import expression +from ..sql import roles +from ..sql import Select +from ..sql import util as sql_util +from ..sql import visitors +from ..sql._typing import _FromClauseArgument +from ..sql._typing import _TP +from ..sql.annotation import SupportsCloneAnnotations +from ..sql.base import _entity_namespace_key +from ..sql.base import _generative +from ..sql.base import _NoArg +from ..sql.base import Executable +from ..sql.base import Generative +from ..sql.elements import BooleanClauseList +from ..sql.expression import Exists +from ..sql.selectable import _MemoizedSelectEntities +from ..sql.selectable import _SelectFromElements +from ..sql.selectable import ForUpdateArg +from ..sql.selectable import HasHints +from ..sql.selectable import HasPrefixes +from ..sql.selectable import HasSuffixes +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import SelectLabelStyle +from ..util.typing import Literal +from ..util.typing import Self + + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _InternalEntityType + from ._typing import SynchronizeSessionArgument + from .mapper import Mapper + from .path_registry import PathRegistry + from .session import _PKIdentityArgument + from .session import Session + from .state import InstanceState + from ..engine.cursor import CursorResult + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import IsolationLevel + from ..engine.interfaces import SchemaTranslateMapType + from ..engine.result import FrozenResult + from ..engine.result import ScalarResult + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _ColumnExpressionOrStrLabelArgument + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _JoinTargetArgument + from ..sql._typing import _LimitOffsetType + from ..sql._typing import _MAYBE_ENTITY + from ..sql._typing import _no_kw + from ..sql._typing import _NOT_ENTITY + from ..sql._typing import _OnClauseArgument + from ..sql._typing import _PropagateAttrsType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.base import CacheableOptions + from ..sql.base import ExecutableOption + from ..sql.elements import ColumnElement + from ..sql.elements import Label + from ..sql.selectable import _ForUpdateOfArgument + from ..sql.selectable import _JoinTargetElement + from ..sql.selectable import _SetupJoinsElement + from ..sql.selectable import Alias + from ..sql.selectable import CTE + from ..sql.selectable import ExecutableReturnsRows + from ..sql.selectable import FromClause + from ..sql.selectable import ScalarSelect + from ..sql.selectable import Subquery + + +__all__ = ["Query", "QueryContext"] + +_T = TypeVar("_T", bound=Any) + + +@inspection._self_inspects +@log.class_logger +class Query( + _SelectFromElements, + SupportsCloneAnnotations, + HasPrefixes, + HasSuffixes, + HasHints, + EventTarget, + log.Identified, + Generative, + Executable, + Generic[_T], +): + """ORM-level SQL construction object. + + .. legacy:: The ORM :class:`.Query` object is a legacy construct + as of SQLAlchemy 2.0. See the notes at the top of + :ref:`query_api_toplevel` for an overview, including links to migration + documentation. + + :class:`_query.Query` objects are normally initially generated using the + :meth:`~.Session.query` method of :class:`.Session`, and in + less common cases by instantiating the :class:`_query.Query` directly and + associating with a :class:`.Session` using the + :meth:`_query.Query.with_session` + method. + + """ + + # elements that are in Core and can be cached in the same way + _where_criteria: Tuple[ColumnElement[Any], ...] = () + _having_criteria: Tuple[ColumnElement[Any], ...] = () + + _order_by_clauses: Tuple[ColumnElement[Any], ...] = () + _group_by_clauses: Tuple[ColumnElement[Any], ...] = () + _limit_clause: Optional[ColumnElement[Any]] = None + _offset_clause: Optional[ColumnElement[Any]] = None + + _distinct: bool = False + _distinct_on: Tuple[ColumnElement[Any], ...] = () + + _for_update_arg: Optional[ForUpdateArg] = None + _correlate: Tuple[FromClause, ...] = () + _auto_correlate: bool = True + _from_obj: Tuple[FromClause, ...] = () + _setup_joins: Tuple[_SetupJoinsElement, ...] = () + + _label_style: SelectLabelStyle = SelectLabelStyle.LABEL_STYLE_LEGACY_ORM + + _memoized_select_entities = () + + _compile_options: Union[Type[CacheableOptions], CacheableOptions] = ( + ORMCompileState.default_compile_options + ) + + _with_options: Tuple[ExecutableOption, ...] + load_options = QueryContext.default_load_options + { + "_legacy_uniquing": True + } + + _params: util.immutabledict[str, Any] = util.EMPTY_DICT + + # local Query builder state, not needed for + # compilation or execution + _enable_assertions = True + + _statement: Optional[ExecutableReturnsRows] = None + + session: Session + + dispatch: dispatcher[Query[_T]] + + # mirrors that of ClauseElement, used to propagate the "orm" + # plugin as well as the "subject" of the plugin, e.g. the mapper + # we are querying against. + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + return util.EMPTY_DICT + + def __init__( + self, + entities: Union[ + _ColumnsClauseArgument[Any], Sequence[_ColumnsClauseArgument[Any]] + ], + session: Optional[Session] = None, + ): + """Construct a :class:`_query.Query` directly. + + E.g.:: + + q = Query([User, Address], session=some_session) + + The above is equivalent to:: + + q = some_session.query(User, Address) + + :param entities: a sequence of entities and/or SQL expressions. + + :param session: a :class:`.Session` with which the + :class:`_query.Query` + will be associated. Optional; a :class:`_query.Query` + can be associated + with a :class:`.Session` generatively via the + :meth:`_query.Query.with_session` method as well. + + .. seealso:: + + :meth:`.Session.query` + + :meth:`_query.Query.with_session` + + """ + + # session is usually present. There's one case in subqueryloader + # where it stores a Query without a Session and also there are tests + # for the query(Entity).with_session(session) API which is likely in + # some old recipes, however these are legacy as select() can now be + # used. + self.session = session # type: ignore + self._set_entities(entities) + + def _set_propagate_attrs(self, values: Mapping[str, Any]) -> Self: + self._propagate_attrs = util.immutabledict(values) + return self + + def _set_entities( + self, + entities: Union[ + _ColumnsClauseArgument[Any], Iterable[_ColumnsClauseArgument[Any]] + ], + ) -> None: + self._raw_columns = [ + coercions.expect( + roles.ColumnsClauseRole, + ent, + apply_propagate_attrs=self, + post_inspect=True, + ) + for ent in util.to_list(entities) + ] + + def tuples(self: Query[_O]) -> Query[Tuple[_O]]: + """return a tuple-typed form of this :class:`.Query`. + + This method invokes the :meth:`.Query.only_return_tuples` + method with a value of ``True``, which by itself ensures that this + :class:`.Query` will always return :class:`.Row` objects, even + if the query is made against a single entity. It then also + at the typing level will return a "typed" query, if possible, + that will type result rows as ``Tuple`` objects with typed + elements. + + This method can be compared to the :meth:`.Result.tuples` method, + which returns "self", but from a typing perspective returns an object + that will yield typed ``Tuple`` objects for results. Typing + takes effect only if this :class:`.Query` object is a typed + query object already. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Result.tuples` - v2 equivalent method. + + """ + return self.only_return_tuples(True) # type: ignore + + def _entity_from_pre_ent_zero(self) -> Optional[_InternalEntityType[Any]]: + if not self._raw_columns: + return None + + ent = self._raw_columns[0] + + if "parententity" in ent._annotations: + return ent._annotations["parententity"] # type: ignore + elif "bundle" in ent._annotations: + return ent._annotations["bundle"] # type: ignore + else: + # label, other SQL expression + for element in visitors.iterate(ent): + if "parententity" in element._annotations: + return element._annotations["parententity"] # type: ignore # noqa: E501 + else: + return None + + def _only_full_mapper_zero(self, methname: str) -> Mapper[Any]: + if ( + len(self._raw_columns) != 1 + or "parententity" not in self._raw_columns[0]._annotations + or not self._raw_columns[0].is_selectable + ): + raise sa_exc.InvalidRequestError( + "%s() can only be used against " + "a single mapped class." % methname + ) + + return self._raw_columns[0]._annotations["parententity"] # type: ignore # noqa: E501 + + def _set_select_from( + self, obj: Iterable[_FromClauseArgument], set_base_alias: bool + ) -> None: + fa = [ + coercions.expect( + roles.StrictFromClauseRole, + elem, + allow_select=True, + apply_propagate_attrs=self, + ) + for elem in obj + ] + + self._compile_options += {"_set_base_alias": set_base_alias} + self._from_obj = tuple(fa) + + @_generative + def _set_lazyload_from(self, state: InstanceState[Any]) -> Self: + self.load_options += {"_lazy_loaded_from": state} + return self + + def _get_condition(self) -> None: + """used by legacy BakedQuery""" + self._no_criterion_condition("get", order_by=False, distinct=False) + + def _get_existing_condition(self) -> None: + self._no_criterion_assertion("get", order_by=False, distinct=False) + + def _no_criterion_assertion( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: + if not self._enable_assertions: + return + if ( + self._where_criteria + or self._statement is not None + or self._from_obj + or self._setup_joins + or self._limit_clause is not None + or self._offset_clause is not None + or self._group_by_clauses + or (order_by and self._order_by_clauses) + or (distinct and self._distinct) + ): + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth + ) + + def _no_criterion_condition( + self, meth: str, order_by: bool = True, distinct: bool = True + ) -> None: + self._no_criterion_assertion(meth, order_by, distinct) + + self._from_obj = self._setup_joins = () + if self._statement is not None: + self._compile_options += {"_statement": None} + self._where_criteria = () + self._distinct = False + + self._order_by_clauses = self._group_by_clauses = () + + def _no_clauseelement_condition(self, meth: str) -> None: + if not self._enable_assertions: + return + if self._order_by_clauses: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a " + "Query with existing criterion. " % meth + ) + self._no_criterion_condition(meth) + + def _no_statement_condition(self, meth: str) -> None: + if not self._enable_assertions: + return + if self._statement is not None: + raise sa_exc.InvalidRequestError( + ( + "Query.%s() being called on a Query with an existing full " + "statement - can't apply criterion." + ) + % meth + ) + + def _no_limit_offset(self, meth: str) -> None: + if not self._enable_assertions: + return + if self._limit_clause is not None or self._offset_clause is not None: + raise sa_exc.InvalidRequestError( + "Query.%s() being called on a Query which already has LIMIT " + "or OFFSET applied. Call %s() before limit() or offset() " + "are applied." % (meth, meth) + ) + + @property + def _has_row_limiting_clause(self) -> bool: + return ( + self._limit_clause is not None or self._offset_clause is not None + ) + + def _get_options( + self, + populate_existing: Optional[bool] = None, + version_check: Optional[bool] = None, + only_load_props: Optional[Sequence[str]] = None, + refresh_state: Optional[InstanceState[Any]] = None, + identity_token: Optional[Any] = None, + ) -> Self: + load_options: Dict[str, Any] = {} + compile_options: Dict[str, Any] = {} + + if version_check: + load_options["_version_check"] = version_check + if populate_existing: + load_options["_populate_existing"] = populate_existing + if refresh_state: + load_options["_refresh_state"] = refresh_state + compile_options["_for_refresh_state"] = True + if only_load_props: + compile_options["_only_load_props"] = frozenset(only_load_props) + if identity_token: + load_options["_identity_token"] = identity_token + + if load_options: + self.load_options += load_options + if compile_options: + self._compile_options += compile_options + + return self + + def _clone(self, **kw: Any) -> Self: + return self._generate() + + def _get_select_statement_only(self) -> Select[_T]: + if self._statement is not None: + raise sa_exc.InvalidRequestError( + "Can't call this method on a Query that uses from_statement()" + ) + return cast("Select[_T]", self.statement) + + @property + def statement(self) -> Union[Select[_T], FromStatement[_T]]: + """The full SELECT statement represented by this Query. + + The statement by default will not have disambiguating labels + applied to the construct unless with_labels(True) is called + first. + + """ + + # .statement can return the direct future.Select() construct here, as + # long as we are not using subsequent adaption features that + # are made against raw entities, e.g. from_self(), with_polymorphic(), + # select_entity_from(). If these features are being used, then + # the Select() we return will not have the correct .selected_columns + # collection and will not embed in subsequent queries correctly. + # We could find a way to make this collection "correct", however + # this would not be too different from doing the full compile as + # we are doing in any case, the Select() would still not have the + # proper state for other attributes like whereclause, order_by, + # and these features are all deprecated in any case. + # + # for these reasons, Query is not a Select, it remains an ORM + # object for which __clause_element__() must be called in order for + # it to provide a real expression object. + # + # from there, it starts to look much like Query itself won't be + # passed into the execute process and won't generate its own cache + # key; this will all occur in terms of the ORM-enabled Select. + if not self._compile_options._set_base_alias: + # if we don't have legacy top level aliasing features in use + # then convert to a future select() directly + stmt = self._statement_20(for_statement=True) + else: + stmt = self._compile_state(for_statement=True).statement + + if self._params: + stmt = stmt.params(self._params) + + return stmt + + def _final_statement(self, legacy_query_style: bool = True) -> Select[Any]: + """Return the 'final' SELECT statement for this :class:`.Query`. + + This is used by the testing suite only and is fairly inefficient. + + This is the Core-only select() that will be rendered by a complete + compilation of this query, and is what .statement used to return + in 1.3. + + + """ + + q = self._clone() + + return q._compile_state( + use_legacy_query_style=legacy_query_style + ).statement # type: ignore + + def _statement_20( + self, for_statement: bool = False, use_legacy_query_style: bool = True + ) -> Union[Select[_T], FromStatement[_T]]: + # TODO: this event needs to be deprecated, as it currently applies + # only to ORM query and occurs at this spot that is now more + # or less an artificial spot + if self.dispatch.before_compile: + for fn in self.dispatch.before_compile: + new_query = fn(self) + if new_query is not None and new_query is not self: + self = new_query + if not fn._bake_ok: # type: ignore + self._compile_options += {"_bake_ok": False} + + compile_options = self._compile_options + compile_options += { + "_for_statement": for_statement, + "_use_legacy_query_style": use_legacy_query_style, + } + + stmt: Union[Select[_T], FromStatement[_T]] + + if self._statement is not None: + stmt = FromStatement(self._raw_columns, self._statement) + stmt.__dict__.update( + _with_options=self._with_options, + _with_context_options=self._with_context_options, + _compile_options=compile_options, + _execution_options=self._execution_options, + _propagate_attrs=self._propagate_attrs, + ) + else: + # Query / select() internal attributes are 99% cross-compatible + stmt = Select._create_raw_select(**self.__dict__) + stmt.__dict__.update( + _label_style=self._label_style, + _compile_options=compile_options, + _propagate_attrs=self._propagate_attrs, + ) + stmt.__dict__.pop("session", None) + + # ensure the ORM context is used to compile the statement, even + # if it has no ORM entities. This is so ORM-only things like + # _legacy_joins are picked up that wouldn't be picked up by the + # Core statement context + if "compile_state_plugin" not in stmt._propagate_attrs: + stmt._propagate_attrs = stmt._propagate_attrs.union( + {"compile_state_plugin": "orm", "plugin_subject": None} + ) + + return stmt + + def subquery( + self, + name: Optional[str] = None, + with_labels: bool = False, + reduce_columns: bool = False, + ) -> Subquery: + """Return the full SELECT statement represented by + this :class:`_query.Query`, embedded within an + :class:`_expression.Alias`. + + Eager JOIN generation within the query is disabled. + + .. seealso:: + + :meth:`_sql.Select.subquery` - v2 comparable method. + + :param name: string name to be assigned as the alias; + this is passed through to :meth:`_expression.FromClause.alias`. + If ``None``, a name will be deterministically generated + at compile time. + + :param with_labels: if True, :meth:`.with_labels` will be called + on the :class:`_query.Query` first to apply table-qualified labels + to all columns. + + :param reduce_columns: if True, + :meth:`_expression.Select.reduce_columns` will + be called on the resulting :func:`_expression.select` construct, + to remove same-named columns where one also refers to the other + via foreign key or WHERE clause equivalence. + + """ + q = self.enable_eagerloads(False) + if with_labels: + q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + + stmt = q._get_select_statement_only() + + if TYPE_CHECKING: + assert isinstance(stmt, Select) + + if reduce_columns: + stmt = stmt.reduce_columns() + return stmt.subquery(name=name) + + def cte( + self, + name: Optional[str] = None, + recursive: bool = False, + nesting: bool = False, + ) -> CTE: + r"""Return the full SELECT statement represented by this + :class:`_query.Query` represented as a common table expression (CTE). + + Parameters and usage are the same as those of the + :meth:`_expression.SelectBase.cte` method; see that method for + further details. + + Here is the `PostgreSQL WITH + RECURSIVE example + `_. + Note that, in this example, the ``included_parts`` cte and the + ``incl_alias`` alias of it are Core selectables, which + means the columns are accessed via the ``.c.`` attribute. The + ``parts_alias`` object is an :func:`_orm.aliased` instance of the + ``Part`` entity, so column-mapped attributes are available + directly:: + + from sqlalchemy.orm import aliased + + class Part(Base): + __tablename__ = 'part' + part = Column(String, primary_key=True) + sub_part = Column(String, primary_key=True) + quantity = Column(Integer) + + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\ + filter(Part.part=="our part").\ + cte(name="included_parts", recursive=True) + + incl_alias = aliased(included_parts, name="pr") + parts_alias = aliased(Part, name="p") + included_parts = included_parts.union_all( + session.query( + parts_alias.sub_part, + parts_alias.part, + parts_alias.quantity).\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) + + q = session.query( + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ).\ + group_by(included_parts.c.sub_part) + + .. seealso:: + + :meth:`_sql.Select.cte` - v2 equivalent method. + + """ + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .cte(name=name, recursive=recursive, nesting=nesting) + ) + + def label(self, name: Optional[str]) -> Label[Any]: + """Return the full SELECT statement represented by this + :class:`_query.Query`, converted + to a scalar subquery with a label of the given name. + + .. seealso:: + + :meth:`_sql.Select.label` - v2 comparable method. + + """ + + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .label(name) + ) + + @overload + def as_scalar( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[_MAYBE_ENTITY]: ... + + @overload + def as_scalar( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: ... + + @overload + def as_scalar(self) -> ScalarSelect[Any]: ... + + @util.deprecated( + "1.4", + "The :meth:`_query.Query.as_scalar` method is deprecated and will be " + "removed in a future release. Please refer to " + ":meth:`_query.Query.scalar_subquery`.", + ) + def as_scalar(self) -> ScalarSelect[Any]: + """Return the full SELECT statement represented by this + :class:`_query.Query`, converted to a scalar subquery. + + """ + return self.scalar_subquery() + + @overload + def scalar_subquery( + self: Query[Tuple[_MAYBE_ENTITY]], + ) -> ScalarSelect[Any]: ... + + @overload + def scalar_subquery( + self: Query[Tuple[_NOT_ENTITY]], + ) -> ScalarSelect[_NOT_ENTITY]: ... + + @overload + def scalar_subquery(self) -> ScalarSelect[Any]: ... + + def scalar_subquery(self) -> ScalarSelect[Any]: + """Return the full SELECT statement represented by this + :class:`_query.Query`, converted to a scalar subquery. + + Analogous to + :meth:`sqlalchemy.sql.expression.SelectBase.scalar_subquery`. + + .. versionchanged:: 1.4 The :meth:`_query.Query.scalar_subquery` + method replaces the :meth:`_query.Query.as_scalar` method. + + .. seealso:: + + :meth:`_sql.Select.scalar_subquery` - v2 comparable method. + + """ + + return ( + self.enable_eagerloads(False) + ._get_select_statement_only() + .scalar_subquery() + ) + + @property + def selectable(self) -> Union[Select[_T], FromStatement[_T]]: + """Return the :class:`_expression.Select` object emitted by this + :class:`_query.Query`. + + Used for :func:`_sa.inspect` compatibility, this is equivalent to:: + + query.enable_eagerloads(False).with_labels().statement + + """ + return self.__clause_element__() + + def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: + return ( + self._with_compile_options( + _enable_eagerloads=False, _render_for_subquery=True + ) + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + .statement + ) + + @overload + def only_return_tuples( + self: Query[_O], value: Literal[True] + ) -> RowReturningQuery[Tuple[_O]]: ... + + @overload + def only_return_tuples( + self: Query[_O], value: Literal[False] + ) -> Query[_O]: ... + + @_generative + def only_return_tuples(self, value: bool) -> Query[Any]: + """When set to True, the query results will always be a + :class:`.Row` object. + + This can change a query that normally returns a single entity + as a scalar to return a :class:`.Row` result in all cases. + + .. seealso:: + + :meth:`.Query.tuples` - returns tuples, but also at the typing + level will type results as ``Tuple``. + + :meth:`_query.Query.is_single_entity` + + :meth:`_engine.Result.tuples` - v2 comparable method. + + """ + self.load_options += dict(_only_return_tuples=value) + return self + + @property + def is_single_entity(self) -> bool: + """Indicates if this :class:`_query.Query` + returns tuples or single entities. + + Returns True if this query returns a single entity for each instance + in its result list, and False if this query returns a tuple of entities + for each result. + + .. versionadded:: 1.3.11 + + .. seealso:: + + :meth:`_query.Query.only_return_tuples` + + """ + return ( + not self.load_options._only_return_tuples + and len(self._raw_columns) == 1 + and "parententity" in self._raw_columns[0]._annotations + and isinstance( + self._raw_columns[0]._annotations["parententity"], + ORMColumnsClauseRole, + ) + ) + + @_generative + def enable_eagerloads(self, value: bool) -> Self: + """Control whether or not eager joins and subqueries are + rendered. + + When set to False, the returned Query will not render + eager joins regardless of :func:`~sqlalchemy.orm.joinedload`, + :func:`~sqlalchemy.orm.subqueryload` options + or mapper-level ``lazy='joined'``/``lazy='subquery'`` + configurations. + + This is used primarily when nesting the Query's + statement into a subquery or other + selectable, or when using :meth:`_query.Query.yield_per`. + + """ + self._compile_options += {"_enable_eagerloads": value} + return self + + @_generative + def _with_compile_options(self, **opt: Any) -> Self: + self._compile_options += opt + return self + + @util.became_legacy_20( + ":meth:`_orm.Query.with_labels` and :meth:`_orm.Query.apply_labels`", + alternative="Use set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) " + "instead.", + ) + def with_labels(self) -> Self: + return self.set_label_style( + SelectLabelStyle.LABEL_STYLE_TABLENAME_PLUS_COL + ) + + apply_labels = with_labels + + @property + def get_label_style(self) -> SelectLabelStyle: + """ + Retrieve the current label style. + + .. versionadded:: 1.4 + + .. seealso:: + + :meth:`_sql.Select.get_label_style` - v2 equivalent method. + + """ + return self._label_style + + def set_label_style(self, style: SelectLabelStyle) -> Self: + """Apply column labels to the return value of Query.statement. + + Indicates that this Query's `statement` accessor should return + a SELECT statement that applies labels to all columns in the + form _; this is commonly used to + disambiguate columns from multiple tables which have the same + name. + + When the `Query` actually issues SQL to load rows, it always + uses column labeling. + + .. note:: The :meth:`_query.Query.set_label_style` method *only* applies + the output of :attr:`_query.Query.statement`, and *not* to any of + the result-row invoking systems of :class:`_query.Query` itself, + e.g. + :meth:`_query.Query.first`, :meth:`_query.Query.all`, etc. + To execute + a query using :meth:`_query.Query.set_label_style`, invoke the + :attr:`_query.Query.statement` using :meth:`.Session.execute`:: + + result = session.execute( + query + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + .statement + ) + + .. versionadded:: 1.4 + + + .. seealso:: + + :meth:`_sql.Select.set_label_style` - v2 equivalent method. + + """ # noqa + if self._label_style is not style: + self = self._generate() + self._label_style = style + return self + + @_generative + def enable_assertions(self, value: bool) -> Self: + """Control whether assertions are generated. + + When set to False, the returned Query will + not assert its state before certain operations, + including that LIMIT/OFFSET has not been applied + when filter() is called, no criterion exists + when get() is called, and no "from_statement()" + exists when filter()/order_by()/group_by() etc. + is called. This more permissive mode is used by + custom Query subclasses to specify criterion or + other modifiers outside of the usual usage patterns. + + Care should be taken to ensure that the usage + pattern is even possible. A statement applied + by from_statement() will override any criterion + set by filter() or order_by(), for example. + + """ + self._enable_assertions = value + return self + + @property + def whereclause(self) -> Optional[ColumnElement[bool]]: + """A readonly attribute which returns the current WHERE criterion for + this Query. + + This returned value is a SQL expression construct, or ``None`` if no + criterion has been established. + + .. seealso:: + + :attr:`_sql.Select.whereclause` - v2 equivalent property. + + """ + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) + + @_generative + def _with_current_path(self, path: PathRegistry) -> Self: + """indicate that this query applies to objects loaded + within a certain path. + + Used by deferred loaders (see strategies.py) which transfer + query options from an originating query to a newly generated + query intended for the deferred load. + + """ + self._compile_options += {"_current_path": path} + return self + + @_generative + def yield_per(self, count: int) -> Self: + r"""Yield only ``count`` rows at a time. + + The purpose of this method is when fetching very large result sets + (> 10K rows), to batch results in sub-collections and yield them + out partially, so that the Python interpreter doesn't need to declare + very large areas of memory which is both time consuming and leads + to excessive memory use. The performance from fetching hundreds of + thousands of rows can often double when a suitable yield-per setting + (e.g. approximately 1000) is used, even with DBAPIs that buffer + rows (which are most). + + As of SQLAlchemy 1.4, the :meth:`_orm.Query.yield_per` method is + equivalent to using the ``yield_per`` execution option at the ORM + level. See the section :ref:`orm_queryguide_yield_per` for further + background on this option. + + .. seealso:: + + :ref:`orm_queryguide_yield_per` + + """ + self.load_options += {"_yield_per": count} + return self + + @util.became_legacy_20( + ":meth:`_orm.Query.get`", + alternative="The method is now available as :meth:`_orm.Session.get`", + ) + def get(self, ident: _PKIdentityArgument) -> Optional[Any]: + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + E.g.:: + + my_user = session.query(User).get(5) + + some_object = session.query(VersionedFoo).get((5, 10)) + + some_object = session.query(VersionedFoo).get( + {"id": 5, "version_id": 10}) + + :meth:`_query.Query.get` is special in that it provides direct + access to the identity map of the owning :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`_query.Query.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :meth:`_query.Query.get` is only used to return a single + mapped instance, not multiple instances or + individual column constructs, and strictly + on a single primary key value. The originating + :class:`_query.Query` must be constructed in this way, + i.e. against a single mapped entity, + with no additional filtering criterion. Loading + options via :meth:`_query.Query.options` may be applied + however, and will be used if the object is not + yet locally present. + + :param ident: A scalar, tuple, or dictionary representing the + primary key. For a composite (e.g. multiple column) primary key, + a tuple or dictionary should be passed. + + For a single-column primary key, the scalar calling form is typically + the most expedient. If the primary key of a row is the value "5", + the call looks like:: + + my_object = query.get(5) + + The tuple form contains primary key values typically in + the order in which they correspond to the mapped + :class:`_schema.Table` + object's primary key columns, or if the + :paramref:`_orm.Mapper.primary_key` configuration parameter were + used, in + the order used for that parameter. For example, if the primary key + of a row is represented by the integer + digits "5, 10" the call would look like:: + + my_object = query.get((5, 10)) + + The dictionary form should include as keys the mapped attribute names + corresponding to each element of the primary key. If the mapped class + has the attributes ``id``, ``version_id`` as the attributes which + store the object's primary key value, the call would look like:: + + my_object = query.get({"id": 5, "version_id": 10}) + + .. versionadded:: 1.3 the :meth:`_query.Query.get` + method now optionally + accepts a dictionary of attribute names to values in order to + indicate a primary key identifier. + + + :return: The object instance, or ``None``. + + """ + self._no_criterion_assertion("get", order_by=False, distinct=False) + + # we still implement _get_impl() so that baked query can override + # it + return self._get_impl(ident, loading.load_on_pk_identity) + + def _get_impl( + self, + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., Any], + identity_token: Optional[Any] = None, + ) -> Optional[Any]: + mapper = self._only_full_mapper_zero("get") + return self.session._get_impl( + mapper, + primary_key_identity, + db_load_fn, + populate_existing=self.load_options._populate_existing, + with_for_update=self._for_update_arg, + options=self._with_options, + identity_token=identity_token, + execution_options=self._execution_options, + ) + + @property + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: + """An :class:`.InstanceState` that is using this :class:`_query.Query` + for a lazy load operation. + + .. deprecated:: 1.4 This attribute should be viewed via the + :attr:`.ORMExecuteState.lazy_loaded_from` attribute, within + the context of the :meth:`.SessionEvents.do_orm_execute` + event. + + .. seealso:: + + :attr:`.ORMExecuteState.lazy_loaded_from` + + """ + return self.load_options._lazy_loaded_from # type: ignore + + @property + def _current_path(self) -> PathRegistry: + return self._compile_options._current_path # type: ignore + + @_generative + def correlate( + self, + *fromclauses: Union[Literal[None, False], _FromClauseArgument], + ) -> Self: + """Return a :class:`.Query` construct which will correlate the given + FROM clauses to that of an enclosing :class:`.Query` or + :func:`~.expression.select`. + + The method here accepts mapped classes, :func:`.aliased` constructs, + and :class:`_orm.Mapper` constructs as arguments, which are resolved + into expression constructs, in addition to appropriate expression + constructs. + + The correlation arguments are ultimately passed to + :meth:`_expression.Select.correlate` + after coercion to expression constructs. + + The correlation arguments take effect in such cases + as when :meth:`_query.Query.from_self` is used, or when + a subquery as returned by :meth:`_query.Query.subquery` is + embedded in another :func:`_expression.select` construct. + + .. seealso:: + + :meth:`_sql.Select.correlate` - v2 equivalent method. + + """ + + self._auto_correlate = False + if fromclauses and fromclauses[0] in {None, False}: + self._correlate = () + else: + self._correlate = self._correlate + tuple( + coercions.expect(roles.FromClauseRole, f) for f in fromclauses + ) + return self + + @_generative + def autoflush(self, setting: bool) -> Self: + """Return a Query with a specific 'autoflush' setting. + + As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method + is equivalent to using the ``autoflush`` execution option at the + ORM level. See the section :ref:`orm_queryguide_autoflush` for + further background on this option. + + """ + self.load_options += {"_autoflush": setting} + return self + + @_generative + def populate_existing(self) -> Self: + """Return a :class:`_query.Query` + that will expire and refresh all instances + as they are loaded, or reused from the current :class:`.Session`. + + As of SQLAlchemy 1.4, the :meth:`_orm.Query.populate_existing` method + is equivalent to using the ``populate_existing`` execution option at + the ORM level. See the section :ref:`orm_queryguide_populate_existing` + for further background on this option. + + """ + self.load_options += {"_populate_existing": True} + return self + + @_generative + def _with_invoke_all_eagers(self, value: bool) -> Self: + """Set the 'invoke all eagers' flag which causes joined- and + subquery loaders to traverse into already-loaded related objects + and collections. + + Default is that of :attr:`_query.Query._invoke_all_eagers`. + + """ + self.load_options += {"_invoke_all_eagers": value} + return self + + @util.became_legacy_20( + ":meth:`_orm.Query.with_parent`", + alternative="Use the :func:`_orm.with_parent` standalone construct.", + ) + @util.preload_module("sqlalchemy.orm.relationships") + def with_parent( + self, + instance: object, + property: Optional[ # noqa: A002 + attributes.QueryableAttribute[Any] + ] = None, + from_entity: Optional[_ExternalEntityType[Any]] = None, + ) -> Self: + """Add filtering criterion that relates the given instance + to a child object or collection, using its attribute state + as well as an established :func:`_orm.relationship()` + configuration. + + The method uses the :func:`.with_parent` function to generate + the clause, the result of which is passed to + :meth:`_query.Query.filter`. + + Parameters are the same as :func:`.with_parent`, with the exception + that the given property can be None, in which case a search is + performed against this :class:`_query.Query` object's target mapper. + + :param instance: + An instance which has some :func:`_orm.relationship`. + + :param property: + Class bound attribute which indicates + what relationship from the instance should be used to reconcile the + parent/child relationship. + + :param from_entity: + Entity in which to consider as the left side. This defaults to the + "zero" entity of the :class:`_query.Query` itself. + + """ + relationships = util.preloaded.orm_relationships + + if from_entity: + entity_zero = inspect(from_entity) + else: + entity_zero = _legacy_filter_by_entity_zero(self) + if property is None: + # TODO: deprecate, property has to be supplied + mapper = object_mapper(instance) + + for prop in mapper.iterate_properties: + if ( + isinstance(prop, relationships.RelationshipProperty) + and prop.mapper is entity_zero.mapper # type: ignore + ): + property = prop # type: ignore # noqa: A001 + break + else: + raise sa_exc.InvalidRequestError( + "Could not locate a property which relates instances " + "of class '%s' to instances of class '%s'" + % ( + entity_zero.mapper.class_.__name__, # type: ignore + instance.__class__.__name__, + ) + ) + + return self.filter( + with_parent( + instance, + property, # type: ignore + entity_zero.entity, # type: ignore + ) + ) + + @_generative + def add_entity( + self, + entity: _EntityType[Any], + alias: Optional[Union[Alias, Subquery]] = None, + ) -> Query[Any]: + """add a mapped entity to the list of result columns + to be returned. + + .. seealso:: + + :meth:`_sql.Select.add_columns` - v2 comparable method. + """ + + if alias is not None: + # TODO: deprecate + entity = AliasedClass(entity, alias) + + self._raw_columns = list(self._raw_columns) + + self._raw_columns.append( + coercions.expect( + roles.ColumnsClauseRole, entity, apply_propagate_attrs=self + ) + ) + return self + + @_generative + def with_session(self, session: Session) -> Self: + """Return a :class:`_query.Query` that will use the given + :class:`.Session`. + + While the :class:`_query.Query` + object is normally instantiated using the + :meth:`.Session.query` method, it is legal to build the + :class:`_query.Query` + directly without necessarily using a :class:`.Session`. Such a + :class:`_query.Query` object, or any :class:`_query.Query` + already associated + with a different :class:`.Session`, can produce a new + :class:`_query.Query` + object associated with a target session using this method:: + + from sqlalchemy.orm import Query + + query = Query([MyClass]).filter(MyClass.id == 5) + + result = query.with_session(my_session).one() + + """ + + self.session = session + return self + + def _legacy_from_self( + self, *entities: _ColumnsClauseArgument[Any] + ) -> Self: + # used for query.count() as well as for the same + # function in BakedQuery, as well as some old tests in test_baked.py. + + fromclause = ( + self.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + .correlate(None) + .subquery() + ._anonymous_fromclause() + ) + + q = self._from_selectable(fromclause) + + if entities: + q._set_entities(entities) + return q + + @_generative + def _set_enable_single_crit(self, val: bool) -> Self: + self._compile_options += {"_enable_single_crit": val} + return self + + @_generative + def _from_selectable( + self, fromclause: FromClause, set_entity_from: bool = True + ) -> Self: + for attr in ( + "_where_criteria", + "_order_by_clauses", + "_group_by_clauses", + "_limit_clause", + "_offset_clause", + "_last_joined_entity", + "_setup_joins", + "_memoized_select_entities", + "_distinct", + "_distinct_on", + "_having_criteria", + "_prefixes", + "_suffixes", + ): + self.__dict__.pop(attr, None) + self._set_select_from([fromclause], set_entity_from) + self._compile_options += { + "_enable_single_crit": False, + } + + return self + + @util.deprecated( + "1.4", + ":meth:`_query.Query.values` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.with_entities`", + ) + def values(self, *columns: _ColumnsClauseArgument[Any]) -> Iterable[Any]: + """Return an iterator yielding result tuples corresponding + to the given list of columns + + """ + return self._values_no_warn(*columns) + + _values = values + + def _values_no_warn( + self, *columns: _ColumnsClauseArgument[Any] + ) -> Iterable[Any]: + if not columns: + return iter(()) + q = self._clone().enable_eagerloads(False) + q._set_entities(columns) + if not q.load_options._yield_per: + q.load_options += {"_yield_per": 10} + return iter(q) + + @util.deprecated( + "1.4", + ":meth:`_query.Query.value` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.with_entities` " + "in combination with :meth:`_query.Query.scalar`", + ) + def value(self, column: _ColumnExpressionArgument[Any]) -> Any: + """Return a scalar result corresponding to the given + column expression. + + """ + try: + return next(self._values_no_warn(column))[0] # type: ignore + except StopIteration: + return None + + @overload + def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: ... + + @overload + def with_entities( + self, + _colexpr: roles.TypedColumnsClauseRole[_T], + ) -> RowReturningQuery[Tuple[_T]]: ... + + # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... + + @overload + def with_entities( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + + @overload + def with_entities( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... + + # END OVERLOADED FUNCTIONS self.with_entities + + @overload + def with_entities( + self, *entities: _ColumnsClauseArgument[Any] + ) -> Query[Any]: ... + + @_generative + def with_entities( + self, *entities: _ColumnsClauseArgument[Any], **__kw: Any + ) -> Query[Any]: + r"""Return a new :class:`_query.Query` + replacing the SELECT list with the + given entities. + + e.g.:: + + # Users, filtered on some arbitrary criterion + # and then ordered by related email address + q = session.query(User).\ + join(User.address).\ + filter(User.name.like('%ed%')).\ + order_by(Address.email) + + # given *only* User.id==5, Address.email, and 'q', what + # would the *next* User in the result be ? + subq = q.with_entities(Address.email).\ + order_by(None).\ + filter(User.id==5).\ + subquery() + q = q.join((subq, subq.c.email < Address.email)).\ + limit(1) + + .. seealso:: + + :meth:`_sql.Select.with_only_columns` - v2 comparable method. + """ + if __kw: + raise _no_kw() + + # Query has all the same fields as Select for this operation + # this could in theory be based on a protocol but not sure if it's + # worth it + _MemoizedSelectEntities._generate_for_statement(self) # type: ignore + self._set_entities(entities) + return self + + @_generative + def add_columns( + self, *column: _ColumnExpressionArgument[Any] + ) -> Query[Any]: + """Add one or more column expressions to the list + of result columns to be returned. + + .. seealso:: + + :meth:`_sql.Select.add_columns` - v2 comparable method. + """ + + self._raw_columns = list(self._raw_columns) + + self._raw_columns.extend( + coercions.expect( + roles.ColumnsClauseRole, + c, + apply_propagate_attrs=self, + post_inspect=True, + ) + for c in column + ) + return self + + @util.deprecated( + "1.4", + ":meth:`_query.Query.add_column` " + "is deprecated and will be removed in a " + "future release. Please use :meth:`_query.Query.add_columns`", + ) + def add_column(self, column: _ColumnExpressionArgument[Any]) -> Query[Any]: + """Add a column expression to the list of result columns to be + returned. + + """ + return self.add_columns(column) + + @_generative + def options(self, *args: ExecutableOption) -> Self: + """Return a new :class:`_query.Query` object, + applying the given list of + mapper options. + + Most supplied options regard changing how column- and + relationship-mapped attributes are loaded. + + .. seealso:: + + :ref:`loading_columns` + + :ref:`relationship_loader_options` + + """ + + opts = tuple(util.flatten_iterator(args)) + if self._compile_options._current_path: + # opting for lower method overhead for the checks + for opt in opts: + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query_conditionally(self) # type: ignore + else: + for opt in opts: + if not opt._is_core and opt._is_legacy_option: # type: ignore + opt.process_query(self) # type: ignore + + self._with_options += opts + return self + + def with_transformation( + self, fn: Callable[[Query[Any]], Query[Any]] + ) -> Query[Any]: + """Return a new :class:`_query.Query` object transformed by + the given function. + + E.g.:: + + def filter_something(criterion): + def transform(q): + return q.filter(criterion) + return transform + + q = q.with_transformation(filter_something(x==5)) + + This allows ad-hoc recipes to be created for :class:`_query.Query` + objects. + + """ + return fn(self) + + def get_execution_options(self) -> _ImmutableExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded:: 1.3 + + .. seealso:: + + :meth:`_query.Query.execution_options` + + :meth:`_sql.Select.get_execution_options` - v2 comparable method. + + """ + return self._execution_options + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + populate_existing: bool = False, + autoflush: bool = False, + preserve_rowcount: bool = False, + **opt: Any, + ) -> Self: ... + + @overload + def execution_options(self, **opt: Any) -> Self: ... + + @_generative + def execution_options(self, **kwargs: Any) -> Self: + """Set non-SQL options which take effect during execution. + + Options allowed here include all of those accepted by + :meth:`_engine.Connection.execution_options`, as well as a series + of ORM specific options: + + ``populate_existing=True`` - equivalent to using + :meth:`_orm.Query.populate_existing` + + ``autoflush=True|False`` - equivalent to using + :meth:`_orm.Query.autoflush` + + ``yield_per=`` - equivalent to using + :meth:`_orm.Query.yield_per` + + Note that the ``stream_results`` execution option is enabled + automatically if the :meth:`~sqlalchemy.orm.query.Query.yield_per()` + method or execution option is used. + + .. versionadded:: 1.4 - added ORM options to + :meth:`_orm.Query.execution_options` + + The execution options may also be specified on a per execution basis + when using :term:`2.0 style` queries via the + :paramref:`_orm.Session.execution_options` parameter. + + .. warning:: The + :paramref:`_engine.Connection.execution_options.stream_results` + parameter should not be used at the level of individual ORM + statement executions, as the :class:`_orm.Session` will not track + objects from different schema translate maps within a single + session. For multiple schema translate maps within the scope of a + single :class:`_orm.Session`, see :ref:`examples_sharding`. + + + .. seealso:: + + :ref:`engine_stream_results` + + :meth:`_query.Query.get_execution_options` + + :meth:`_sql.Select.execution_options` - v2 equivalent method. + + """ + self._execution_options = self._execution_options.union(kwargs) + return self + + @_generative + def with_for_update( + self, + *, + nowait: bool = False, + read: bool = False, + of: Optional[_ForUpdateOfArgument] = None, + skip_locked: bool = False, + key_share: bool = False, + ) -> Self: + """return a new :class:`_query.Query` + with the specified options for the + ``FOR UPDATE`` clause. + + The behavior of this method is identical to that of + :meth:`_expression.GenerativeSelect.with_for_update`. + When called with no arguments, + the resulting ``SELECT`` statement will have a ``FOR UPDATE`` clause + appended. When additional arguments are specified, backend-specific + options such as ``FOR UPDATE NOWAIT`` or ``LOCK IN SHARE MODE`` + can take effect. + + E.g.:: + + q = sess.query(User).populate_existing().with_for_update(nowait=True, of=User) + + The above query on a PostgreSQL backend will render like:: + + SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT + + .. warning:: + + Using ``with_for_update`` in the context of eager loading + relationships is not officially supported or recommended by + SQLAlchemy and may not work with certain queries on various + database backends. When ``with_for_update`` is successfully used + with a query that involves :func:`_orm.joinedload`, SQLAlchemy will + attempt to emit SQL that locks all involved tables. + + .. note:: It is generally a good idea to combine the use of the + :meth:`_orm.Query.populate_existing` method when using the + :meth:`_orm.Query.with_for_update` method. The purpose of + :meth:`_orm.Query.populate_existing` is to force all the data read + from the SELECT to be populated into the ORM objects returned, + even if these objects are already in the :term:`identity map`. + + .. seealso:: + + :meth:`_expression.GenerativeSelect.with_for_update` + - Core level method with + full argument and behavioral description. + + :meth:`_orm.Query.populate_existing` - overwrites attributes of + objects already loaded in the identity map. + + """ # noqa: E501 + + self._for_update_arg = ForUpdateArg( + read=read, + nowait=nowait, + of=of, + skip_locked=skip_locked, + key_share=key_share, + ) + return self + + @_generative + def params( + self, __params: Optional[Dict[str, Any]] = None, **kw: Any + ) -> Self: + r"""Add values for bind parameters which may have been + specified in filter(). + + Parameters may be specified using \**kwargs, or optionally a single + dictionary as the first positional argument. The reason for both is + that \**kwargs is convenient, however some parameter dictionaries + contain unicode keys in which case \**kwargs cannot be used. + + """ + if __params: + kw.update(__params) + self._params = self._params.union(kw) + return self + + def where(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: + """A synonym for :meth:`.Query.filter`. + + .. versionadded:: 1.4 + + .. seealso:: + + :meth:`_sql.Select.where` - v2 equivalent method. + + """ + return self.filter(*criterion) + + @_generative + @_assertions(_no_statement_condition, _no_limit_offset) + def filter(self, *criterion: _ColumnExpressionArgument[bool]) -> Self: + r"""Apply the given filtering criterion to a copy + of this :class:`_query.Query`, using SQL expressions. + + e.g.:: + + session.query(MyClass).filter(MyClass.name == 'some name') + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + session.query(MyClass).\ + filter(MyClass.name == 'some name', MyClass.id > 5) + + The criterion is any SQL expression object applicable to the + WHERE clause of a select. String expressions are coerced + into SQL expression constructs via the :func:`_expression.text` + construct. + + .. seealso:: + + :meth:`_query.Query.filter_by` - filter on keyword expressions. + + :meth:`_sql.Select.where` - v2 equivalent method. + + """ + for crit in list(criterion): + crit = coercions.expect( + roles.WhereHavingRole, crit, apply_propagate_attrs=self + ) + + self._where_criteria += (crit,) + return self + + @util.memoized_property + def _last_joined_entity( + self, + ) -> Optional[Union[_InternalEntityType[Any], _JoinTargetElement]]: + if self._setup_joins: + return _determine_last_joined_entity( + self._setup_joins, + ) + else: + return None + + def _filter_by_zero(self) -> Any: + """for the filter_by() method, return the target entity for which + we will attempt to derive an expression from based on string name. + + """ + + if self._setup_joins: + _last_joined_entity = self._last_joined_entity + if _last_joined_entity is not None: + return _last_joined_entity + + # discussion related to #7239 + # special check determines if we should try to derive attributes + # for filter_by() from the "from object", i.e., if the user + # called query.select_from(some selectable).filter_by(some_attr=value). + # We don't want to do that in the case that methods like + # from_self(), select_entity_from(), or a set op like union() were + # called; while these methods also place a + # selectable in the _from_obj collection, they also set up + # the _set_base_alias boolean which turns on the whole "adapt the + # entity to this selectable" thing, meaning the query still continues + # to construct itself in terms of the lead entity that was passed + # to query(), e.g. query(User).from_self() is still in terms of User, + # and not the subquery that from_self() created. This feature of + # "implicitly adapt all occurrences of entity X to some arbitrary + # subquery" is the main thing I am trying to do away with in 2.0 as + # users should now used aliased() for that, but I can't entirely get + # rid of it due to query.union() and other set ops relying upon it. + # + # compare this to the base Select()._filter_by_zero() which can + # just return self._from_obj[0] if present, because there is no + # "_set_base_alias" feature. + # + # IOW, this conditional essentially detects if + # "select_from(some_selectable)" has been called, as opposed to + # "select_entity_from()", "from_self()" + # or "union() / some_set_op()". + if self._from_obj and not self._compile_options._set_base_alias: + return self._from_obj[0] + + return self._raw_columns[0] + + def filter_by(self, **kwargs: Any) -> Self: + r"""Apply the given filtering criterion to a copy + of this :class:`_query.Query`, using keyword expressions. + + e.g.:: + + session.query(MyClass).filter_by(name = 'some name') + + Multiple criteria may be specified as comma separated; the effect + is that they will be joined together using the :func:`.and_` + function:: + + session.query(MyClass).\ + filter_by(name = 'some name', id = 5) + + The keyword expressions are extracted from the primary + entity of the query, or the last entity that was the + target of a call to :meth:`_query.Query.join`. + + .. seealso:: + + :meth:`_query.Query.filter` - filter on SQL expressions. + + :meth:`_sql.Select.filter_by` - v2 comparable method. + + """ + from_entity = self._filter_by_zero() + + clauses = [ + _entity_namespace_key(from_entity, key) == value + for key, value in kwargs.items() + ] + return self.filter(*clauses) + + @_generative + def order_by( + self, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionOrStrLabelArgument[Any], + ) -> Self: + """Apply one or more ORDER BY criteria to the query and return + the newly resulting :class:`_query.Query`. + + e.g.:: + + q = session.query(Entity).order_by(Entity.id, Entity.name) + + Calling this method multiple times is equivalent to calling it once + with all the clauses concatenated. All existing ORDER BY criteria may + be cancelled by passing ``None`` by itself. New ORDER BY criteria may + then be added by invoking :meth:`_orm.Query.order_by` again, e.g.:: + + # will erase all ORDER BY and ORDER BY new_col alone + q = q.order_by(None).order_by(new_col) + + .. seealso:: + + These sections describe ORDER BY in terms of :term:`2.0 style` + invocation but apply to :class:`_orm.Query` as well: + + :ref:`tutorial_order_by` - in the :ref:`unified_tutorial` + + :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial` + + :meth:`_sql.Select.order_by` - v2 equivalent method. + + """ + + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("order_by") + + if not clauses and (__first is None or __first is False): + self._order_by_clauses = () + elif __first is not _NoArg.NO_ARG: + criterion = tuple( + coercions.expect(roles.OrderByRole, clause) + for clause in (__first,) + clauses + ) + self._order_by_clauses += criterion + + return self + + @_generative + def group_by( + self, + __first: Union[ + Literal[None, False, _NoArg.NO_ARG], + _ColumnExpressionOrStrLabelArgument[Any], + ] = _NoArg.NO_ARG, + *clauses: _ColumnExpressionOrStrLabelArgument[Any], + ) -> Self: + """Apply one or more GROUP BY criterion to the query and return + the newly resulting :class:`_query.Query`. + + All existing GROUP BY settings can be suppressed by + passing ``None`` - this will suppress any GROUP BY configured + on mappers as well. + + .. seealso:: + + These sections describe GROUP BY in terms of :term:`2.0 style` + invocation but apply to :class:`_orm.Query` as well: + + :ref:`tutorial_group_by_w_aggregates` - in the + :ref:`unified_tutorial` + + :ref:`tutorial_order_by_label` - in the :ref:`unified_tutorial` + + :meth:`_sql.Select.group_by` - v2 equivalent method. + + """ + + for assertion in (self._no_statement_condition, self._no_limit_offset): + assertion("group_by") + + if not clauses and (__first is None or __first is False): + self._group_by_clauses = () + elif __first is not _NoArg.NO_ARG: + criterion = tuple( + coercions.expect(roles.GroupByRole, clause) + for clause in (__first,) + clauses + ) + self._group_by_clauses += criterion + return self + + @_generative + @_assertions(_no_statement_condition, _no_limit_offset) + def having(self, *having: _ColumnExpressionArgument[bool]) -> Self: + r"""Apply a HAVING criterion to the query and return the + newly resulting :class:`_query.Query`. + + :meth:`_query.Query.having` is used in conjunction with + :meth:`_query.Query.group_by`. + + HAVING criterion makes it possible to use filters on aggregate + functions like COUNT, SUM, AVG, MAX, and MIN, eg.:: + + q = session.query(User.id).\ + join(User.addresses).\ + group_by(User.id).\ + having(func.count(Address.id) > 2) + + .. seealso:: + + :meth:`_sql.Select.having` - v2 equivalent method. + + """ + + for criterion in having: + having_criteria = coercions.expect( + roles.WhereHavingRole, criterion + ) + self._having_criteria += (having_criteria,) + return self + + def _set_op(self, expr_fn: Any, *q: Query[Any]) -> Self: + list_of_queries = (self,) + q + return self._from_selectable(expr_fn(*(list_of_queries)).subquery()) + + def union(self, *q: Query[Any]) -> Self: + """Produce a UNION of this Query against one or more queries. + + e.g.:: + + q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar') + q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo') + + q3 = q1.union(q2) + + The method accepts multiple Query objects so as to control + the level of nesting. A series of ``union()`` calls such as:: + + x.union(y).union(z).all() + + will nest on each ``union()``, and produces:: + + SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION + SELECT * FROM y) UNION SELECT * FROM Z) + + Whereas:: + + x.union(y, z).all() + + produces:: + + SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION + SELECT * FROM Z) + + Note that many database backends do not allow ORDER BY to + be rendered on a query called within UNION, EXCEPT, etc. + To disable all ORDER BY clauses including those configured + on mappers, issue ``query.order_by(None)`` - the resulting + :class:`_query.Query` object will not render ORDER BY within + its SELECT statement. + + .. seealso:: + + :meth:`_sql.Select.union` - v2 equivalent method. + + """ + return self._set_op(expression.union, *q) + + def union_all(self, *q: Query[Any]) -> Self: + """Produce a UNION ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.union_all` - v2 equivalent method. + + """ + return self._set_op(expression.union_all, *q) + + def intersect(self, *q: Query[Any]) -> Self: + """Produce an INTERSECT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.intersect` - v2 equivalent method. + + """ + return self._set_op(expression.intersect, *q) + + def intersect_all(self, *q: Query[Any]) -> Self: + """Produce an INTERSECT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.intersect_all` - v2 equivalent method. + + """ + return self._set_op(expression.intersect_all, *q) + + def except_(self, *q: Query[Any]) -> Self: + """Produce an EXCEPT of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.except_` - v2 equivalent method. + + """ + return self._set_op(expression.except_, *q) + + def except_all(self, *q: Query[Any]) -> Self: + """Produce an EXCEPT ALL of this Query against one or more queries. + + Works the same way as :meth:`~sqlalchemy.orm.query.Query.union`. See + that method for usage examples. + + .. seealso:: + + :meth:`_sql.Select.except_all` - v2 equivalent method. + + """ + return self._set_op(expression.except_all, *q) + + @_generative + @_assertions(_no_statement_condition, _no_limit_offset) + def join( + self, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + isouter: bool = False, + full: bool = False, + ) -> Self: + r"""Create a SQL JOIN against this :class:`_query.Query` + object's criterion + and apply generatively, returning the newly resulting + :class:`_query.Query`. + + **Simple Relationship Joins** + + Consider a mapping between two classes ``User`` and ``Address``, + with a relationship ``User.addresses`` representing a collection + of ``Address`` objects associated with each ``User``. The most + common usage of :meth:`_query.Query.join` + is to create a JOIN along this + relationship, using the ``User.addresses`` attribute as an indicator + for how this should occur:: + + q = session.query(User).join(User.addresses) + + Where above, the call to :meth:`_query.Query.join` along + ``User.addresses`` will result in SQL approximately equivalent to:: + + SELECT user.id, user.name + FROM user JOIN address ON user.id = address.user_id + + In the above example we refer to ``User.addresses`` as passed to + :meth:`_query.Query.join` as the "on clause", that is, it indicates + how the "ON" portion of the JOIN should be constructed. + + To construct a chain of joins, multiple :meth:`_query.Query.join` + calls may be used. The relationship-bound attribute implies both + the left and right side of the join at once:: + + q = session.query(User).\ + join(User.orders).\ + join(Order.items).\ + join(Item.keywords) + + .. note:: as seen in the above example, **the order in which each + call to the join() method occurs is important**. Query would not, + for example, know how to join correctly if we were to specify + ``User``, then ``Item``, then ``Order``, in our chain of joins; in + such a case, depending on the arguments passed, it may raise an + error that it doesn't know how to join, or it may produce invalid + SQL in which case the database will raise an error. In correct + practice, the + :meth:`_query.Query.join` method is invoked in such a way that lines + up with how we would want the JOIN clauses in SQL to be + rendered, and each call should represent a clear link from what + precedes it. + + **Joins to a Target Entity or Selectable** + + A second form of :meth:`_query.Query.join` allows any mapped entity or + core selectable construct as a target. In this usage, + :meth:`_query.Query.join` will attempt to create a JOIN along the + natural foreign key relationship between two entities:: + + q = session.query(User).join(Address) + + In the above calling form, :meth:`_query.Query.join` is called upon to + create the "on clause" automatically for us. This calling form will + ultimately raise an error if either there are no foreign keys between + the two entities, or if there are multiple foreign key linkages between + the target entity and the entity or entities already present on the + left side such that creating a join requires more information. Note + that when indicating a join to a target without any ON clause, ORM + configured relationships are not taken into account. + + **Joins to a Target with an ON Clause** + + The third calling form allows both the target entity as well + as the ON clause to be passed explicitly. A example that includes + a SQL expression as the ON clause is as follows:: + + q = session.query(User).join(Address, User.id==Address.user_id) + + The above form may also use a relationship-bound attribute as the + ON clause as well:: + + q = session.query(User).join(Address, User.addresses) + + The above syntax can be useful for the case where we wish + to join to an alias of a particular target entity. If we wanted + to join to ``Address`` twice, it could be achieved using two + aliases set up using the :func:`~sqlalchemy.orm.aliased` function:: + + a1 = aliased(Address) + a2 = aliased(Address) + + q = session.query(User).\ + join(a1, User.addresses).\ + join(a2, User.addresses).\ + filter(a1.email_address=='ed@foo.com').\ + filter(a2.email_address=='ed@bar.com') + + The relationship-bound calling form can also specify a target entity + using the :meth:`_orm.PropComparator.of_type` method; a query + equivalent to the one above would be:: + + a1 = aliased(Address) + a2 = aliased(Address) + + q = session.query(User).\ + join(User.addresses.of_type(a1)).\ + join(User.addresses.of_type(a2)).\ + filter(a1.email_address == 'ed@foo.com').\ + filter(a2.email_address == 'ed@bar.com') + + **Augmenting Built-in ON Clauses** + + As a substitute for providing a full custom ON condition for an + existing relationship, the :meth:`_orm.PropComparator.and_` function + may be applied to a relationship attribute to augment additional + criteria into the ON clause; the additional criteria will be combined + with the default criteria using AND:: + + q = session.query(User).join( + User.addresses.and_(Address.email_address != 'foo@bar.com') + ) + + .. versionadded:: 1.4 + + **Joining to Tables and Subqueries** + + + The target of a join may also be any table or SELECT statement, + which may be related to a target entity or not. Use the + appropriate ``.subquery()`` method in order to make a subquery + out of a query:: + + subq = session.query(Address).\ + filter(Address.email_address == 'ed@foo.com').\ + subquery() + + + q = session.query(User).join( + subq, User.id == subq.c.user_id + ) + + Joining to a subquery in terms of a specific relationship and/or + target entity may be achieved by linking the subquery to the + entity using :func:`_orm.aliased`:: + + subq = session.query(Address).\ + filter(Address.email_address == 'ed@foo.com').\ + subquery() + + address_subq = aliased(Address, subq) + + q = session.query(User).join( + User.addresses.of_type(address_subq) + ) + + + **Controlling what to Join From** + + In cases where the left side of the current state of + :class:`_query.Query` is not in line with what we want to join from, + the :meth:`_query.Query.select_from` method may be used:: + + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') + + Which will produce SQL similar to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + .. seealso:: + + :meth:`_sql.Select.join` - v2 equivalent method. + + :param \*props: Incoming arguments for :meth:`_query.Query.join`, + the props collection in modern use should be considered to be a one + or two argument form, either as a single "target" entity or ORM + attribute-bound relationship, or as a target entity plus an "on + clause" which may be a SQL expression or ORM attribute-bound + relationship. + + :param isouter=False: If True, the join used will be a left outer join, + just as if the :meth:`_query.Query.outerjoin` method were called. + + :param full=False: render FULL OUTER JOIN; implies ``isouter``. + + """ + + join_target = coercions.expect( + roles.JoinTargetRole, + target, + apply_propagate_attrs=self, + legacy=True, + ) + if onclause is not None: + onclause_element = coercions.expect( + roles.OnClauseRole, onclause, legacy=True + ) + else: + onclause_element = None + + self._setup_joins += ( + ( + join_target, + onclause_element, + None, + { + "isouter": isouter, + "full": full, + }, + ), + ) + + self.__dict__.pop("_last_joined_entity", None) + return self + + def outerjoin( + self, + target: _JoinTargetArgument, + onclause: Optional[_OnClauseArgument] = None, + *, + full: bool = False, + ) -> Self: + """Create a left outer join against this ``Query`` object's criterion + and apply generatively, returning the newly resulting ``Query``. + + Usage is the same as the ``join()`` method. + + .. seealso:: + + :meth:`_sql.Select.outerjoin` - v2 equivalent method. + + """ + return self.join(target, onclause=onclause, isouter=True, full=full) + + @_generative + @_assertions(_no_statement_condition) + def reset_joinpoint(self) -> Self: + """Return a new :class:`.Query`, where the "join point" has + been reset back to the base FROM entities of the query. + + This method is usually used in conjunction with the + ``aliased=True`` feature of the :meth:`~.Query.join` + method. See the example in :meth:`~.Query.join` for how + this is used. + + """ + self._last_joined_entity = None + + return self + + @_generative + @_assertions(_no_clauseelement_condition) + def select_from(self, *from_obj: _FromClauseArgument) -> Self: + r"""Set the FROM clause of this :class:`.Query` explicitly. + + :meth:`.Query.select_from` is often used in conjunction with + :meth:`.Query.join` in order to control which entity is selected + from on the "left" side of the join. + + The entity or selectable object here effectively replaces the + "left edge" of any calls to :meth:`~.Query.join`, when no + joinpoint is otherwise established - usually, the default "join + point" is the leftmost entity in the :class:`~.Query` object's + list of entities to be selected. + + A typical example:: + + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') + + Which produces SQL equivalent to:: + + SELECT address.* FROM user + JOIN address ON user.id=address.user_id + WHERE user.name = :name_1 + + :param \*from_obj: collection of one or more entities to apply + to the FROM clause. Entities can be mapped classes, + :class:`.AliasedClass` objects, :class:`.Mapper` objects + as well as core :class:`.FromClause` elements like subqueries. + + .. seealso:: + + :meth:`~.Query.join` + + :meth:`.Query.select_entity_from` + + :meth:`_sql.Select.select_from` - v2 equivalent method. + + """ + + self._set_select_from(from_obj, False) + return self + + def __getitem__(self, item: Any) -> Any: + return orm_util._getitem( + self, + item, + ) + + @_generative + @_assertions(_no_statement_condition) + def slice( + self, + start: int, + stop: int, + ) -> Self: + """Computes the "slice" of the :class:`_query.Query` represented by + the given indices and returns the resulting :class:`_query.Query`. + + The start and stop indices behave like the argument to Python's + built-in :func:`range` function. This method provides an + alternative to using ``LIMIT``/``OFFSET`` to get a slice of the + query. + + For example, :: + + session.query(User).order_by(User.id).slice(1, 3) + + renders as + + .. sourcecode:: sql + + SELECT users.id AS users_id, + users.name AS users_name + FROM users ORDER BY users.id + LIMIT ? OFFSET ? + (2, 1) + + .. seealso:: + + :meth:`_query.Query.limit` + + :meth:`_query.Query.offset` + + :meth:`_sql.Select.slice` - v2 equivalent method. + + """ + + self._limit_clause, self._offset_clause = sql_util._make_slice( + self._limit_clause, self._offset_clause, start, stop + ) + return self + + @_generative + @_assertions(_no_statement_condition) + def limit(self, limit: _LimitOffsetType) -> Self: + """Apply a ``LIMIT`` to the query and return the newly resulting + ``Query``. + + .. seealso:: + + :meth:`_sql.Select.limit` - v2 equivalent method. + + """ + self._limit_clause = sql_util._offset_or_limit_clause(limit) + return self + + @_generative + @_assertions(_no_statement_condition) + def offset(self, offset: _LimitOffsetType) -> Self: + """Apply an ``OFFSET`` to the query and return the newly resulting + ``Query``. + + .. seealso:: + + :meth:`_sql.Select.offset` - v2 equivalent method. + """ + self._offset_clause = sql_util._offset_or_limit_clause(offset) + return self + + @_generative + @_assertions(_no_statement_condition) + def distinct(self, *expr: _ColumnExpressionArgument[Any]) -> Self: + r"""Apply a ``DISTINCT`` to the query and return the newly resulting + ``Query``. + + + .. note:: + + The ORM-level :meth:`.distinct` call includes logic that will + automatically add columns from the ORDER BY of the query to the + columns clause of the SELECT statement, to satisfy the common need + of the database backend that ORDER BY columns be part of the SELECT + list when DISTINCT is used. These columns *are not* added to the + list of columns actually fetched by the :class:`_query.Query`, + however, + so would not affect results. The columns are passed through when + using the :attr:`_query.Query.statement` accessor, however. + + .. deprecated:: 2.0 This logic is deprecated and will be removed + in SQLAlchemy 2.0. See :ref:`migration_20_query_distinct` + for a description of this use case in 2.0. + + .. seealso:: + + :meth:`_sql.Select.distinct` - v2 equivalent method. + + :param \*expr: optional column expressions. When present, + the PostgreSQL dialect will render a ``DISTINCT ON ()`` + construct. + + .. deprecated:: 1.4 Using \*expr in other dialects is deprecated + and will raise :class:`_exc.CompileError` in a future version. + + """ + if expr: + self._distinct = True + self._distinct_on = self._distinct_on + tuple( + coercions.expect(roles.ByOfRole, e) for e in expr + ) + else: + self._distinct = True + return self + + def all(self) -> List[_T]: + """Return the results represented by this :class:`_query.Query` + as a list. + + This results in an execution of the underlying SQL statement. + + .. warning:: The :class:`_query.Query` object, + when asked to return either + a sequence or iterator that consists of full ORM-mapped entities, + will **deduplicate entries based on primary key**. See the FAQ for + more details. + + .. seealso:: + + :ref:`faq_query_deduplicating` + + .. seealso:: + + :meth:`_engine.Result.all` - v2 comparable method. + + :meth:`_engine.Result.scalars` - v2 comparable method. + """ + return self._iter().all() # type: ignore + + @_generative + @_assertions(_no_clauseelement_condition) + def from_statement(self, statement: ExecutableReturnsRows) -> Self: + """Execute the given SELECT statement and return results. + + This method bypasses all internal statement compilation, and the + statement is executed without modification. + + The statement is typically either a :func:`_expression.text` + or :func:`_expression.select` construct, and should return the set + of columns + appropriate to the entity class represented by this + :class:`_query.Query`. + + .. seealso:: + + :meth:`_sql.Select.from_statement` - v2 comparable method. + + """ + statement = coercions.expect( + roles.SelectStatementRole, statement, apply_propagate_attrs=self + ) + self._statement = statement + return self + + def first(self) -> Optional[_T]: + """Return the first result of this ``Query`` or + None if the result doesn't contain any row. + + first() applies a limit of one within the generated SQL, so that + only one primary entity row is generated on the server side + (note this may consist of multiple result rows if join-loaded + collections are present). + + Calling :meth:`_query.Query.first` + results in an execution of the underlying + query. + + .. seealso:: + + :meth:`_query.Query.one` + + :meth:`_query.Query.one_or_none` + + :meth:`_engine.Result.first` - v2 comparable method. + + :meth:`_engine.Result.scalars` - v2 comparable method. + + """ + # replicates limit(1) behavior + if self._statement is not None: + return self._iter().first() # type: ignore + else: + return self.limit(1)._iter().first() # type: ignore + + def one_or_none(self) -> Optional[_T]: + """Return at most one result or raise an exception. + + Returns ``None`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. + + Calling :meth:`_query.Query.one_or_none` + results in an execution of the + underlying query. + + .. seealso:: + + :meth:`_query.Query.first` + + :meth:`_query.Query.one` + + :meth:`_engine.Result.one_or_none` - v2 comparable method. + + :meth:`_engine.Result.scalar_one_or_none` - v2 comparable method. + + """ + return self._iter().one_or_none() # type: ignore + + def one(self) -> _T: + """Return exactly one result or raise an exception. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. + + Calling :meth:`.one` results in an execution of the underlying query. + + .. seealso:: + + :meth:`_query.Query.first` + + :meth:`_query.Query.one_or_none` + + :meth:`_engine.Result.one` - v2 comparable method. + + :meth:`_engine.Result.scalar_one` - v2 comparable method. + + """ + return self._iter().one() # type: ignore + + def scalar(self) -> Any: + """Return the first element of the first result or None + if no rows present. If multiple rows are returned, + raises MultipleResultsFound. + + >>> session.query(Item).scalar() + + >>> session.query(Item.id).scalar() + 1 + >>> session.query(Item.id).filter(Item.id < 0).scalar() + None + >>> session.query(Item.id, Item.name).scalar() + 1 + >>> session.query(func.count(Parent.id)).scalar() + 20 + + This results in an execution of the underlying query. + + .. seealso:: + + :meth:`_engine.Result.scalar` - v2 comparable method. + + """ + # TODO: not sure why we can't use result.scalar() here + try: + ret = self.one() + if not isinstance(ret, collections_abc.Sequence): + return ret + return ret[0] + except sa_exc.NoResultFound: + return None + + def __iter__(self) -> Iterator[_T]: + result = self._iter() + try: + yield from result # type: ignore + except GeneratorExit: + # issue #8710 - direct iteration is not re-usable after + # an iterable block is broken, so close the result + result._soft_close() + raise + + def _iter(self) -> Union[ScalarResult[_T], Result[_T]]: + # new style execution. + params = self._params + + statement = self._statement_20() + result: Union[ScalarResult[_T], Result[_T]] = self.session.execute( + statement, + params, + execution_options={"_sa_orm_load_options": self.load_options}, + ) + + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = cast("Result[_T]", result).scalars() + + if ( + result._attributes.get("filtered", False) + and not self.load_options._yield_per + ): + result = result.unique() + + return result + + def __str__(self) -> str: + statement = self._statement_20() + + try: + bind = ( + self._get_bind_args(statement, self.session.get_bind) + if self.session + else None + ) + except sa_exc.UnboundExecutionError: + bind = None + + return str(statement.compile(bind)) + + def _get_bind_args(self, statement: Any, fn: Any, **kw: Any) -> Any: + return fn(clause=statement, **kw) + + @property + def column_descriptions(self) -> List[ORMColumnDescription]: + """Return metadata about the columns which would be + returned by this :class:`_query.Query`. + + Format is a list of dictionaries:: + + user_alias = aliased(User, name='user2') + q = sess.query(User, User.id, user_alias) + + # this expression: + q.column_descriptions + + # would return: + [ + { + 'name':'User', + 'type':User, + 'aliased':False, + 'expr':User, + 'entity': User + }, + { + 'name':'id', + 'type':Integer(), + 'aliased':False, + 'expr':User.id, + 'entity': User + }, + { + 'name':'user2', + 'type':User, + 'aliased':True, + 'expr':user_alias, + 'entity': user_alias + } + ] + + .. seealso:: + + This API is available using :term:`2.0 style` queries as well, + documented at: + + * :ref:`queryguide_inspection` + + * :attr:`.Select.column_descriptions` + + """ + + return _column_descriptions(self, legacy=True) + + @util.deprecated( + "2.0", + "The :meth:`_orm.Query.instances` method is deprecated and will " + "be removed in a future release. " + "Use the Select.from_statement() method or aliased() construct in " + "conjunction with Session.execute() instead.", + ) + def instances( + self, + result_proxy: CursorResult[Any], + context: Optional[QueryContext] = None, + ) -> Any: + """Return an ORM result given a :class:`_engine.CursorResult` and + :class:`.QueryContext`. + + """ + if context is None: + util.warn_deprecated( + "Using the Query.instances() method without a context " + "is deprecated and will be disallowed in a future release. " + "Please make use of :meth:`_query.Query.from_statement` " + "for linking ORM results to arbitrary select constructs.", + version="1.4", + ) + compile_state = self._compile_state(for_statement=False) + + context = QueryContext( + compile_state, + compile_state.statement, + self._params, + self.session, + self.load_options, + ) + + result = loading.instances(result_proxy, context) + + # legacy: automatically set scalars, unique + if result._attributes.get("is_single_entity", False): + result = result.scalars() # type: ignore + + if result._attributes.get("filtered", False): + result = result.unique() + + # TODO: isn't this supposed to be a list? + return result + + @util.became_legacy_20( + ":meth:`_orm.Query.merge_result`", + alternative="The method is superseded by the " + ":func:`_orm.merge_frozen_result` function.", + enable_warnings=False, # warnings occur via loading.merge_result + ) + def merge_result( + self, + iterator: Union[ + FrozenResult[Any], Iterable[Sequence[Any]], Iterable[object] + ], + load: bool = True, + ) -> Union[FrozenResult[Any], Iterable[Any]]: + """Merge a result into this :class:`_query.Query` object's Session. + + Given an iterator returned by a :class:`_query.Query` + of the same structure + as this one, return an identical iterator of results, with all mapped + instances merged into the session using :meth:`.Session.merge`. This + is an optimized method which will merge all mapped instances, + preserving the structure of the result rows and unmapped columns with + less method overhead than that of calling :meth:`.Session.merge` + explicitly for each value. + + The structure of the results is determined based on the column list of + this :class:`_query.Query` - if these do not correspond, + unchecked errors + will occur. + + The 'load' argument is the same as that of :meth:`.Session.merge`. + + For an example of how :meth:`_query.Query.merge_result` is used, see + the source code for the example :ref:`examples_caching`, where + :meth:`_query.Query.merge_result` is used to efficiently restore state + from a cache back into a target :class:`.Session`. + + """ + + return loading.merge_result(self, iterator, load) + + def exists(self) -> Exists: + """A convenience method that turns a query into an EXISTS subquery + of the form EXISTS (SELECT 1 FROM ... WHERE ...). + + e.g.:: + + q = session.query(User).filter(User.name == 'fred') + session.query(q.exists()) + + Producing SQL similar to:: + + SELECT EXISTS ( + SELECT 1 FROM users WHERE users.name = :name_1 + ) AS anon_1 + + The EXISTS construct is usually used in the WHERE clause:: + + session.query(User.id).filter(q.exists()).scalar() + + Note that some databases such as SQL Server don't allow an + EXISTS expression to be present in the columns clause of a + SELECT. To select a simple boolean value based on the exists + as a WHERE, use :func:`.literal`:: + + from sqlalchemy import literal + + session.query(literal(True)).filter(q.exists()).scalar() + + .. seealso:: + + :meth:`_sql.Select.exists` - v2 comparable method. + + """ + + # .add_columns() for the case that we are a query().select_from(X), + # so that ".statement" can be produced (#2995) but also without + # omitting the FROM clause from a query(X) (#2818); + # .with_only_columns() after we have a core select() so that + # we get just "SELECT 1" without any entities. + + inner = ( + self.enable_eagerloads(False) + .add_columns(sql.literal_column("1")) + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + ._get_select_statement_only() + .with_only_columns(1) + ) + + ezero = self._entity_from_pre_ent_zero() + if ezero is not None: + inner = inner.select_from(ezero) + + return sql.exists(inner) + + def count(self) -> int: + r"""Return a count of rows this the SQL formed by this :class:`Query` + would return. + + This generates the SQL for this Query as follows:: + + SELECT count(1) AS count_1 FROM ( + SELECT + ) AS anon_1 + + The above SQL returns a single row, which is the aggregate value + of the count function; the :meth:`_query.Query.count` + method then returns + that single integer value. + + .. warning:: + + It is important to note that the value returned by + count() is **not the same as the number of ORM objects that this + Query would return from a method such as the .all() method**. + The :class:`_query.Query` object, + when asked to return full entities, + will **deduplicate entries based on primary key**, meaning if the + same primary key value would appear in the results more than once, + only one object of that primary key would be present. This does + not apply to a query that is against individual columns. + + .. seealso:: + + :ref:`faq_query_deduplicating` + + For fine grained control over specific columns to count, to skip the + usage of a subquery or otherwise control of the FROM clause, or to use + other aggregate functions, use :attr:`~sqlalchemy.sql.expression.func` + expressions in conjunction with :meth:`~.Session.query`, i.e.:: + + from sqlalchemy import func + + # count User records, without + # using a subquery. + session.query(func.count(User.id)) + + # return count of user "id" grouped + # by "name" + session.query(func.count(User.id)).\ + group_by(User.name) + + from sqlalchemy import distinct + + # count distinct "name" values + session.query(func.count(distinct(User.name))) + + .. seealso:: + + :ref:`migration_20_query_usage` + + """ + col = sql.func.count(sql.literal_column("*")) + return ( # type: ignore + self._legacy_from_self(col).enable_eagerloads(False).scalar() + ) + + def delete( + self, synchronize_session: SynchronizeSessionArgument = "auto" + ) -> int: + r"""Perform a DELETE with an arbitrary WHERE clause. + + Deletes rows matched by this query from the database. + + E.g.:: + + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session=False) + + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session='evaluate') + + .. warning:: + + See the section :ref:`orm_expression_update_delete` for important + caveats and warnings, including limitations when using bulk UPDATE + and DELETE with mapper inheritance configurations. + + :param synchronize_session: chooses the strategy to update the + attributes on objects in the session. See the section + :ref:`orm_expression_update_delete` for a discussion of these + strategies. + + :return: the count of rows matched as returned by the database's + "row count" feature. + + .. seealso:: + + :ref:`orm_expression_update_delete` + + """ + + bulk_del = BulkDelete(self) + if self.dispatch.before_compile_delete: + for fn in self.dispatch.before_compile_delete: + new_query = fn(bulk_del.query, bulk_del) + if new_query is not None: + bulk_del.query = new_query + + self = bulk_del.query + + delete_ = sql.delete(*self._raw_columns) # type: ignore + delete_._where_criteria = self._where_criteria + result: CursorResult[Any] = self.session.execute( + delete_, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), + ) + bulk_del.result = result # type: ignore + self.session.dispatch.after_bulk_delete(bulk_del) + result.close() + + return result.rowcount + + def update( + self, + values: Dict[_DMLColumnArgument, Any], + synchronize_session: SynchronizeSessionArgument = "auto", + update_args: Optional[Dict[Any, Any]] = None, + ) -> int: + r"""Perform an UPDATE with an arbitrary WHERE clause. + + Updates rows matched by this query in the database. + + E.g.:: + + sess.query(User).filter(User.age == 25).\ + update({User.age: User.age - 10}, synchronize_session=False) + + sess.query(User).filter(User.age == 25).\ + update({"age": User.age - 10}, synchronize_session='evaluate') + + .. warning:: + + See the section :ref:`orm_expression_update_delete` for important + caveats and warnings, including limitations when using arbitrary + UPDATE and DELETE with mapper inheritance configurations. + + :param values: a dictionary with attributes names, or alternatively + mapped attributes or SQL expressions, as keys, and literal + values or sql expressions as values. If :ref:`parameter-ordered + mode ` is desired, the values can + be passed as a list of 2-tuples; this requires that the + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order` + flag is passed to the :paramref:`.Query.update.update_args` dictionary + as well. + + :param synchronize_session: chooses the strategy to update the + attributes on objects in the session. See the section + :ref:`orm_expression_update_delete` for a discussion of these + strategies. + + :param update_args: Optional dictionary, if present will be passed + to the underlying :func:`_expression.update` + construct as the ``**kw`` for + the object. May be used to pass dialect-specific arguments such + as ``mysql_limit``, as well as other special arguments such as + :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`. + + :return: the count of rows matched as returned by the database's + "row count" feature. + + + .. seealso:: + + :ref:`orm_expression_update_delete` + + """ + + update_args = update_args or {} + + bulk_ud = BulkUpdate(self, values, update_args) + + if self.dispatch.before_compile_update: + for fn in self.dispatch.before_compile_update: + new_query = fn(bulk_ud.query, bulk_ud) + if new_query is not None: + bulk_ud.query = new_query + self = bulk_ud.query + + upd = sql.update(*self._raw_columns) # type: ignore + + ppo = update_args.pop("preserve_parameter_order", False) + if ppo: + upd = upd.ordered_values(*values) # type: ignore + else: + upd = upd.values(values) + if update_args: + upd = upd.with_dialect_options(**update_args) + + upd._where_criteria = self._where_criteria + result: CursorResult[Any] = self.session.execute( + upd, + self._params, + execution_options=self._execution_options.union( + {"synchronize_session": synchronize_session} + ), + ) + bulk_ud.result = result # type: ignore + self.session.dispatch.after_bulk_update(bulk_ud) + result.close() + return result.rowcount + + def _compile_state( + self, for_statement: bool = False, **kw: Any + ) -> ORMCompileState: + """Create an out-of-compiler ORMCompileState object. + + The ORMCompileState object is normally created directly as a result + of the SQLCompiler.process() method being handed a Select() + or FromStatement() object that uses the "orm" plugin. This method + provides a means of creating this ORMCompileState object directly + without using the compiler. + + This method is used only for deprecated cases, which include + the .from_self() method for a Query that has multiple levels + of .from_self() in use, as well as the instances() method. It is + also used within the test suite to generate ORMCompileState objects + for test purposes. + + """ + + stmt = self._statement_20(for_statement=for_statement, **kw) + assert for_statement == stmt._compile_options._for_statement + + # this chooses between ORMFromStatementCompileState and + # ORMSelectCompileState. We could also base this on + # query._statement is not None as we have the ORM Query here + # however this is the more general path. + compile_state_cls = cast( + ORMCompileState, + ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), + ) + + return compile_state_cls.create_for_statement(stmt, None) + + def _compile_context(self, for_statement: bool = False) -> QueryContext: + compile_state = self._compile_state(for_statement=for_statement) + context = QueryContext( + compile_state, + compile_state.statement, + self._params, + self.session, + self.load_options, + ) + + return context + + +class AliasOption(interfaces.LoaderOption): + inherit_cache = False + + @util.deprecated( + "1.4", + "The :class:`.AliasOption` object is not necessary " + "for entities to be matched up to a query that is established " + "via :meth:`.Query.from_statement` and now does nothing.", + ) + def __init__(self, alias: Union[Alias, Subquery]): + r"""Return a :class:`.MapperOption` that will indicate to the + :class:`_query.Query` + that the main table has been aliased. + + """ + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + pass + + +class BulkUD: + """State used for the orm.Query version of update() / delete(). + + This object is now specific to Query only. + + """ + + def __init__(self, query: Query[Any]): + self.query = query.enable_eagerloads(False) + self._validate_query_state() + self.mapper = self.query._entity_from_pre_ent_zero() + + def _validate_query_state(self) -> None: + for attr, methname, notset, op in ( + ("_limit_clause", "limit()", None, operator.is_), + ("_offset_clause", "offset()", None, operator.is_), + ("_order_by_clauses", "order_by()", (), operator.eq), + ("_group_by_clauses", "group_by()", (), operator.eq), + ("_distinct", "distinct()", False, operator.is_), + ( + "_from_obj", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ( + "_setup_joins", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ): + if not op(getattr(self.query, attr), notset): + raise sa_exc.InvalidRequestError( + "Can't call Query.update() or Query.delete() " + "when %s has been called" % (methname,) + ) + + @property + def session(self) -> Session: + return self.query.session + + +class BulkUpdate(BulkUD): + """BulkUD which handles UPDATEs.""" + + def __init__( + self, + query: Query[Any], + values: Dict[_DMLColumnArgument, Any], + update_kwargs: Optional[Dict[Any, Any]], + ): + super().__init__(query) + self.values = values + self.update_kwargs = update_kwargs + + +class BulkDelete(BulkUD): + """BulkUD which handles DELETEs.""" + + +class RowReturningQuery(Query[Row[_TP]]): + if TYPE_CHECKING: + + def tuples(self) -> Query[_TP]: # type: ignore + ... diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/relationships.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/relationships.py new file mode 100644 index 0000000..b5e33ff --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/relationships.py @@ -0,0 +1,3500 @@ +# orm/relationships.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Heuristics related to join conditions as used in +:func:`_orm.relationship`. + +Provides the :class:`.JoinCondition` object, which encapsulates +SQL annotation and aliasing behavior focused on the `primaryjoin` +and `secondaryjoin` aspects of :func:`_orm.relationship`. + +""" +from __future__ import annotations + +import collections +from collections import abc +import dataclasses +import inspect as _py_inspect +import itertools +import re +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Collection +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import strategy_options +from ._typing import insp_is_aliased_class +from ._typing import is_has_collection_adapter +from .base import _DeclarativeMapped +from .base import _is_mapped_class +from .base import class_mapper +from .base import DynamicMapped +from .base import LoaderCallableStatus +from .base import PassiveFlag +from .base import state_str +from .base import WriteOnlyMapped +from .interfaces import _AttributeOptions +from .interfaces import _IntrospectsAnnotations +from .interfaces import MANYTOMANY +from .interfaces import MANYTOONE +from .interfaces import ONETOMANY +from .interfaces import PropComparator +from .interfaces import RelationshipDirection +from .interfaces import StrategizedProperty +from .util import _orm_annotate +from .util import _orm_deannotate +from .util import CascadeOptions +from .. import exc as sa_exc +from .. import Exists +from .. import log +from .. import schema +from .. import sql +from .. import util +from ..inspection import inspect +from ..sql import coercions +from ..sql import expression +from ..sql import operators +from ..sql import roles +from ..sql import visitors +from ..sql._typing import _ColumnExpressionArgument +from ..sql._typing import _HasClauseElement +from ..sql.annotation import _safe_annotate +from ..sql.elements import ColumnClause +from ..sql.elements import ColumnElement +from ..sql.util import _deep_annotate +from ..sql.util import _deep_deannotate +from ..sql.util import _shallow_annotate +from ..sql.util import adapt_criterion_to_null +from ..sql.util import ClauseAdapter +from ..sql.util import join_condition +from ..sql.util import selectables_overlap +from ..sql.util import visit_binary_product +from ..util.typing import de_optionalize_union_types +from ..util.typing import Literal +from ..util.typing import resolve_name_to_real_class_name + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _ExternalEntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _InternalEntityType + from ._typing import _O + from ._typing import _RegistryType + from .base import Mapped + from .clsregistry import _class_resolver + from .clsregistry import _ModNS + from .decl_base import _ClassScanMapperConfig + from .dependency import DependencyProcessor + from .mapper import Mapper + from .query import Query + from .session import Session + from .state import InstanceState + from .strategies import LazyLoader + from .util import AliasedClass + from .util import AliasedInsp + from ..sql._typing import _CoreAdapterProto + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _InfoType + from ..sql.annotation import _AnnotationDict + from ..sql.annotation import SupportsAnnotations + from ..sql.elements import BinaryExpression + from ..sql.elements import BindParameter + from ..sql.elements import ClauseElement + from ..sql.schema import Table + from ..sql.selectable import FromClause + from ..util.typing import _AnnotationScanType + from ..util.typing import RODescriptorReference + +_T = TypeVar("_T", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) + +_PT = TypeVar("_PT", bound=Any) + +_PT2 = TypeVar("_PT2", bound=Any) + + +_RelationshipArgumentType = Union[ + str, + Type[_T], + Callable[[], Type[_T]], + "Mapper[_T]", + "AliasedClass[_T]", + Callable[[], "Mapper[_T]"], + Callable[[], "AliasedClass[_T]"], +] + +_LazyLoadArgumentType = Literal[ + "select", + "joined", + "selectin", + "subquery", + "raise", + "raise_on_sql", + "noload", + "immediate", + "write_only", + "dynamic", + True, + False, + None, +] + + +_RelationshipJoinConditionArgument = Union[ + str, _ColumnExpressionArgument[bool] +] +_RelationshipSecondaryArgument = Union[ + "FromClause", str, Callable[[], "FromClause"] +] +_ORMOrderByArgument = Union[ + Literal[False], + str, + _ColumnExpressionArgument[Any], + Callable[[], _ColumnExpressionArgument[Any]], + Callable[[], Iterable[_ColumnExpressionArgument[Any]]], + Iterable[Union[str, _ColumnExpressionArgument[Any]]], +] +ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] + +_ORMColCollectionElement = Union[ + ColumnClause[Any], + _HasClauseElement[Any], + roles.DMLColumnRole, + "Mapped[Any]", +] +_ORMColCollectionArgument = Union[ + str, + Sequence[_ORMColCollectionElement], + Callable[[], Sequence[_ORMColCollectionElement]], + Callable[[], _ORMColCollectionElement], + _ORMColCollectionElement, +] + + +_CEA = TypeVar("_CEA", bound=_ColumnExpressionArgument[Any]) + +_CE = TypeVar("_CE", bound="ColumnElement[Any]") + + +_ColumnPairIterable = Iterable[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_ColumnPairs = Sequence[Tuple[ColumnElement[Any], ColumnElement[Any]]] + +_MutableColumnPairs = List[Tuple[ColumnElement[Any], ColumnElement[Any]]] + + +def remote(expr: _CEA) -> _CEA: + """Annotate a portion of a primaryjoin expression + with a 'remote' annotation. + + See the section :ref:`relationship_custom_foreign` for a + description of use. + + .. seealso:: + + :ref:`relationship_custom_foreign` + + :func:`.foreign` + + """ + return _annotate_columns( # type: ignore + coercions.expect(roles.ColumnArgumentRole, expr), {"remote": True} + ) + + +def foreign(expr: _CEA) -> _CEA: + """Annotate a portion of a primaryjoin expression + with a 'foreign' annotation. + + See the section :ref:`relationship_custom_foreign` for a + description of use. + + .. seealso:: + + :ref:`relationship_custom_foreign` + + :func:`.remote` + + """ + + return _annotate_columns( # type: ignore + coercions.expect(roles.ColumnArgumentRole, expr), {"foreign": True} + ) + + +@dataclasses.dataclass +class _RelationshipArg(Generic[_T1, _T2]): + """stores a user-defined parameter value that must be resolved and + parsed later at mapper configuration time. + + """ + + __slots__ = "name", "argument", "resolved" + name: str + argument: _T1 + resolved: Optional[_T2] + + def _is_populated(self) -> bool: + return self.argument is not None + + def _resolve_against_registry( + self, clsregistry_resolver: Callable[[str, bool], _class_resolver] + ) -> None: + attr_value = self.argument + + if isinstance(attr_value, str): + self.resolved = clsregistry_resolver( + attr_value, self.name == "secondary" + )() + elif callable(attr_value) and not _is_mapped_class(attr_value): + self.resolved = attr_value() + else: + self.resolved = attr_value + + +_RelationshipOrderByArg = Union[Literal[False], Tuple[ColumnElement[Any], ...]] + + +class _RelationshipArgs(NamedTuple): + """stores user-passed parameters that are resolved at mapper configuration + time. + + """ + + secondary: _RelationshipArg[ + Optional[_RelationshipSecondaryArgument], + Optional[FromClause], + ] + primaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + secondaryjoin: _RelationshipArg[ + Optional[_RelationshipJoinConditionArgument], + Optional[ColumnElement[Any]], + ] + order_by: _RelationshipArg[_ORMOrderByArgument, _RelationshipOrderByArg] + foreign_keys: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + remote_side: _RelationshipArg[ + Optional[_ORMColCollectionArgument], Set[ColumnElement[Any]] + ] + + +@log.class_logger +class RelationshipProperty( + _IntrospectsAnnotations, StrategizedProperty[_T], log.Identified +): + """Describes an object property that holds a single item or list + of items that correspond to a related database table. + + Public constructor is the :func:`_orm.relationship` function. + + .. seealso:: + + :ref:`relationship_config_toplevel` + + """ + + strategy_wildcard_key = strategy_options._RELATIONSHIP_TOKEN + inherit_cache = True + """:meta private:""" + + _links_to_entity = True + _is_relationship = True + + _overlaps: Sequence[str] + + _lazy_strategy: LazyLoader + + _persistence_only = dict( + passive_deletes=False, + passive_updates=True, + enable_typechecks=True, + active_history=False, + cascade_backrefs=False, + ) + + _dependency_processor: Optional[DependencyProcessor] = None + + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + _join_condition: JoinCondition + order_by: _RelationshipOrderByArg + + _user_defined_foreign_keys: Set[ColumnElement[Any]] + _calculated_foreign_keys: Set[ColumnElement[Any]] + + remote_side: Set[ColumnElement[Any]] + local_columns: Set[ColumnElement[Any]] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: Optional[_ColumnPairs] + + local_remote_pairs: Optional[_ColumnPairs] + + direction: RelationshipDirection + + _init_args: _RelationshipArgs + + def __init__( + self, + argument: Optional[_RelationshipArgumentType[_T]] = None, + secondary: Optional[_RelationshipSecondaryArgument] = None, + *, + uselist: Optional[bool] = None, + collection_class: Optional[ + Union[Type[Collection[Any]], Callable[[], Collection[Any]]] + ] = None, + primaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None, + back_populates: Optional[str] = None, + order_by: _ORMOrderByArgument = False, + backref: Optional[ORMBackrefArgument] = None, + overlaps: Optional[str] = None, + post_update: bool = False, + cascade: str = "save-update, merge", + viewonly: bool = False, + attribute_options: Optional[_AttributeOptions] = None, + lazy: _LazyLoadArgumentType = "select", + passive_deletes: Union[Literal["all"], bool] = False, + passive_updates: bool = True, + active_history: bool = False, + enable_typechecks: bool = True, + foreign_keys: Optional[_ORMColCollectionArgument] = None, + remote_side: Optional[_ORMColCollectionArgument] = None, + join_depth: Optional[int] = None, + comparator_factory: Optional[ + Type[RelationshipProperty.Comparator[Any]] + ] = None, + single_parent: bool = False, + innerjoin: bool = False, + distinct_target_key: Optional[bool] = None, + load_on_pending: bool = False, + query_class: Optional[Type[Query[Any]]] = None, + info: Optional[_InfoType] = None, + omit_join: Literal[None, False] = None, + sync_backref: Optional[bool] = None, + doc: Optional[str] = None, + bake_queries: Literal[True] = True, + cascade_backrefs: Literal[False] = False, + _local_remote_pairs: Optional[_ColumnPairs] = None, + _legacy_inactive_history_style: bool = False, + ): + super().__init__(attribute_options=attribute_options) + + self.uselist = uselist + self.argument = argument + + self._init_args = _RelationshipArgs( + _RelationshipArg("secondary", secondary, None), + _RelationshipArg("primaryjoin", primaryjoin, None), + _RelationshipArg("secondaryjoin", secondaryjoin, None), + _RelationshipArg("order_by", order_by, None), + _RelationshipArg("foreign_keys", foreign_keys, None), + _RelationshipArg("remote_side", remote_side, None), + ) + + self.post_update = post_update + self.viewonly = viewonly + if viewonly: + self._warn_for_persistence_only_flags( + passive_deletes=passive_deletes, + passive_updates=passive_updates, + enable_typechecks=enable_typechecks, + active_history=active_history, + cascade_backrefs=cascade_backrefs, + ) + if viewonly and sync_backref: + raise sa_exc.ArgumentError( + "sync_backref and viewonly cannot both be True" + ) + self.sync_backref = sync_backref + self.lazy = lazy + self.single_parent = single_parent + self.collection_class = collection_class + self.passive_deletes = passive_deletes + + if cascade_backrefs: + raise sa_exc.ArgumentError( + "The 'cascade_backrefs' parameter passed to " + "relationship() may only be set to False." + ) + + self.passive_updates = passive_updates + self.enable_typechecks = enable_typechecks + self.query_class = query_class + self.innerjoin = innerjoin + self.distinct_target_key = distinct_target_key + self.doc = doc + self.active_history = active_history + self._legacy_inactive_history_style = _legacy_inactive_history_style + + self.join_depth = join_depth + if omit_join: + util.warn( + "setting omit_join to True is not supported; selectin " + "loading of this relationship may not work correctly if this " + "flag is set explicitly. omit_join optimization is " + "automatically detected for conditions under which it is " + "supported." + ) + + self.omit_join = omit_join + self.local_remote_pairs = _local_remote_pairs + self.load_on_pending = load_on_pending + self.comparator_factory = ( + comparator_factory or RelationshipProperty.Comparator + ) + util.set_creation_order(self) + + if info is not None: + self.info.update(info) + + self.strategy_key = (("lazy", self.lazy),) + + self._reverse_property: Set[RelationshipProperty[Any]] = set() + + if overlaps: + self._overlaps = set(re.split(r"\s*,\s*", overlaps)) # type: ignore # noqa: E501 + else: + self._overlaps = () + + # mypy ignoring the @property setter + self.cascade = cascade # type: ignore + + self.back_populates = back_populates + + if self.back_populates: + if backref: + raise sa_exc.ArgumentError( + "backref and back_populates keyword arguments " + "are mutually exclusive" + ) + self.backref = None + else: + self.backref = backref + + def _warn_for_persistence_only_flags(self, **kw: Any) -> None: + for k, v in kw.items(): + if v != self._persistence_only[k]: + # we are warning here rather than warn deprecated as this is a + # configuration mistake, and Python shows regular warnings more + # aggressively than deprecation warnings by default. Unlike the + # case of setting viewonly with cascade, the settings being + # warned about here are not actively doing the wrong thing + # against viewonly=True, so it is not as urgent to have these + # raise an error. + util.warn( + "Setting %s on relationship() while also " + "setting viewonly=True does not make sense, as a " + "viewonly=True relationship does not perform persistence " + "operations. This configuration may raise an error " + "in a future release." % (k,) + ) + + def instrument_class(self, mapper: Mapper[Any]) -> None: + attributes.register_descriptor( + mapper.class_, + self.key, + comparator=self.comparator_factory(self, mapper), + parententity=mapper, + doc=self.doc, + ) + + class Comparator(util.MemoizedSlots, PropComparator[_PT]): + """Produce boolean, comparison, and other operators for + :class:`.RelationshipProperty` attributes. + + See the documentation for :class:`.PropComparator` for a brief + overview of ORM level operator definition. + + .. seealso:: + + :class:`.PropComparator` + + :class:`.ColumnProperty.Comparator` + + :class:`.ColumnOperators` + + :ref:`types_operators` + + :attr:`.TypeEngine.comparator_factory` + + """ + + __slots__ = ( + "entity", + "mapper", + "property", + "_of_type", + "_extra_criteria", + ) + + prop: RODescriptorReference[RelationshipProperty[_PT]] + _of_type: Optional[_EntityType[_PT]] + + def __init__( + self, + prop: RelationshipProperty[_PT], + parentmapper: _InternalEntityType[Any], + adapt_to_entity: Optional[AliasedInsp[Any]] = None, + of_type: Optional[_EntityType[_PT]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ): + """Construction of :class:`.RelationshipProperty.Comparator` + is internal to the ORM's attribute mechanics. + + """ + self.prop = prop + self._parententity = parentmapper + self._adapt_to_entity = adapt_to_entity + if of_type: + self._of_type = of_type + else: + self._of_type = None + self._extra_criteria = extra_criteria + + def adapt_to_entity( + self, adapt_to_entity: AliasedInsp[Any] + ) -> RelationshipProperty.Comparator[Any]: + return self.__class__( + self.prop, + self._parententity, + adapt_to_entity=adapt_to_entity, + of_type=self._of_type, + ) + + entity: _InternalEntityType[_PT] + """The target entity referred to by this + :class:`.RelationshipProperty.Comparator`. + + This is either a :class:`_orm.Mapper` or :class:`.AliasedInsp` + object. + + This is the "target" or "remote" side of the + :func:`_orm.relationship`. + + """ + + mapper: Mapper[_PT] + """The target :class:`_orm.Mapper` referred to by this + :class:`.RelationshipProperty.Comparator`. + + This is the "target" or "remote" side of the + :func:`_orm.relationship`. + + """ + + def _memoized_attr_entity(self) -> _InternalEntityType[_PT]: + if self._of_type: + return inspect(self._of_type) # type: ignore + else: + return self.prop.entity + + def _memoized_attr_mapper(self) -> Mapper[_PT]: + return self.entity.mapper + + def _source_selectable(self) -> FromClause: + if self._adapt_to_entity: + return self._adapt_to_entity.selectable + else: + return self.property.parent._with_polymorphic_selectable + + def __clause_element__(self) -> ColumnElement[bool]: + adapt_from = self._source_selectable() + if self._of_type: + of_type_entity = inspect(self._of_type) + else: + of_type_entity = None + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = self.prop._create_joins( + source_selectable=adapt_from, + source_polymorphic=True, + of_type_entity=of_type_entity, + alias_secondary=True, + extra_criteria=self._extra_criteria, + ) + if sj is not None: + return pj & sj + else: + return pj + + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_PT]: + r"""Redefine this object in terms of a polymorphic subclass. + + See :meth:`.PropComparator.of_type` for an example. + + + """ + return RelationshipProperty.Comparator( + self.prop, + self._parententity, + adapt_to_entity=self._adapt_to_entity, + of_type=class_, + extra_criteria=self._extra_criteria, + ) + + def and_( + self, *criteria: _ColumnExpressionArgument[bool] + ) -> PropComparator[Any]: + """Add AND criteria. + + See :meth:`.PropComparator.and_` for an example. + + .. versionadded:: 1.4 + + """ + exprs = tuple( + coercions.expect(roles.WhereHavingRole, clause) + for clause in util.coerce_generator_arg(criteria) + ) + + return RelationshipProperty.Comparator( + self.prop, + self._parententity, + adapt_to_entity=self._adapt_to_entity, + of_type=self._of_type, + extra_criteria=self._extra_criteria + exprs, + ) + + def in_(self, other: Any) -> NoReturn: + """Produce an IN clause - this is not implemented + for :func:`_orm.relationship`-based attributes at this time. + + """ + raise NotImplementedError( + "in_() not yet supported for " + "relationships. For a simple " + "many-to-one, use in_() against " + "the set of foreign key values." + ) + + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + """Implement the ``==`` operator. + + In a many-to-one context, such as:: + + MyClass.some_prop == + + this will typically produce a + clause such as:: + + mytable.related_id == + + Where ```` is the primary key of the given + object. + + The ``==`` operator provides partial functionality for non- + many-to-one comparisons: + + * Comparisons against collections are not supported. + Use :meth:`~.Relationship.Comparator.contains`. + * Compared to a scalar one-to-many, will produce a + clause that compares the target columns in the parent to + the given target. + * Compared to a scalar many-to-many, an alias + of the association table will be rendered as + well, forming a natural join that is part of the + main body of the query. This will not work for + queries that go beyond simple AND conjunctions of + comparisons, such as those which use OR. Use + explicit joins, outerjoins, or + :meth:`~.Relationship.Comparator.has` for + more comprehensive non-many-to-one scalar + membership tests. + * Comparisons against ``None`` given in a one-to-many + or many-to-many context produce a NOT EXISTS clause. + + """ + if other is None or isinstance(other, expression.Null): + if self.property.direction in [ONETOMANY, MANYTOMANY]: + return ~self._criterion_exists() + else: + return _orm_annotate( + self.property._optimized_compare( + None, adapt_source=self.adapter + ) + ) + elif self.property.uselist: + raise sa_exc.InvalidRequestError( + "Can't compare a collection to an object or collection; " + "use contains() to test for membership." + ) + else: + return _orm_annotate( + self.property._optimized_compare( + other, adapt_source=self.adapter + ) + ) + + def _criterion_exists( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> Exists: + where_criteria = ( + coercions.expect(roles.WhereHavingRole, criterion) + if criterion is not None + else None + ) + + if getattr(self, "_of_type", None): + info: Optional[_InternalEntityType[Any]] = inspect( + self._of_type + ) + assert info is not None + target_mapper, to_selectable, is_aliased_class = ( + info.mapper, + info.selectable, + info.is_aliased_class, + ) + if self.property._is_self_referential and not is_aliased_class: + to_selectable = to_selectable._anonymous_fromclause() + + single_crit = target_mapper._single_table_criterion + if single_crit is not None: + if where_criteria is not None: + where_criteria = single_crit & where_criteria + else: + where_criteria = single_crit + else: + is_aliased_class = False + to_selectable = None + + if self.adapter: + source_selectable = self._source_selectable() + else: + source_selectable = None + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = self.property._create_joins( + dest_selectable=to_selectable, + source_selectable=source_selectable, + ) + + for k in kwargs: + crit = getattr(self.property.mapper.class_, k) == kwargs[k] + if where_criteria is None: + where_criteria = crit + else: + where_criteria = where_criteria & crit + + # annotate the *local* side of the join condition, in the case + # of pj + sj this is the full primaryjoin, in the case of just + # pj its the local side of the primaryjoin. + if sj is not None: + j = _orm_annotate(pj) & sj + else: + j = _orm_annotate(pj, exclude=self.property.remote_side) + + if ( + where_criteria is not None + and target_adapter + and not is_aliased_class + ): + # limit this adapter to annotated only? + where_criteria = target_adapter.traverse(where_criteria) + + # only have the "joined left side" of what we + # return be subject to Query adaption. The right + # side of it is used for an exists() subquery and + # should not correlate or otherwise reach out + # to anything in the enclosing query. + if where_criteria is not None: + where_criteria = where_criteria._annotate( + {"no_replacement_traverse": True} + ) + + crit = j & sql.True_._ifnone(where_criteria) + + if secondary is not None: + ex = ( + sql.exists(1) + .where(crit) + .select_from(dest, secondary) + .correlate_except(dest, secondary) + ) + else: + ex = ( + sql.exists(1) + .where(crit) + .select_from(dest) + .correlate_except(dest) + ) + return ex + + def any( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce an expression that tests a collection against + particular criterion, using EXISTS. + + An expression like:: + + session.query(MyClass).filter( + MyClass.somereference.any(SomeRelated.x==2) + ) + + + Will produce a query like:: + + SELECT * FROM my_table WHERE + EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id + AND related.x=2) + + Because :meth:`~.Relationship.Comparator.any` uses + a correlated subquery, its performance is not nearly as + good when compared against large target tables as that of + using a join. + + :meth:`~.Relationship.Comparator.any` is particularly + useful for testing for empty collections:: + + session.query(MyClass).filter( + ~MyClass.somereference.any() + ) + + will produce:: + + SELECT * FROM my_table WHERE + NOT (EXISTS (SELECT 1 FROM related WHERE + related.my_id=my_table.id)) + + :meth:`~.Relationship.Comparator.any` is only + valid for collections, i.e. a :func:`_orm.relationship` + that has ``uselist=True``. For scalar references, + use :meth:`~.Relationship.Comparator.has`. + + """ + if not self.property.uselist: + raise sa_exc.InvalidRequestError( + "'any()' not implemented for scalar " + "attributes. Use has()." + ) + + return self._criterion_exists(criterion, **kwargs) + + def has( + self, + criterion: Optional[_ColumnExpressionArgument[bool]] = None, + **kwargs: Any, + ) -> ColumnElement[bool]: + """Produce an expression that tests a scalar reference against + particular criterion, using EXISTS. + + An expression like:: + + session.query(MyClass).filter( + MyClass.somereference.has(SomeRelated.x==2) + ) + + + Will produce a query like:: + + SELECT * FROM my_table WHERE + EXISTS (SELECT 1 FROM related WHERE + related.id==my_table.related_id AND related.x=2) + + Because :meth:`~.Relationship.Comparator.has` uses + a correlated subquery, its performance is not nearly as + good when compared against large target tables as that of + using a join. + + :meth:`~.Relationship.Comparator.has` is only + valid for scalar references, i.e. a :func:`_orm.relationship` + that has ``uselist=False``. For collection references, + use :meth:`~.Relationship.Comparator.any`. + + """ + if self.property.uselist: + raise sa_exc.InvalidRequestError( + "'has()' not implemented for collections. Use any()." + ) + return self._criterion_exists(criterion, **kwargs) + + def contains( + self, other: _ColumnExpressionArgument[Any], **kwargs: Any + ) -> ColumnElement[bool]: + """Return a simple expression that tests a collection for + containment of a particular item. + + :meth:`~.Relationship.Comparator.contains` is + only valid for a collection, i.e. a + :func:`_orm.relationship` that implements + one-to-many or many-to-many with ``uselist=True``. + + When used in a simple one-to-many context, an + expression like:: + + MyClass.contains(other) + + Produces a clause like:: + + mytable.id == + + Where ```` is the value of the foreign key + attribute on ``other`` which refers to the primary + key of its parent object. From this it follows that + :meth:`~.Relationship.Comparator.contains` is + very useful when used with simple one-to-many + operations. + + For many-to-many operations, the behavior of + :meth:`~.Relationship.Comparator.contains` + has more caveats. The association table will be + rendered in the statement, producing an "implicit" + join, that is, includes multiple tables in the FROM + clause which are equated in the WHERE clause:: + + query(MyClass).filter(MyClass.contains(other)) + + Produces a query like:: + + SELECT * FROM my_table, my_association_table AS + my_association_table_1 WHERE + my_table.id = my_association_table_1.parent_id + AND my_association_table_1.child_id = + + Where ```` would be the primary key of + ``other``. From the above, it is clear that + :meth:`~.Relationship.Comparator.contains` + will **not** work with many-to-many collections when + used in queries that move beyond simple AND + conjunctions, such as multiple + :meth:`~.Relationship.Comparator.contains` + expressions joined by OR. In such cases subqueries or + explicit "outer joins" will need to be used instead. + See :meth:`~.Relationship.Comparator.any` for + a less-performant alternative using EXISTS, or refer + to :meth:`_query.Query.outerjoin` + as well as :ref:`orm_queryguide_joins` + for more details on constructing outer joins. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + if not self.prop.uselist: + raise sa_exc.InvalidRequestError( + "'contains' not implemented for scalar " + "attributes. Use ==" + ) + + clause = self.prop._optimized_compare( + other, adapt_source=self.adapter + ) + + if self.prop.secondaryjoin is not None: + clause.negation_clause = self.__negated_contains_or_equals( + other + ) + + return clause + + def __negated_contains_or_equals( + self, other: Any + ) -> ColumnElement[bool]: + if self.prop.direction == MANYTOONE: + state = attributes.instance_state(other) + + def state_bindparam( + local_col: ColumnElement[Any], + state: InstanceState[Any], + remote_col: ColumnElement[Any], + ) -> BindParameter[Any]: + dict_ = state.dict + return sql.bindparam( + local_col.key, + type_=local_col.type, + unique=True, + callable_=self.prop._get_attr_w_warn_on_none( + self.prop.mapper, state, dict_, remote_col + ), + ) + + def adapt(col: _CE) -> _CE: + if self.adapter: + return self.adapter(col) + else: + return col + + if self.property._use_get: + return sql.and_( + *[ + sql.or_( + adapt(x) + != state_bindparam(adapt(x), state, y), + adapt(x) == None, + ) + for (x, y) in self.property.local_remote_pairs + ] + ) + + criterion = sql.and_( + *[ + x == y + for (x, y) in zip( + self.property.mapper.primary_key, + self.property.mapper.primary_key_from_instance(other), + ) + ] + ) + + return ~self._criterion_exists(criterion) + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + """Implement the ``!=`` operator. + + In a many-to-one context, such as:: + + MyClass.some_prop != + + This will typically produce a clause such as:: + + mytable.related_id != + + Where ```` is the primary key of the + given object. + + The ``!=`` operator provides partial functionality for non- + many-to-one comparisons: + + * Comparisons against collections are not supported. + Use + :meth:`~.Relationship.Comparator.contains` + in conjunction with :func:`_expression.not_`. + * Compared to a scalar one-to-many, will produce a + clause that compares the target columns in the parent to + the given target. + * Compared to a scalar many-to-many, an alias + of the association table will be rendered as + well, forming a natural join that is part of the + main body of the query. This will not work for + queries that go beyond simple AND conjunctions of + comparisons, such as those which use OR. Use + explicit joins, outerjoins, or + :meth:`~.Relationship.Comparator.has` in + conjunction with :func:`_expression.not_` for + more comprehensive non-many-to-one scalar + membership tests. + * Comparisons against ``None`` given in a one-to-many + or many-to-many context produce an EXISTS clause. + + """ + if other is None or isinstance(other, expression.Null): + if self.property.direction == MANYTOONE: + return _orm_annotate( + ~self.property._optimized_compare( + None, adapt_source=self.adapter + ) + ) + + else: + return self._criterion_exists() + elif self.property.uselist: + raise sa_exc.InvalidRequestError( + "Can't compare a collection" + " to an object or collection; use " + "contains() to test for membership." + ) + else: + return _orm_annotate(self.__negated_contains_or_equals(other)) + + def _memoized_attr_property(self) -> RelationshipProperty[_PT]: + self.prop.parent._check_configure() + return self.prop + + def _with_parent( + self, + instance: object, + alias_secondary: bool = True, + from_entity: Optional[_EntityType[Any]] = None, + ) -> ColumnElement[bool]: + assert instance is not None + adapt_source: Optional[_CoreAdapterProto] = None + if from_entity is not None: + insp: Optional[_InternalEntityType[Any]] = inspect(from_entity) + assert insp is not None + if insp_is_aliased_class(insp): + adapt_source = insp._adapter.adapt_clause + return self._optimized_compare( + instance, + value_is_parent=True, + adapt_source=adapt_source, + alias_secondary=alias_secondary, + ) + + def _optimized_compare( + self, + state: Any, + value_is_parent: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + alias_secondary: bool = True, + ) -> ColumnElement[bool]: + if state is not None: + try: + state = inspect(state) + except sa_exc.NoInspectionAvailable: + state = None + + if state is None or not getattr(state, "is_instance", False): + raise sa_exc.ArgumentError( + "Mapped instance expected for relationship " + "comparison to object. Classes, queries and other " + "SQL elements are not accepted in this context; for " + "comparison with a subquery, " + "use %s.has(**criteria)." % self + ) + reverse_direction = not value_is_parent + + if state is None: + return self._lazy_none_clause( + reverse_direction, adapt_source=adapt_source + ) + + if not reverse_direction: + criterion, bind_to_col = ( + self._lazy_strategy._lazywhere, + self._lazy_strategy._bind_to_col, + ) + else: + criterion, bind_to_col = ( + self._lazy_strategy._rev_lazywhere, + self._lazy_strategy._rev_bind_to_col, + ) + + if reverse_direction: + mapper = self.mapper + else: + mapper = self.parent + + dict_ = attributes.instance_dict(state.obj()) + + def visit_bindparam(bindparam: BindParameter[Any]) -> None: + if bindparam._identifying_key in bind_to_col: + bindparam.callable = self._get_attr_w_warn_on_none( + mapper, + state, + dict_, + bind_to_col[bindparam._identifying_key], + ) + + if self.secondary is not None and alias_secondary: + criterion = ClauseAdapter( + self.secondary._anonymous_fromclause() + ).traverse(criterion) + + criterion = visitors.cloned_traverse( + criterion, {}, {"bindparam": visit_bindparam} + ) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def _get_attr_w_warn_on_none( + self, + mapper: Mapper[Any], + state: InstanceState[Any], + dict_: _InstanceDict, + column: ColumnElement[Any], + ) -> Callable[[], Any]: + """Create the callable that is used in a many-to-one expression. + + E.g.:: + + u1 = s.query(User).get(5) + + expr = Address.user == u1 + + Above, the SQL should be "address.user_id = 5". The callable + returned by this method produces the value "5" based on the identity + of ``u1``. + + """ + + # in this callable, we're trying to thread the needle through + # a wide variety of scenarios, including: + # + # * the object hasn't been flushed yet and there's no value for + # the attribute as of yet + # + # * the object hasn't been flushed yet but it has a user-defined + # value + # + # * the object has a value but it's expired and not locally present + # + # * the object has a value but it's expired and not locally present, + # and the object is also detached + # + # * The object hadn't been flushed yet, there was no value, but + # later, the object has been expired and detached, and *now* + # they're trying to evaluate it + # + # * the object had a value, but it was changed to a new value, and + # then expired + # + # * the object had a value, but it was changed to a new value, and + # then expired, then the object was detached + # + # * the object has a user-set value, but it's None and we don't do + # the comparison correctly for that so warn + # + + prop = mapper.get_property_by_column(column) + + # by invoking this method, InstanceState will track the last known + # value for this key each time the attribute is to be expired. + # this feature was added explicitly for use in this method. + state._track_last_known_value(prop.key) + + lkv_fixed = state._last_known_values + + def _go() -> Any: + assert lkv_fixed is not None + last_known = to_return = lkv_fixed[prop.key] + existing_is_available = ( + last_known is not LoaderCallableStatus.NO_VALUE + ) + + # we support that the value may have changed. so here we + # try to get the most recent value including re-fetching. + # only if we can't get a value now due to detachment do we return + # the last known value + current_value = mapper._get_state_attr_by_column( + state, + dict_, + column, + passive=( + PassiveFlag.PASSIVE_OFF + if state.persistent + else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK + ), + ) + + if current_value is LoaderCallableStatus.NEVER_SET: + if not existing_is_available: + raise sa_exc.InvalidRequestError( + "Can't resolve value for column %s on object " + "%s; no value has been set for this column" + % (column, state_str(state)) + ) + elif current_value is LoaderCallableStatus.PASSIVE_NO_RESULT: + if not existing_is_available: + raise sa_exc.InvalidRequestError( + "Can't resolve value for column %s on object " + "%s; the object is detached and the value was " + "expired" % (column, state_str(state)) + ) + else: + to_return = current_value + if to_return is None: + util.warn( + "Got None for value of column %s; this is unsupported " + "for a relationship comparison and will not " + "currently produce an IS comparison " + "(but may in a future release)" % column + ) + return to_return + + return _go + + def _lazy_none_clause( + self, + reverse_direction: bool = False, + adapt_source: Optional[_CoreAdapterProto] = None, + ) -> ColumnElement[bool]: + if not reverse_direction: + criterion, bind_to_col = ( + self._lazy_strategy._lazywhere, + self._lazy_strategy._bind_to_col, + ) + else: + criterion, bind_to_col = ( + self._lazy_strategy._rev_lazywhere, + self._lazy_strategy._rev_bind_to_col, + ) + + criterion = adapt_criterion_to_null(criterion, bind_to_col) + + if adapt_source: + criterion = adapt_source(criterion) + return criterion + + def __str__(self) -> str: + return str(self.parent.class_.__name__) + "." + self.key + + def merge( + self, + session: Session, + source_state: InstanceState[Any], + source_dict: _InstanceDict, + dest_state: InstanceState[Any], + dest_dict: _InstanceDict, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> None: + if load: + for r in self._reverse_property: + if (source_state, r) in _recursive: + return + + if "merge" not in self._cascade: + return + + if self.key not in source_dict: + return + + if self.uselist: + impl = source_state.get_impl(self.key) + + assert is_has_collection_adapter(impl) + instances_iterable = impl.get_collection(source_state, source_dict) + + # if this is a CollectionAttributeImpl, then empty should + # be False, otherwise "self.key in source_dict" should not be + # True + assert not instances_iterable.empty if impl.collection else True + + if load: + # for a full merge, pre-load the destination collection, + # so that individual _merge of each item pulls from identity + # map for those already present. + # also assumes CollectionAttributeImpl behavior of loading + # "old" list in any case + dest_state.get_impl(self.key).get( + dest_state, dest_dict, passive=PassiveFlag.PASSIVE_MERGE + ) + + dest_list = [] + for current in instances_iterable: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge( + current_state, + current_dict, + load=load, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) + if obj is not None: + dest_list.append(obj) + + if not load: + coll = attributes.init_state_collection( + dest_state, dest_dict, self.key + ) + for c in dest_list: + coll.append_without_event(c) + else: + dest_impl = dest_state.get_impl(self.key) + assert is_has_collection_adapter(dest_impl) + dest_impl.set( + dest_state, + dest_dict, + dest_list, + _adapt=False, + passive=PassiveFlag.PASSIVE_MERGE, + ) + else: + current = source_dict[self.key] + if current is not None: + current_state = attributes.instance_state(current) + current_dict = attributes.instance_dict(current) + _recursive[(current_state, self)] = True + obj = session._merge( + current_state, + current_dict, + load=load, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) + else: + obj = None + + if not load: + dest_dict[self.key] = obj + else: + dest_state.get_impl(self.key).set( + dest_state, dest_dict, obj, None + ) + + def _value_as_iterable( + self, + state: InstanceState[_O], + dict_: _InstanceDict, + key: str, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Sequence[Tuple[InstanceState[_O], _O]]: + """Return a list of tuples (state, obj) for the given + key. + + returns an empty list if the value is None/empty/PASSIVE_NO_RESULT + """ + + impl = state.manager[key].impl + x = impl.get(state, dict_, passive=passive) + if x is LoaderCallableStatus.PASSIVE_NO_RESULT or x is None: + return [] + elif is_has_collection_adapter(impl): + return [ + (attributes.instance_state(o), o) + for o in impl.get_collection(state, dict_, x, passive=passive) + ] + else: + return [(attributes.instance_state(x), x)] + + def cascade_iterator( + self, + type_: str, + state: InstanceState[Any], + dict_: _InstanceDict, + visited_states: Set[InstanceState[Any]], + halt_on: Optional[Callable[[InstanceState[Any]], bool]] = None, + ) -> Iterator[Tuple[Any, Mapper[Any], InstanceState[Any], _InstanceDict]]: + # assert type_ in self._cascade + + # only actively lazy load on the 'delete' cascade + if type_ != "delete" or self.passive_deletes: + passive = PassiveFlag.PASSIVE_NO_INITIALIZE + else: + passive = PassiveFlag.PASSIVE_OFF | PassiveFlag.NO_RAISE + + if type_ == "save-update": + tuples = state.manager[self.key].impl.get_all_pending(state, dict_) + else: + tuples = self._value_as_iterable( + state, dict_, self.key, passive=passive + ) + + skip_pending = ( + type_ == "refresh-expire" and "delete-orphan" not in self._cascade + ) + + for instance_state, c in tuples: + if instance_state in visited_states: + continue + + if c is None: + # would like to emit a warning here, but + # would not be consistent with collection.append(None) + # current behavior of silently skipping. + # see [ticket:2229] + continue + + assert instance_state is not None + instance_dict = attributes.instance_dict(c) + + if halt_on and halt_on(instance_state): + continue + + if skip_pending and not instance_state.key: + continue + + instance_mapper = instance_state.manager.mapper + + if not instance_mapper.isa(self.mapper.class_manager.mapper): + raise AssertionError( + "Attribute '%s' on class '%s' " + "doesn't handle objects " + "of type '%s'" + % (self.key, self.parent.class_, c.__class__) + ) + + visited_states.add(instance_state) + + yield c, instance_mapper, instance_state, instance_dict + + @property + def _effective_sync_backref(self) -> bool: + if self.viewonly: + return False + else: + return self.sync_backref is not False + + @staticmethod + def _check_sync_backref( + rel_a: RelationshipProperty[Any], rel_b: RelationshipProperty[Any] + ) -> None: + if rel_a.viewonly and rel_b.sync_backref: + raise sa_exc.InvalidRequestError( + "Relationship %s cannot specify sync_backref=True since %s " + "includes viewonly=True." % (rel_b, rel_a) + ) + if ( + rel_a.viewonly + and not rel_b.viewonly + and rel_b.sync_backref is not False + ): + rel_b.sync_backref = False + + def _add_reverse_property(self, key: str) -> None: + other = self.mapper.get_property(key, _configure_mappers=False) + if not isinstance(other, RelationshipProperty): + raise sa_exc.InvalidRequestError( + "back_populates on relationship '%s' refers to attribute '%s' " + "that is not a relationship. The back_populates parameter " + "should refer to the name of a relationship on the target " + "class." % (self, other) + ) + # viewonly and sync_backref cases + # 1. self.viewonly==True and other.sync_backref==True -> error + # 2. self.viewonly==True and other.viewonly==False and + # other.sync_backref==None -> warn sync_backref=False, set to False + self._check_sync_backref(self, other) + # 3. other.viewonly==True and self.sync_backref==True -> error + # 4. other.viewonly==True and self.viewonly==False and + # self.sync_backref==None -> warn sync_backref=False, set to False + self._check_sync_backref(other, self) + + self._reverse_property.add(other) + other._reverse_property.add(self) + + other._setup_entity() + + if not other.mapper.common_parent(self.parent): + raise sa_exc.ArgumentError( + "reverse_property %r on " + "relationship %s references relationship %s, which " + "does not reference mapper %s" + % (key, self, other, self.parent) + ) + + if ( + other._configure_started + and self.direction in (ONETOMANY, MANYTOONE) + and self.direction == other.direction + ): + raise sa_exc.ArgumentError( + "%s and back-reference %s are " + "both of the same direction %r. Did you mean to " + "set remote_side on the many-to-one side ?" + % (other, self, self.direction) + ) + + @util.memoized_property + def entity(self) -> _InternalEntityType[_T]: + """Return the target mapped entity, which is an inspect() of the + class or aliased class that is referenced by this + :class:`.RelationshipProperty`. + + """ + self.parent._check_configure() + return self.entity + + @util.memoized_property + def mapper(self) -> Mapper[_T]: + """Return the targeted :class:`_orm.Mapper` for this + :class:`.RelationshipProperty`. + + """ + return self.entity.mapper + + def do_init(self) -> None: + self._check_conflicts() + self._process_dependent_arguments() + self._setup_entity() + self._setup_registry_dependencies() + self._setup_join_conditions() + self._check_cascade_settings(self._cascade) + self._post_init() + self._generate_backref() + self._join_condition._warn_for_conflicting_sync_targets() + super().do_init() + self._lazy_strategy = cast( + "LazyLoader", self._get_strategy((("lazy", "select"),)) + ) + + def _setup_registry_dependencies(self) -> None: + self.parent.mapper.registry._set_depends_on( + self.entity.mapper.registry + ) + + def _process_dependent_arguments(self) -> None: + """Convert incoming configuration arguments to their + proper form. + + Callables are resolved, ORM annotations removed. + + """ + + # accept callables for other attributes which may require + # deferred initialization. This technique is used + # by declarative "string configs" and some recipes. + init_args = self._init_args + + for attr in ( + "order_by", + "primaryjoin", + "secondaryjoin", + "secondary", + "foreign_keys", + "remote_side", + ): + rel_arg = getattr(init_args, attr) + + rel_arg._resolve_against_registry(self._clsregistry_resolvers[1]) + + # remove "annotations" which are present if mapped class + # descriptors are used to create the join expression. + for attr in "primaryjoin", "secondaryjoin": + rel_arg = getattr(init_args, attr) + val = rel_arg.resolved + if val is not None: + rel_arg.resolved = _orm_deannotate( + coercions.expect( + roles.ColumnArgumentRole, val, argname=attr + ) + ) + + secondary = init_args.secondary.resolved + if secondary is not None and _is_mapped_class(secondary): + raise sa_exc.ArgumentError( + "secondary argument %s passed to to relationship() %s must " + "be a Table object or other FROM clause; can't send a mapped " + "class directly as rows in 'secondary' are persisted " + "independently of a class that is mapped " + "to that same table." % (secondary, self) + ) + + # ensure expressions in self.order_by, foreign_keys, + # remote_side are all columns, not strings. + if ( + init_args.order_by.resolved is not False + and init_args.order_by.resolved is not None + ): + self.order_by = tuple( + coercions.expect( + roles.ColumnArgumentRole, x, argname="order_by" + ) + for x in util.to_list(init_args.order_by.resolved) + ) + else: + self.order_by = False + + self._user_defined_foreign_keys = util.column_set( + coercions.expect( + roles.ColumnArgumentRole, x, argname="foreign_keys" + ) + for x in util.to_column_set(init_args.foreign_keys.resolved) + ) + + self.remote_side = util.column_set( + coercions.expect( + roles.ColumnArgumentRole, x, argname="remote_side" + ) + for x in util.to_column_set(init_args.remote_side.resolved) + ) + + def declarative_scan( + self, + decl_scan: _ClassScanMapperConfig, + registry: _RegistryType, + cls: Type[Any], + originating_module: Optional[str], + key: str, + mapped_container: Optional[Type[Mapped[Any]]], + annotation: Optional[_AnnotationScanType], + extracted_mapped_annotation: Optional[_AnnotationScanType], + is_dataclass_field: bool, + ) -> None: + argument = extracted_mapped_annotation + + if extracted_mapped_annotation is None: + if self.argument is None: + self._raise_for_required(key, cls) + else: + return + + argument = extracted_mapped_annotation + assert originating_module is not None + + if mapped_container is not None: + is_write_only = issubclass(mapped_container, WriteOnlyMapped) + is_dynamic = issubclass(mapped_container, DynamicMapped) + if is_write_only: + self.lazy = "write_only" + self.strategy_key = (("lazy", self.lazy),) + elif is_dynamic: + self.lazy = "dynamic" + self.strategy_key = (("lazy", self.lazy),) + else: + is_write_only = is_dynamic = False + + argument = de_optionalize_union_types(argument) + + if hasattr(argument, "__origin__"): + arg_origin = argument.__origin__ + if isinstance(arg_origin, type) and issubclass( + arg_origin, abc.Collection + ): + if self.collection_class is None: + if _py_inspect.isabstract(arg_origin): + raise sa_exc.ArgumentError( + f"Collection annotation type {arg_origin} cannot " + "be instantiated; please provide an explicit " + "'collection_class' parameter " + "(e.g. list, set, etc.) to the " + "relationship() function to accompany this " + "annotation" + ) + + self.collection_class = arg_origin + + elif not is_write_only and not is_dynamic: + self.uselist = False + + if argument.__args__: # type: ignore + if isinstance(arg_origin, type) and issubclass( + arg_origin, typing.Mapping + ): + type_arg = argument.__args__[-1] # type: ignore + else: + type_arg = argument.__args__[0] # type: ignore + if hasattr(type_arg, "__forward_arg__"): + str_argument = type_arg.__forward_arg__ + + argument = resolve_name_to_real_class_name( + str_argument, originating_module + ) + else: + argument = type_arg + else: + raise sa_exc.ArgumentError( + f"Generic alias {argument} requires an argument" + ) + elif hasattr(argument, "__forward_arg__"): + argument = argument.__forward_arg__ + + argument = resolve_name_to_real_class_name( + argument, originating_module + ) + + if ( + self.collection_class is None + and not is_write_only + and not is_dynamic + ): + self.uselist = False + + # ticket #8759 + # if a lead argument was given to relationship(), like + # `relationship("B")`, use that, don't replace it with class we + # found in the annotation. The declarative_scan() method call here is + # still useful, as we continue to derive collection type and do + # checking of the annotation in any case. + if self.argument is None: + self.argument = cast("_RelationshipArgumentType[_T]", argument) + + @util.preload_module("sqlalchemy.orm.mapper") + def _setup_entity(self, __argument: Any = None) -> None: + if "entity" in self.__dict__: + return + + mapperlib = util.preloaded.orm_mapper + + if __argument: + argument = __argument + else: + argument = self.argument + + resolved_argument: _ExternalEntityType[Any] + + if isinstance(argument, str): + # we might want to cleanup clsregistry API to make this + # more straightforward + resolved_argument = cast( + "_ExternalEntityType[Any]", + self._clsregistry_resolve_name(argument)(), + ) + elif callable(argument) and not isinstance( + argument, (type, mapperlib.Mapper) + ): + resolved_argument = argument() + else: + resolved_argument = argument + + entity: _InternalEntityType[Any] + + if isinstance(resolved_argument, type): + entity = class_mapper(resolved_argument, configure=False) + else: + try: + entity = inspect(resolved_argument) + except sa_exc.NoInspectionAvailable: + entity = None # type: ignore + + if not hasattr(entity, "mapper"): + raise sa_exc.ArgumentError( + "relationship '%s' expects " + "a class or a mapper argument (received: %s)" + % (self.key, type(resolved_argument)) + ) + + self.entity = entity + self.target = self.entity.persist_selectable + + def _setup_join_conditions(self) -> None: + self._join_condition = jc = JoinCondition( + parent_persist_selectable=self.parent.persist_selectable, + child_persist_selectable=self.entity.persist_selectable, + parent_local_selectable=self.parent.local_table, + child_local_selectable=self.entity.local_table, + primaryjoin=self._init_args.primaryjoin.resolved, + secondary=self._init_args.secondary.resolved, + secondaryjoin=self._init_args.secondaryjoin.resolved, + parent_equivalents=self.parent._equivalent_columns, + child_equivalents=self.mapper._equivalent_columns, + consider_as_foreign_keys=self._user_defined_foreign_keys, + local_remote_pairs=self.local_remote_pairs, + remote_side=self.remote_side, + self_referential=self._is_self_referential, + prop=self, + support_sync=not self.viewonly, + can_be_synced_fn=self._columns_are_mapped, + ) + self.primaryjoin = jc.primaryjoin + self.secondaryjoin = jc.secondaryjoin + self.secondary = jc.secondary + self.direction = jc.direction + self.local_remote_pairs = jc.local_remote_pairs + self.remote_side = jc.remote_columns + self.local_columns = jc.local_columns + self.synchronize_pairs = jc.synchronize_pairs + self._calculated_foreign_keys = jc.foreign_key_columns + self.secondary_synchronize_pairs = jc.secondary_synchronize_pairs + + @property + def _clsregistry_resolve_arg( + self, + ) -> Callable[[str, bool], _class_resolver]: + return self._clsregistry_resolvers[1] + + @property + def _clsregistry_resolve_name( + self, + ) -> Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]]: + return self._clsregistry_resolvers[0] + + @util.memoized_property + @util.preload_module("sqlalchemy.orm.clsregistry") + def _clsregistry_resolvers( + self, + ) -> Tuple[ + Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], + Callable[[str, bool], _class_resolver], + ]: + _resolver = util.preloaded.orm_clsregistry._resolver + + return _resolver(self.parent.class_, self) + + def _check_conflicts(self) -> None: + """Test that this relationship is legal, warn about + inheritance conflicts.""" + if self.parent.non_primary and not class_mapper( + self.parent.class_, configure=False + ).has_property(self.key): + raise sa_exc.ArgumentError( + "Attempting to assign a new " + "relationship '%s' to a non-primary mapper on " + "class '%s'. New relationships can only be added " + "to the primary mapper, i.e. the very first mapper " + "created for class '%s' " + % ( + self.key, + self.parent.class_.__name__, + self.parent.class_.__name__, + ) + ) + + @property + def cascade(self) -> CascadeOptions: + """Return the current cascade setting for this + :class:`.RelationshipProperty`. + """ + return self._cascade + + @cascade.setter + def cascade(self, cascade: Union[str, CascadeOptions]) -> None: + self._set_cascade(cascade) + + def _set_cascade(self, cascade_arg: Union[str, CascadeOptions]) -> None: + cascade = CascadeOptions(cascade_arg) + + if self.viewonly: + cascade = CascadeOptions( + cascade.intersection(CascadeOptions._viewonly_cascades) + ) + + if "mapper" in self.__dict__: + self._check_cascade_settings(cascade) + self._cascade = cascade + + if self._dependency_processor: + self._dependency_processor.cascade = cascade + + def _check_cascade_settings(self, cascade: CascadeOptions) -> None: + if ( + cascade.delete_orphan + and not self.single_parent + and (self.direction is MANYTOMANY or self.direction is MANYTOONE) + ): + raise sa_exc.ArgumentError( + "For %(direction)s relationship %(rel)s, delete-orphan " + "cascade is normally " + 'configured only on the "one" side of a one-to-many ' + "relationship, " + 'and not on the "many" side of a many-to-one or many-to-many ' + "relationship. " + "To force this relationship to allow a particular " + '"%(relatedcls)s" object to be referenced by only ' + 'a single "%(clsname)s" object at a time via the ' + "%(rel)s relationship, which " + "would allow " + "delete-orphan cascade to take place in this direction, set " + "the single_parent=True flag." + % { + "rel": self, + "direction": ( + "many-to-one" + if self.direction is MANYTOONE + else "many-to-many" + ), + "clsname": self.parent.class_.__name__, + "relatedcls": self.mapper.class_.__name__, + }, + code="bbf0", + ) + + if self.passive_deletes == "all" and ( + "delete" in cascade or "delete-orphan" in cascade + ): + raise sa_exc.ArgumentError( + "On %s, can't set passive_deletes='all' in conjunction " + "with 'delete' or 'delete-orphan' cascade" % self + ) + + if cascade.delete_orphan: + self.mapper.primary_mapper()._delete_orphans.append( + (self.key, self.parent.class_) + ) + + def _persists_for(self, mapper: Mapper[Any]) -> bool: + """Return True if this property will persist values on behalf + of the given mapper. + + """ + + return ( + self.key in mapper.relationships + and mapper.relationships[self.key] is self + ) + + def _columns_are_mapped(self, *cols: ColumnElement[Any]) -> bool: + """Return True if all columns in the given collection are + mapped by the tables referenced by this :class:`.RelationshipProperty`. + + """ + + secondary = self._init_args.secondary.resolved + for c in cols: + if secondary is not None and secondary.c.contains_column(c): + continue + if not self.parent.persist_selectable.c.contains_column( + c + ) and not self.target.c.contains_column(c): + return False + return True + + def _generate_backref(self) -> None: + """Interpret the 'backref' instruction to create a + :func:`_orm.relationship` complementary to this one.""" + + if self.parent.non_primary: + return + if self.backref is not None and not self.back_populates: + kwargs: Dict[str, Any] + if isinstance(self.backref, str): + backref_key, kwargs = self.backref, {} + else: + backref_key, kwargs = self.backref + mapper = self.mapper.primary_mapper() + + if not mapper.concrete: + check = set(mapper.iterate_to_root()).union( + mapper.self_and_descendants + ) + for m in check: + if m.has_property(backref_key) and not m.concrete: + raise sa_exc.ArgumentError( + "Error creating backref " + "'%s' on relationship '%s': property of that " + "name exists on mapper '%s'" + % (backref_key, self, m) + ) + + # determine primaryjoin/secondaryjoin for the + # backref. Use the one we had, so that + # a custom join doesn't have to be specified in + # both directions. + if self.secondary is not None: + # for many to many, just switch primaryjoin/ + # secondaryjoin. use the annotated + # pj/sj on the _join_condition. + pj = kwargs.pop( + "primaryjoin", + self._join_condition.secondaryjoin_minus_local, + ) + sj = kwargs.pop( + "secondaryjoin", + self._join_condition.primaryjoin_minus_local, + ) + else: + pj = kwargs.pop( + "primaryjoin", + self._join_condition.primaryjoin_reverse_remote, + ) + sj = kwargs.pop("secondaryjoin", None) + if sj: + raise sa_exc.InvalidRequestError( + "Can't assign 'secondaryjoin' on a backref " + "against a non-secondary relationship." + ) + + foreign_keys = kwargs.pop( + "foreign_keys", self._user_defined_foreign_keys + ) + parent = self.parent.primary_mapper() + kwargs.setdefault("viewonly", self.viewonly) + kwargs.setdefault("post_update", self.post_update) + kwargs.setdefault("passive_updates", self.passive_updates) + kwargs.setdefault("sync_backref", self.sync_backref) + self.back_populates = backref_key + relationship = RelationshipProperty( + parent, + self.secondary, + primaryjoin=pj, + secondaryjoin=sj, + foreign_keys=foreign_keys, + back_populates=self.key, + **kwargs, + ) + mapper._configure_property( + backref_key, relationship, warn_for_existing=True + ) + + if self.back_populates: + self._add_reverse_property(self.back_populates) + + @util.preload_module("sqlalchemy.orm.dependency") + def _post_init(self) -> None: + dependency = util.preloaded.orm_dependency + + if self.uselist is None: + self.uselist = self.direction is not MANYTOONE + if not self.viewonly: + self._dependency_processor = ( # type: ignore + dependency.DependencyProcessor.from_relationship + )(self) + + @util.memoized_property + def _use_get(self) -> bool: + """memoize the 'use_get' attribute of this RelationshipLoader's + lazyloader.""" + + strategy = self._lazy_strategy + return strategy.use_get + + @util.memoized_property + def _is_self_referential(self) -> bool: + return self.mapper.common_parent(self.parent) + + def _create_joins( + self, + source_polymorphic: bool = False, + source_selectable: Optional[FromClause] = None, + dest_selectable: Optional[FromClause] = None, + of_type_entity: Optional[_InternalEntityType[Any]] = None, + alias_secondary: bool = False, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + FromClause, + FromClause, + Optional[FromClause], + Optional[ClauseAdapter], + ]: + aliased = False + + if alias_secondary and self.secondary is not None: + aliased = True + + if source_selectable is None: + if source_polymorphic and self.parent.with_polymorphic: + source_selectable = self.parent._with_polymorphic_selectable + + if of_type_entity: + dest_mapper = of_type_entity.mapper + if dest_selectable is None: + dest_selectable = of_type_entity.selectable + aliased = True + else: + dest_mapper = self.mapper + + if dest_selectable is None: + dest_selectable = self.entity.selectable + if self.mapper.with_polymorphic: + aliased = True + + if self._is_self_referential and source_selectable is None: + dest_selectable = dest_selectable._anonymous_fromclause() + aliased = True + elif ( + dest_selectable is not self.mapper._with_polymorphic_selectable + or self.mapper.with_polymorphic + ): + aliased = True + + single_crit = dest_mapper._single_table_criterion + aliased = aliased or ( + source_selectable is not None + and ( + source_selectable + is not self.parent._with_polymorphic_selectable + or source_selectable._is_subquery + ) + ) + + ( + primaryjoin, + secondaryjoin, + secondary, + target_adapter, + dest_selectable, + ) = self._join_condition.join_targets( + source_selectable, + dest_selectable, + aliased, + single_crit, + extra_criteria, + ) + if source_selectable is None: + source_selectable = self.parent.local_table + if dest_selectable is None: + dest_selectable = self.entity.local_table + return ( + primaryjoin, + secondaryjoin, + source_selectable, + dest_selectable, + secondary, + target_adapter, + ) + + +def _annotate_columns(element: _CE, annotations: _AnnotationDict) -> _CE: + def clone(elem: _CE) -> _CE: + if isinstance(elem, expression.ColumnClause): + elem = elem._annotate(annotations.copy()) # type: ignore + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + clone = None # type: ignore # remove gc cycles + return element + + +class JoinCondition: + primaryjoin_initial: Optional[ColumnElement[bool]] + primaryjoin: ColumnElement[bool] + secondaryjoin: Optional[ColumnElement[bool]] + secondary: Optional[FromClause] + prop: RelationshipProperty[Any] + + synchronize_pairs: _ColumnPairs + secondary_synchronize_pairs: _ColumnPairs + direction: RelationshipDirection + + parent_persist_selectable: FromClause + child_persist_selectable: FromClause + parent_local_selectable: FromClause + child_local_selectable: FromClause + + _local_remote_pairs: Optional[_ColumnPairs] + + def __init__( + self, + parent_persist_selectable: FromClause, + child_persist_selectable: FromClause, + parent_local_selectable: FromClause, + child_local_selectable: FromClause, + *, + primaryjoin: Optional[ColumnElement[bool]] = None, + secondary: Optional[FromClause] = None, + secondaryjoin: Optional[ColumnElement[bool]] = None, + parent_equivalents: Optional[_EquivalentColumnMap] = None, + child_equivalents: Optional[_EquivalentColumnMap] = None, + consider_as_foreign_keys: Any = None, + local_remote_pairs: Optional[_ColumnPairs] = None, + remote_side: Any = None, + self_referential: Any = False, + prop: RelationshipProperty[Any], + support_sync: bool = True, + can_be_synced_fn: Callable[..., bool] = lambda *c: True, + ): + self.parent_persist_selectable = parent_persist_selectable + self.parent_local_selectable = parent_local_selectable + self.child_persist_selectable = child_persist_selectable + self.child_local_selectable = child_local_selectable + self.parent_equivalents = parent_equivalents + self.child_equivalents = child_equivalents + self.primaryjoin_initial = primaryjoin + self.secondaryjoin = secondaryjoin + self.secondary = secondary + self.consider_as_foreign_keys = consider_as_foreign_keys + self._local_remote_pairs = local_remote_pairs + self._remote_side = remote_side + self.prop = prop + self.self_referential = self_referential + self.support_sync = support_sync + self.can_be_synced_fn = can_be_synced_fn + + self._determine_joins() + assert self.primaryjoin is not None + + self._sanitize_joins() + self._annotate_fks() + self._annotate_remote() + self._annotate_local() + self._annotate_parentmapper() + self._setup_pairs() + self._check_foreign_cols(self.primaryjoin, True) + if self.secondaryjoin is not None: + self._check_foreign_cols(self.secondaryjoin, False) + self._determine_direction() + self._check_remote_side() + self._log_joins() + + def _log_joins(self) -> None: + log = self.prop.logger + log.info("%s setup primary join %s", self.prop, self.primaryjoin) + log.info("%s setup secondary join %s", self.prop, self.secondaryjoin) + log.info( + "%s synchronize pairs [%s]", + self.prop, + ",".join( + "(%s => %s)" % (l, r) for (l, r) in self.synchronize_pairs + ), + ) + log.info( + "%s secondary synchronize pairs [%s]", + self.prop, + ",".join( + "(%s => %s)" % (l, r) + for (l, r) in self.secondary_synchronize_pairs or [] + ), + ) + log.info( + "%s local/remote pairs [%s]", + self.prop, + ",".join( + "(%s / %s)" % (l, r) for (l, r) in self.local_remote_pairs + ), + ) + log.info( + "%s remote columns [%s]", + self.prop, + ",".join("%s" % col for col in self.remote_columns), + ) + log.info( + "%s local columns [%s]", + self.prop, + ",".join("%s" % col for col in self.local_columns), + ) + log.info("%s relationship direction %s", self.prop, self.direction) + + def _sanitize_joins(self) -> None: + """remove the parententity annotation from our join conditions which + can leak in here based on some declarative patterns and maybe others. + + "parentmapper" is relied upon both by the ORM evaluator as well as + the use case in _join_fixture_inh_selfref_w_entity + that relies upon it being present, see :ticket:`3364`. + + """ + + self.primaryjoin = _deep_deannotate( + self.primaryjoin, values=("parententity", "proxy_key") + ) + if self.secondaryjoin is not None: + self.secondaryjoin = _deep_deannotate( + self.secondaryjoin, values=("parententity", "proxy_key") + ) + + def _determine_joins(self) -> None: + """Determine the 'primaryjoin' and 'secondaryjoin' attributes, + if not passed to the constructor already. + + This is based on analysis of the foreign key relationships + between the parent and target mapped selectables. + + """ + if self.secondaryjoin is not None and self.secondary is None: + raise sa_exc.ArgumentError( + "Property %s specified with secondary " + "join condition but " + "no secondary argument" % self.prop + ) + + # find a join between the given mapper's mapped table and + # the given table. will try the mapper's local table first + # for more specificity, then if not found will try the more + # general mapped table, which in the case of inheritance is + # a join. + try: + consider_as_foreign_keys = self.consider_as_foreign_keys or None + if self.secondary is not None: + if self.secondaryjoin is None: + self.secondaryjoin = join_condition( + self.child_persist_selectable, + self.secondary, + a_subset=self.child_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) + if self.primaryjoin_initial is None: + self.primaryjoin = join_condition( + self.parent_persist_selectable, + self.secondary, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) + else: + self.primaryjoin = self.primaryjoin_initial + else: + if self.primaryjoin_initial is None: + self.primaryjoin = join_condition( + self.parent_persist_selectable, + self.child_persist_selectable, + a_subset=self.parent_local_selectable, + consider_as_foreign_keys=consider_as_foreign_keys, + ) + else: + self.primaryjoin = self.primaryjoin_initial + except sa_exc.NoForeignKeysError as nfe: + if self.secondary is not None: + raise sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables via secondary table '%s'. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify 'primaryjoin' and 'secondaryjoin' " + "expressions." % (self.prop, self.secondary) + ) from nfe + else: + raise sa_exc.NoForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are no foreign keys " + "linking these tables. " + "Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or " + "specify a 'primaryjoin' expression." % self.prop + ) from nfe + except sa_exc.AmbiguousForeignKeysError as afe: + if self.secondary is not None: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables via secondary table '%s'. " + "Specify the 'foreign_keys' " + "argument, providing a list of those columns which " + "should be counted as containing a foreign key " + "reference from the secondary table to each of the " + "parent and child tables." % (self.prop, self.secondary) + ) from afe + else: + raise sa_exc.AmbiguousForeignKeysError( + "Could not determine join " + "condition between parent/child tables on " + "relationship %s - there are multiple foreign key " + "paths linking the tables. Specify the " + "'foreign_keys' argument, providing a list of those " + "columns which should be counted as containing a " + "foreign key reference to the parent table." % self.prop + ) from afe + + @property + def primaryjoin_minus_local(self) -> ColumnElement[bool]: + return _deep_deannotate(self.primaryjoin, values=("local", "remote")) + + @property + def secondaryjoin_minus_local(self) -> ColumnElement[bool]: + assert self.secondaryjoin is not None + return _deep_deannotate(self.secondaryjoin, values=("local", "remote")) + + @util.memoized_property + def primaryjoin_reverse_remote(self) -> ColumnElement[bool]: + """Return the primaryjoin condition suitable for the + "reverse" direction. + + If the primaryjoin was delivered here with pre-existing + "remote" annotations, the local/remote annotations + are reversed. Otherwise, the local/remote annotations + are removed. + + """ + if self._has_remote_annotations: + + def replace(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" in element._annotations: + v = dict(element._annotations) + del v["remote"] + v["local"] = True + return element._with_annotations(v) + elif "local" in element._annotations: + v = dict(element._annotations) + del v["local"] + v["remote"] = True + return element._with_annotations(v) + + return None + + return visitors.replacement_traverse(self.primaryjoin, {}, replace) + else: + if self._has_foreign_annotations: + # TODO: coverage + return _deep_deannotate( + self.primaryjoin, values=("local", "remote") + ) + else: + return _deep_deannotate(self.primaryjoin) + + def _has_annotation(self, clause: ClauseElement, annotation: str) -> bool: + for col in visitors.iterate(clause, {}): + if annotation in col._annotations: + return True + else: + return False + + @util.memoized_property + def _has_foreign_annotations(self) -> bool: + return self._has_annotation(self.primaryjoin, "foreign") + + @util.memoized_property + def _has_remote_annotations(self) -> bool: + return self._has_annotation(self.primaryjoin, "remote") + + def _annotate_fks(self) -> None: + """Annotate the primaryjoin and secondaryjoin + structures with 'foreign' annotations marking columns + considered as foreign. + + """ + if self._has_foreign_annotations: + return + + if self.consider_as_foreign_keys: + self._annotate_from_fk_list() + else: + self._annotate_present_fks() + + def _annotate_from_fk_list(self) -> None: + def check_fk(element: _CE, **kw: Any) -> Optional[_CE]: + if element in self.consider_as_foreign_keys: + return element._annotate({"foreign": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, check_fk + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, check_fk + ) + + def _annotate_present_fks(self) -> None: + if self.secondary is not None: + secondarycols = util.column_set(self.secondary.c) + else: + secondarycols = set() + + def is_foreign( + a: ColumnElement[Any], b: ColumnElement[Any] + ) -> Optional[ColumnElement[Any]]: + if isinstance(a, schema.Column) and isinstance(b, schema.Column): + if a.references(b): + return a + elif b.references(a): + return b + + if secondarycols: + if a in secondarycols and b not in secondarycols: + return a + elif b in secondarycols and a not in secondarycols: + return b + + return None + + def visit_binary(binary: BinaryExpression[Any]) -> None: + if not isinstance( + binary.left, sql.ColumnElement + ) or not isinstance(binary.right, sql.ColumnElement): + return + + if ( + "foreign" not in binary.left._annotations + and "foreign" not in binary.right._annotations + ): + col = is_foreign(binary.left, binary.right) + if col is not None: + if col.compare(binary.left): + binary.left = binary.left._annotate({"foreign": True}) + elif col.compare(binary.right): + binary.right = binary.right._annotate( + {"foreign": True} + ) + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, {"binary": visit_binary} + ) + if self.secondaryjoin is not None: + self.secondaryjoin = visitors.cloned_traverse( + self.secondaryjoin, {}, {"binary": visit_binary} + ) + + def _refers_to_parent_table(self) -> bool: + """Return True if the join condition contains column + comparisons where both columns are in both tables. + + """ + pt = self.parent_persist_selectable + mt = self.child_persist_selectable + result = False + + def visit_binary(binary: BinaryExpression[Any]) -> None: + nonlocal result + c, f = binary.left, binary.right + if ( + isinstance(c, expression.ColumnClause) + and isinstance(f, expression.ColumnClause) + and pt.is_derived_from(c.table) + and pt.is_derived_from(f.table) + and mt.is_derived_from(c.table) + and mt.is_derived_from(f.table) + ): + result = True + + visitors.traverse(self.primaryjoin, {}, {"binary": visit_binary}) + return result + + def _tables_overlap(self) -> bool: + """Return True if parent/child tables have some overlap.""" + + return selectables_overlap( + self.parent_persist_selectable, self.child_persist_selectable + ) + + def _annotate_remote(self) -> None: + """Annotate the primaryjoin and secondaryjoin + structures with 'remote' annotations marking columns + considered as part of the 'remote' side. + + """ + if self._has_remote_annotations: + return + + if self.secondary is not None: + self._annotate_remote_secondary() + elif self._local_remote_pairs or self._remote_side: + self._annotate_remote_from_args() + elif self._refers_to_parent_table(): + self._annotate_selfref( + lambda col: "foreign" in col._annotations, False + ) + elif self._tables_overlap(): + self._annotate_remote_with_overlap() + else: + self._annotate_remote_distinct_selectables() + + def _annotate_remote_secondary(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when 'secondary' is present. + + """ + + assert self.secondary is not None + fixed_secondary = self.secondary + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + if fixed_secondary.c.contains_column(element): + return element._annotate({"remote": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl + ) + + assert self.secondaryjoin is not None + self.secondaryjoin = visitors.replacement_traverse( + self.secondaryjoin, {}, repl + ) + + def _annotate_selfref( + self, fn: Callable[[ColumnElement[Any]], bool], remote_side_given: bool + ) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the relationship is detected as self-referential. + + """ + + def visit_binary(binary: BinaryExpression[Any]) -> None: + equated = binary.left.compare(binary.right) + if isinstance(binary.left, expression.ColumnClause) and isinstance( + binary.right, expression.ColumnClause + ): + # assume one to many - FKs are "remote" + if fn(binary.left): + binary.left = binary.left._annotate({"remote": True}) + if fn(binary.right) and not equated: + binary.right = binary.right._annotate({"remote": True}) + elif not remote_side_given: + self._warn_non_column_elements() + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, {"binary": visit_binary} + ) + + def _annotate_remote_from_args(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the 'remote_side' or '_local_remote_pairs' + arguments are used. + + """ + if self._local_remote_pairs: + if self._remote_side: + raise sa_exc.ArgumentError( + "remote_side argument is redundant " + "against more detailed _local_remote_side " + "argument." + ) + + remote_side = [r for (l, r) in self._local_remote_pairs] + else: + remote_side = self._remote_side + + if self._refers_to_parent_table(): + self._annotate_selfref(lambda col: col in remote_side, True) + else: + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + # use set() to avoid generating ``__eq__()`` expressions + # against each element + if element in set(remote_side): + return element._annotate({"remote": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl + ) + + def _annotate_remote_with_overlap(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables have some set of + tables in common, though is not a fully self-referential + relationship. + + """ + + def visit_binary(binary: BinaryExpression[Any]) -> None: + binary.left, binary.right = proc_left_right( + binary.left, binary.right + ) + binary.right, binary.left = proc_left_right( + binary.right, binary.left + ) + + check_entities = ( + self.prop is not None and self.prop.mapper is not self.prop.parent + ) + + def proc_left_right( + left: ColumnElement[Any], right: ColumnElement[Any] + ) -> Tuple[ColumnElement[Any], ColumnElement[Any]]: + if isinstance(left, expression.ColumnClause) and isinstance( + right, expression.ColumnClause + ): + if self.child_persist_selectable.c.contains_column( + right + ) and self.parent_persist_selectable.c.contains_column(left): + right = right._annotate({"remote": True}) + elif ( + check_entities + and right._annotations.get("parentmapper") is self.prop.mapper + ): + right = right._annotate({"remote": True}) + elif ( + check_entities + and left._annotations.get("parentmapper") is self.prop.mapper + ): + left = left._annotate({"remote": True}) + else: + self._warn_non_column_elements() + + return left, right + + self.primaryjoin = visitors.cloned_traverse( + self.primaryjoin, {}, {"binary": visit_binary} + ) + + def _annotate_remote_distinct_selectables(self) -> None: + """annotate 'remote' in primaryjoin, secondaryjoin + when the parent/child tables are entirely + separate. + + """ + + def repl(element: _CE, **kw: Any) -> Optional[_CE]: + if self.child_persist_selectable.c.contains_column(element) and ( + not self.parent_local_selectable.c.contains_column(element) + or self.child_local_selectable.c.contains_column(element) + ): + return element._annotate({"remote": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, repl + ) + + def _warn_non_column_elements(self) -> None: + util.warn( + "Non-simple column elements in primary " + "join condition for property %s - consider using " + "remote() annotations to mark the remote side." % self.prop + ) + + def _annotate_local(self) -> None: + """Annotate the primaryjoin and secondaryjoin + structures with 'local' annotations. + + This annotates all column elements found + simultaneously in the parent table + and the join condition that don't have a + 'remote' annotation set up from + _annotate_remote() or user-defined. + + """ + if self._has_annotation(self.primaryjoin, "local"): + return + + if self._local_remote_pairs: + local_side = util.column_set( + [l for (l, r) in self._local_remote_pairs] + ) + else: + local_side = util.column_set(self.parent_persist_selectable.c) + + def locals_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" not in element._annotations and element in local_side: + return element._annotate({"local": True}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, locals_ + ) + + def _annotate_parentmapper(self) -> None: + def parentmappers_(element: _CE, **kw: Any) -> Optional[_CE]: + if "remote" in element._annotations: + return element._annotate({"parentmapper": self.prop.mapper}) + elif "local" in element._annotations: + return element._annotate({"parentmapper": self.prop.parent}) + return None + + self.primaryjoin = visitors.replacement_traverse( + self.primaryjoin, {}, parentmappers_ + ) + + def _check_remote_side(self) -> None: + if not self.local_remote_pairs: + raise sa_exc.ArgumentError( + "Relationship %s could " + "not determine any unambiguous local/remote column " + "pairs based on join condition and remote_side " + "arguments. " + "Consider using the remote() annotation to " + "accurately mark those elements of the join " + "condition that are on the remote side of " + "the relationship." % (self.prop,) + ) + else: + not_target = util.column_set( + self.parent_persist_selectable.c + ).difference(self.child_persist_selectable.c) + + for _, rmt in self.local_remote_pairs: + if rmt in not_target: + util.warn( + "Expression %s is marked as 'remote', but these " + "column(s) are local to the local side. The " + "remote() annotation is needed only for a " + "self-referential relationship where both sides " + "of the relationship refer to the same tables." + % (rmt,) + ) + + def _check_foreign_cols( + self, join_condition: ColumnElement[bool], primary: bool + ) -> None: + """Check the foreign key columns collected and emit error + messages.""" + + can_sync = False + + foreign_cols = self._gather_columns_with_annotation( + join_condition, "foreign" + ) + + has_foreign = bool(foreign_cols) + + if primary: + can_sync = bool(self.synchronize_pairs) + else: + can_sync = bool(self.secondary_synchronize_pairs) + + if ( + self.support_sync + and can_sync + or (not self.support_sync and has_foreign) + ): + return + + # from here below is just determining the best error message + # to report. Check for a join condition using any operator + # (not just ==), perhaps they need to turn on "viewonly=True". + if self.support_sync and has_foreign and not can_sync: + err = ( + "Could not locate any simple equality expressions " + "involving locally mapped foreign key columns for " + "%s join condition " + "'%s' on relationship %s." + % ( + primary and "primary" or "secondary", + join_condition, + self.prop, + ) + ) + err += ( + " Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or are " + "annotated in the join condition with the foreign() " + "annotation. To allow comparison operators other than " + "'==', the relationship can be marked as viewonly=True." + ) + + raise sa_exc.ArgumentError(err) + else: + err = ( + "Could not locate any relevant foreign key columns " + "for %s join condition '%s' on relationship %s." + % ( + primary and "primary" or "secondary", + join_condition, + self.prop, + ) + ) + err += ( + " Ensure that referencing columns are associated " + "with a ForeignKey or ForeignKeyConstraint, or are " + "annotated in the join condition with the foreign() " + "annotation." + ) + raise sa_exc.ArgumentError(err) + + def _determine_direction(self) -> None: + """Determine if this relationship is one to many, many to one, + many to many. + + """ + if self.secondaryjoin is not None: + self.direction = MANYTOMANY + else: + parentcols = util.column_set(self.parent_persist_selectable.c) + targetcols = util.column_set(self.child_persist_selectable.c) + + # fk collection which suggests ONETOMANY. + onetomany_fk = targetcols.intersection(self.foreign_key_columns) + + # fk collection which suggests MANYTOONE. + + manytoone_fk = parentcols.intersection(self.foreign_key_columns) + + if onetomany_fk and manytoone_fk: + # fks on both sides. test for overlap of local/remote + # with foreign key. + # we will gather columns directly from their annotations + # without deannotating, so that we can distinguish on a column + # that refers to itself. + + # 1. columns that are both remote and FK suggest + # onetomany. + onetomany_local = self._gather_columns_with_annotation( + self.primaryjoin, "remote", "foreign" + ) + + # 2. columns that are FK but are not remote (e.g. local) + # suggest manytoone. + manytoone_local = { + c + for c in self._gather_columns_with_annotation( + self.primaryjoin, "foreign" + ) + if "remote" not in c._annotations + } + + # 3. if both collections are present, remove columns that + # refer to themselves. This is for the case of + # and_(Me.id == Me.remote_id, Me.version == Me.version) + if onetomany_local and manytoone_local: + self_equated = self.remote_columns.intersection( + self.local_columns + ) + onetomany_local = onetomany_local.difference(self_equated) + manytoone_local = manytoone_local.difference(self_equated) + + # at this point, if only one or the other collection is + # present, we know the direction, otherwise it's still + # ambiguous. + + if onetomany_local and not manytoone_local: + self.direction = ONETOMANY + elif manytoone_local and not onetomany_local: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError( + "Can't determine relationship" + " direction for relationship '%s' - foreign " + "key columns within the join condition are present " + "in both the parent and the child's mapped tables. " + "Ensure that only those columns referring " + "to a parent column are marked as foreign, " + "either via the foreign() annotation or " + "via the foreign_keys argument." % self.prop + ) + elif onetomany_fk: + self.direction = ONETOMANY + elif manytoone_fk: + self.direction = MANYTOONE + else: + raise sa_exc.ArgumentError( + "Can't determine relationship " + "direction for relationship '%s' - foreign " + "key columns are present in neither the parent " + "nor the child's mapped tables" % self.prop + ) + + def _deannotate_pairs( + self, collection: _ColumnPairIterable + ) -> _MutableColumnPairs: + """provide deannotation for the various lists of + pairs, so that using them in hashes doesn't incur + high-overhead __eq__() comparisons against + original columns mapped. + + """ + return [(x._deannotate(), y._deannotate()) for x, y in collection] + + def _setup_pairs(self) -> None: + sync_pairs: _MutableColumnPairs = [] + lrp: util.OrderedSet[Tuple[ColumnElement[Any], ColumnElement[Any]]] = ( + util.OrderedSet([]) + ) + secondary_sync_pairs: _MutableColumnPairs = [] + + def go( + joincond: ColumnElement[bool], + collection: _MutableColumnPairs, + ) -> None: + def visit_binary( + binary: BinaryExpression[Any], + left: ColumnElement[Any], + right: ColumnElement[Any], + ) -> None: + if ( + "remote" in right._annotations + and "remote" not in left._annotations + and self.can_be_synced_fn(left) + ): + lrp.add((left, right)) + elif ( + "remote" in left._annotations + and "remote" not in right._annotations + and self.can_be_synced_fn(right) + ): + lrp.add((right, left)) + if binary.operator is operators.eq and self.can_be_synced_fn( + left, right + ): + if "foreign" in right._annotations: + collection.append((left, right)) + elif "foreign" in left._annotations: + collection.append((right, left)) + + visit_binary_product(visit_binary, joincond) + + for joincond, collection in [ + (self.primaryjoin, sync_pairs), + (self.secondaryjoin, secondary_sync_pairs), + ]: + if joincond is None: + continue + go(joincond, collection) + + self.local_remote_pairs = self._deannotate_pairs(lrp) + self.synchronize_pairs = self._deannotate_pairs(sync_pairs) + self.secondary_synchronize_pairs = self._deannotate_pairs( + secondary_sync_pairs + ) + + _track_overlapping_sync_targets: weakref.WeakKeyDictionary[ + ColumnElement[Any], + weakref.WeakKeyDictionary[ + RelationshipProperty[Any], ColumnElement[Any] + ], + ] = weakref.WeakKeyDictionary() + + def _warn_for_conflicting_sync_targets(self) -> None: + if not self.support_sync: + return + + # we would like to detect if we are synchronizing any column + # pairs in conflict with another relationship that wishes to sync + # an entirely different column to the same target. This is a + # very rare edge case so we will try to minimize the memory/overhead + # impact of this check + for from_, to_ in [ + (from_, to_) for (from_, to_) in self.synchronize_pairs + ] + [ + (from_, to_) for (from_, to_) in self.secondary_synchronize_pairs + ]: + # save ourselves a ton of memory and overhead by only + # considering columns that are subject to a overlapping + # FK constraints at the core level. This condition can arise + # if multiple relationships overlap foreign() directly, but + # we're going to assume it's typically a ForeignKeyConstraint- + # level configuration that benefits from this warning. + + if to_ not in self._track_overlapping_sync_targets: + self._track_overlapping_sync_targets[to_] = ( + weakref.WeakKeyDictionary({self.prop: from_}) + ) + else: + other_props = [] + prop_to_from = self._track_overlapping_sync_targets[to_] + + for pr, fr_ in prop_to_from.items(): + if ( + not pr.mapper._dispose_called + and pr not in self.prop._reverse_property + and pr.key not in self.prop._overlaps + and self.prop.key not in pr._overlaps + # note: the "__*" symbol is used internally by + # SQLAlchemy as a general means of suppressing the + # overlaps warning for some extension cases, however + # this is not currently + # a publicly supported symbol and may change at + # any time. + and "__*" not in self.prop._overlaps + and "__*" not in pr._overlaps + and not self.prop.parent.is_sibling(pr.parent) + and not self.prop.mapper.is_sibling(pr.mapper) + and not self.prop.parent.is_sibling(pr.mapper) + and not self.prop.mapper.is_sibling(pr.parent) + and ( + self.prop.key != pr.key + or not self.prop.parent.common_parent(pr.parent) + ) + ): + other_props.append((pr, fr_)) + + if other_props: + util.warn( + "relationship '%s' will copy column %s to column %s, " + "which conflicts with relationship(s): %s. " + "If this is not the intention, consider if these " + "relationships should be linked with " + "back_populates, or if viewonly=True should be " + "applied to one or more if they are read-only. " + "For the less common case that foreign key " + "constraints are partially overlapping, the " + "orm.foreign() " + "annotation can be used to isolate the columns that " + "should be written towards. To silence this " + "warning, add the parameter 'overlaps=\"%s\"' to the " + "'%s' relationship." + % ( + self.prop, + from_, + to_, + ", ".join( + sorted( + "'%s' (copies %s to %s)" % (pr, fr_, to_) + for (pr, fr_) in other_props + ) + ), + ",".join(sorted(pr.key for pr, fr in other_props)), + self.prop, + ), + code="qzyx", + ) + self._track_overlapping_sync_targets[to_][self.prop] = from_ + + @util.memoized_property + def remote_columns(self) -> Set[ColumnElement[Any]]: + return self._gather_join_annotations("remote") + + @util.memoized_property + def local_columns(self) -> Set[ColumnElement[Any]]: + return self._gather_join_annotations("local") + + @util.memoized_property + def foreign_key_columns(self) -> Set[ColumnElement[Any]]: + return self._gather_join_annotations("foreign") + + def _gather_join_annotations( + self, annotation: str + ) -> Set[ColumnElement[Any]]: + s = set( + self._gather_columns_with_annotation(self.primaryjoin, annotation) + ) + if self.secondaryjoin is not None: + s.update( + self._gather_columns_with_annotation( + self.secondaryjoin, annotation + ) + ) + return {x._deannotate() for x in s} + + def _gather_columns_with_annotation( + self, clause: ColumnElement[Any], *annotation: Iterable[str] + ) -> Set[ColumnElement[Any]]: + annotation_set = set(annotation) + return { + cast(ColumnElement[Any], col) + for col in visitors.iterate(clause, {}) + if annotation_set.issubset(col._annotations) + } + + @util.memoized_property + def _secondary_lineage_set(self) -> FrozenSet[ColumnElement[Any]]: + if self.secondary is not None: + return frozenset( + itertools.chain(*[c.proxy_set for c in self.secondary.c]) + ) + else: + return util.EMPTY_SET + + def join_targets( + self, + source_selectable: Optional[FromClause], + dest_selectable: FromClause, + aliased: bool, + single_crit: Optional[ColumnElement[bool]] = None, + extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ) -> Tuple[ + ColumnElement[bool], + Optional[ColumnElement[bool]], + Optional[FromClause], + Optional[ClauseAdapter], + FromClause, + ]: + """Given a source and destination selectable, create a + join between them. + + This takes into account aliasing the join clause + to reference the appropriate corresponding columns + in the target objects, as well as the extra child + criterion, equivalent column sets, etc. + + """ + # place a barrier on the destination such that + # replacement traversals won't ever dig into it. + # its internal structure remains fixed + # regardless of context. + dest_selectable = _shallow_annotate( + dest_selectable, {"no_replacement_traverse": True} + ) + + primaryjoin, secondaryjoin, secondary = ( + self.primaryjoin, + self.secondaryjoin, + self.secondary, + ) + + # adjust the join condition for single table inheritance, + # in the case that the join is to a subclass + # this is analogous to the + # "_adjust_for_single_table_inheritance()" method in Query. + + if single_crit is not None: + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & single_crit + else: + primaryjoin = primaryjoin & single_crit + + if extra_criteria: + + def mark_exclude_cols( + elem: SupportsAnnotations, annotations: _AnnotationDict + ) -> SupportsAnnotations: + """note unrelated columns in the "extra criteria" as either + should be adapted or not adapted, even though they are not + part of our "local" or "remote" side. + + see #9779 for this case, as well as #11010 for a follow up + + """ + + parentmapper_for_element = elem._annotations.get( + "parentmapper", None + ) + + if ( + parentmapper_for_element is not self.prop.parent + and parentmapper_for_element is not self.prop.mapper + and elem not in self._secondary_lineage_set + ): + return _safe_annotate(elem, annotations) + else: + return elem + + extra_criteria = tuple( + _deep_annotate( + elem, + {"should_not_adapt": True}, + annotate_callable=mark_exclude_cols, + ) + for elem in extra_criteria + ) + + if secondaryjoin is not None: + secondaryjoin = secondaryjoin & sql.and_(*extra_criteria) + else: + primaryjoin = primaryjoin & sql.and_(*extra_criteria) + + if aliased: + if secondary is not None: + secondary = secondary._anonymous_fromclause(flat=True) + primary_aliasizer = ClauseAdapter( + secondary, + exclude_fn=_local_col_exclude, + ) + secondary_aliasizer = ClauseAdapter( + dest_selectable, equivalents=self.child_equivalents + ).chain(primary_aliasizer) + if source_selectable is not None: + primary_aliasizer = ClauseAdapter( + secondary, + exclude_fn=_local_col_exclude, + ).chain( + ClauseAdapter( + source_selectable, + equivalents=self.parent_equivalents, + ) + ) + + secondaryjoin = secondary_aliasizer.traverse(secondaryjoin) + else: + primary_aliasizer = ClauseAdapter( + dest_selectable, + exclude_fn=_local_col_exclude, + equivalents=self.child_equivalents, + ) + if source_selectable is not None: + primary_aliasizer.chain( + ClauseAdapter( + source_selectable, + exclude_fn=_remote_col_exclude, + equivalents=self.parent_equivalents, + ) + ) + secondary_aliasizer = None + + primaryjoin = primary_aliasizer.traverse(primaryjoin) + target_adapter = secondary_aliasizer or primary_aliasizer + target_adapter.exclude_fn = None + else: + target_adapter = None + return ( + primaryjoin, + secondaryjoin, + secondary, + target_adapter, + dest_selectable, + ) + + def create_lazy_clause(self, reverse_direction: bool = False) -> Tuple[ + ColumnElement[bool], + Dict[str, ColumnElement[Any]], + Dict[ColumnElement[Any], ColumnElement[Any]], + ]: + binds: Dict[ColumnElement[Any], BindParameter[Any]] = {} + equated_columns: Dict[ColumnElement[Any], ColumnElement[Any]] = {} + + has_secondary = self.secondaryjoin is not None + + if has_secondary: + lookup = collections.defaultdict(list) + for l, r in self.local_remote_pairs: + lookup[l].append((l, r)) + equated_columns[r] = l + elif not reverse_direction: + for l, r in self.local_remote_pairs: + equated_columns[r] = l + else: + for l, r in self.local_remote_pairs: + equated_columns[l] = r + + def col_to_bind( + element: ColumnElement[Any], **kw: Any + ) -> Optional[BindParameter[Any]]: + if ( + (not reverse_direction and "local" in element._annotations) + or reverse_direction + and ( + (has_secondary and element in lookup) + or (not has_secondary and "remote" in element._annotations) + ) + ): + if element not in binds: + binds[element] = sql.bindparam( + None, None, type_=element.type, unique=True + ) + return binds[element] + return None + + lazywhere = self.primaryjoin + if self.secondaryjoin is None or not reverse_direction: + lazywhere = visitors.replacement_traverse( + lazywhere, {}, col_to_bind + ) + + if self.secondaryjoin is not None: + secondaryjoin = self.secondaryjoin + if reverse_direction: + secondaryjoin = visitors.replacement_traverse( + secondaryjoin, {}, col_to_bind + ) + lazywhere = sql.and_(lazywhere, secondaryjoin) + + bind_to_col = {binds[col].key: col for col in binds} + + return lazywhere, bind_to_col, equated_columns + + +class _ColInAnnotations: + """Serializable object that tests for names in c._annotations. + + TODO: does this need to be serializable anymore? can we find what the + use case was for that? + + """ + + __slots__ = ("names",) + + def __init__(self, *names: str): + self.names = frozenset(names) + + def __call__(self, c: ClauseElement) -> bool: + return bool(self.names.intersection(c._annotations)) + + +_local_col_exclude = _ColInAnnotations("local", "should_not_adapt") +_remote_col_exclude = _ColInAnnotations("remote", "should_not_adapt") + + +class Relationship( + RelationshipProperty[_T], + _DeclarativeMapped[_T], +): + """Describes an object property that holds a single item or list + of items that correspond to a related database table. + + Public constructor is the :func:`_orm.relationship` function. + + .. seealso:: + + :ref:`relationship_config_toplevel` + + .. versionchanged:: 2.0 Added :class:`_orm.Relationship` as a Declarative + compatible subclass for :class:`_orm.RelationshipProperty`. + + """ + + inherit_cache = True + """:meta private:""" + + +class _RelationshipDeclared( # type: ignore[misc] + Relationship[_T], + WriteOnlyMapped[_T], # not compatible with Mapped[_T] + DynamicMapped[_T], # not compatible with Mapped[_T] +): + """Relationship subclass used implicitly for declarative mapping.""" + + inherit_cache = True + """:meta private:""" + + @classmethod + def _mapper_property_name(cls) -> str: + return "Relationship" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/scoping.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/scoping.py new file mode 100644 index 0000000..819616a --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/scoping.py @@ -0,0 +1,2165 @@ +# orm/scoping.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from .session import _S +from .session import Session +from .. import exc as sa_exc +from .. import util +from ..util import create_proxy_methods +from ..util import ScopedRegistry +from ..util import ThreadLocalRegistry +from ..util import warn +from ..util import warn_deprecated +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import OrmExecuteOptionsParameter + from .identity import IdentityMap + from .interfaces import ORMOption + from .mapper import Mapper + from .query import Query + from .query import RowReturningQuery + from .session import _BindArguments + from .session import _EntityBindKey + from .session import _PKIdentityArgument + from .session import _SessionBind + from .session import sessionmaker + from .session import SessionTransaction + from ..engine import Connection + from ..engine import CursorResult + from ..engine import Engine + from ..engine import Result + from ..engine import Row + from ..engine import RowMapping + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import CoreExecuteOptionsParameter + from ..engine.result import ScalarResult + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.base import Executable + from ..sql.dml import UpdateBase + from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole + from ..sql.selectable import ForUpdateParameter + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + + +class QueryPropertyDescriptor(Protocol): + """Describes the type applied to a class-level + :meth:`_orm.scoped_session.query_property` attribute. + + .. versionadded:: 2.0.5 + + """ + + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... + + +_O = TypeVar("_O", bound=object) + +__all__ = ["scoped_session"] + + +@create_proxy_methods( + Session, + ":class:`_orm.Session`", + ":class:`_orm.scoping.scoped_session`", + classmethods=["close_all", "object_session", "identity_key"], + methods=[ + "__contains__", + "__iter__", + "add", + "add_all", + "begin", + "begin_nested", + "close", + "reset", + "commit", + "connection", + "delete", + "execute", + "expire", + "expire_all", + "expunge", + "expunge_all", + "flush", + "get", + "get_one", + "get_bind", + "is_modified", + "bulk_save_objects", + "bulk_insert_mappings", + "bulk_update_mappings", + "merge", + "query", + "refresh", + "rollback", + "scalar", + "scalars", + ], + attributes=[ + "bind", + "dirty", + "deleted", + "new", + "identity_map", + "is_active", + "autoflush", + "no_autoflush", + "info", + ], +) +class scoped_session(Generic[_S]): + """Provides scoped management of :class:`.Session` objects. + + See :ref:`unitofwork_contextual` for a tutorial. + + .. note:: + + When using :ref:`asyncio_toplevel`, the async-compatible + :class:`_asyncio.async_scoped_session` class should be + used in place of :class:`.scoped_session`. + + """ + + _support_async: bool = False + + session_factory: sessionmaker[_S] + """The `session_factory` provided to `__init__` is stored in this + attribute and may be accessed at a later time. This can be useful when + a new non-scoped :class:`.Session` is needed.""" + + registry: ScopedRegistry[_S] + + def __init__( + self, + session_factory: sessionmaker[_S], + scopefunc: Optional[Callable[[], Any]] = None, + ): + """Construct a new :class:`.scoped_session`. + + :param session_factory: a factory to create new :class:`.Session` + instances. This is usually, but not necessarily, an instance + of :class:`.sessionmaker`. + :param scopefunc: optional function which defines + the current scope. If not passed, the :class:`.scoped_session` + object assumes "thread-local" scope, and will use + a Python ``threading.local()`` in order to maintain the current + :class:`.Session`. If passed, the function should return + a hashable token; this token will be used as the key in a + dictionary in order to store and retrieve the current + :class:`.Session`. + + """ + self.session_factory = session_factory + + if scopefunc: + self.registry = ScopedRegistry(session_factory, scopefunc) + else: + self.registry = ThreadLocalRegistry(session_factory) + + @property + def _proxied(self) -> _S: + return self.registry() + + def __call__(self, **kw: Any) -> _S: + r"""Return the current :class:`.Session`, creating it + using the :attr:`.scoped_session.session_factory` if not present. + + :param \**kw: Keyword arguments will be passed to the + :attr:`.scoped_session.session_factory` callable, if an existing + :class:`.Session` is not present. If the :class:`.Session` is present + and keyword arguments have been passed, + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + if kw: + if self.registry.has(): + raise sa_exc.InvalidRequestError( + "Scoped session is already present; " + "no new arguments may be specified." + ) + else: + sess = self.session_factory(**kw) + self.registry.set(sess) + else: + sess = self.registry() + if not self._support_async and sess._is_asyncio: + warn_deprecated( + "Using `scoped_session` with asyncio is deprecated and " + "will raise an error in a future version. " + "Please use `async_scoped_session` instead.", + "1.4.23", + ) + return sess + + def configure(self, **kwargs: Any) -> None: + """reconfigure the :class:`.sessionmaker` used by this + :class:`.scoped_session`. + + See :meth:`.sessionmaker.configure`. + + """ + + if self.registry.has(): + warn( + "At least one scoped session is already present. " + " configure() can not affect sessions that have " + "already been created." + ) + + self.session_factory.configure(**kwargs) + + def remove(self) -> None: + """Dispose of the current :class:`.Session`, if present. + + This will first call :meth:`.Session.close` method + on the current :class:`.Session`, which releases any existing + transactional/connection resources still being held; transactions + specifically are rolled back. The :class:`.Session` is then + discarded. Upon next usage within the same scope, + the :class:`.scoped_session` will produce a new + :class:`.Session` object. + + """ + + if self.registry.has(): + self.registry().close() + self.registry.clear() + + def query_property( + self, query_cls: Optional[Type[Query[_T]]] = None + ) -> QueryPropertyDescriptor: + """return a class property which produces a legacy + :class:`_query.Query` object against the class and the current + :class:`.Session` when called. + + .. legacy:: The :meth:`_orm.scoped_session.query_property` accessor + is specific to the legacy :class:`.Query` object and is not + considered to be part of :term:`2.0-style` ORM use. + + e.g.:: + + from sqlalchemy.orm import QueryPropertyDescriptor + from sqlalchemy.orm import scoped_session + from sqlalchemy.orm import sessionmaker + + Session = scoped_session(sessionmaker()) + + class MyClass: + query: QueryPropertyDescriptor = Session.query_property() + + # after mappers are defined + result = MyClass.query.filter(MyClass.name=='foo').all() + + Produces instances of the session's configured query class by + default. To override and use a custom implementation, provide + a ``query_cls`` callable. The callable will be invoked with + the class's mapper as a positional argument and a session + keyword argument. + + There is no limit to the number of query properties placed on + a class. + + """ + + class query: + def __get__(s, instance: Any, owner: Type[_O]) -> Query[_O]: + if query_cls: + # custom query class + return query_cls(owner, session=self.registry()) # type: ignore # noqa: E501 + else: + # session's configured query class + return self.registry().query(owner) + + return query() + + # START PROXY METHODS scoped_session + + # code within this block is **programmatically, + # statically generated** by tools/generate_proxy_methods.py + + def __contains__(self, instance: object) -> bool: + r"""Return True if the instance is associated with this session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The instance may be pending or persistent within the Session for a + result of True. + + + """ # noqa: E501 + + return self._proxied.__contains__(instance) + + def __iter__(self) -> Iterator[object]: + r"""Iterate over all pending or persistent instances within this + Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + + """ # noqa: E501 + + return self._proxied.__iter__() + + def add(self, instance: object, _warn: bool = True) -> None: + r"""Place an object into this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add(instance, _warn=_warn) + + def add_all(self, instances: Iterable[object]) -> None: + r"""Add the given collection of instances to this :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.add_all(instances) + + def begin(self, nested: bool = False) -> SessionTransaction: + r"""Begin a transaction, or nested transaction, + on this :class:`.Session`, if one is not already begun. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The :class:`_orm.Session` object features **autobegin** behavior, + so that normally it is not necessary to call the + :meth:`_orm.Session.begin` + method explicitly. However, it may be used in order to control + the scope of when the transactional state is begun. + + When used to begin the outermost transaction, an error is raised + if this :class:`.Session` is already inside of a transaction. + + :param nested: if True, begins a SAVEPOINT transaction and is + equivalent to calling :meth:`~.Session.begin_nested`. For + documentation on SAVEPOINT transactions, please see + :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` + acts as a Python context manager, allowing :meth:`.Session.begin` + to be used in a "with" block. See :ref:`session_explicit_begin` for + an example. + + .. seealso:: + + :ref:`session_autobegin` + + :ref:`unitofwork_transaction` + + :meth:`.Session.begin_nested` + + + + """ # noqa: E501 + + return self._proxied.begin(nested=nested) + + def begin_nested(self) -> SessionTransaction: + r"""Begin a "nested" transaction on this Session, e.g. SAVEPOINT. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The target database(s) and associated drivers must support SQL + SAVEPOINT for this method to function correctly. + + For documentation on SAVEPOINT + transactions, please see :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` acts as a context manager, allowing + :meth:`.Session.begin_nested` to be used in a "with" block. + See :ref:`session_begin_nested` for a usage example. + + .. seealso:: + + :ref:`session_begin_nested` + + :ref:`pysqlite_serializable` - special workarounds required + with the SQLite driver in order for SAVEPOINT to work + correctly. For asyncio use cases, see the section + :ref:`aiosqlite_serializable`. + + + """ # noqa: E501 + + return self._proxied.begin_nested() + + def close(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This expunges all ORM objects associated with this + :class:`_orm.Session`, ends any transaction in progress and + :term:`releases` any :class:`_engine.Connection` objects which this + :class:`_orm.Session` itself has checked out from associated + :class:`_engine.Engine` objects. The operation then leaves the + :class:`_orm.Session` in a state which it may be used again. + + .. tip:: + + In the default running mode the :meth:`_orm.Session.close` + method **does not prevent the Session from being used again**. + The :class:`_orm.Session` itself does not actually have a + distinct "closed" state; it merely means + the :class:`_orm.Session` will release all database connections + and ORM objects. + + Setting the parameter :paramref:`_orm.Session.close_resets_only` + to ``False`` will instead make the ``close`` final, meaning that + any further action on the session will be forbidden. + + .. versionchanged:: 1.4 The :meth:`.Session.close` method does not + immediately create a new :class:`.SessionTransaction` object; + instead, the new :class:`.SessionTransaction` is created only if + the :class:`.Session` is used again for a database operation. + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_orm.Session.close` and :meth:`_orm.Session.reset`. + + :meth:`_orm.Session.reset` - a similar method that behaves like + ``close()`` with the parameter + :paramref:`_orm.Session.close_resets_only` set to ``True``. + + + """ # noqa: E501 + + return self._proxied.close() + + def reset(self) -> None: + r"""Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`, resetting the session to its initial state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This method provides for same "reset-only" behavior that the + :meth:`_orm.Session.close` method has provided historically, where the + state of the :class:`_orm.Session` is reset as though the object were + brand new, and ready to be used again. + This method may then be useful for :class:`_orm.Session` objects + which set :paramref:`_orm.Session.close_resets_only` to ``False``, + so that "reset only" behavior is still available. + + .. versionadded:: 2.0.22 + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_orm.Session.close` and :meth:`_orm.Session.reset`. + + :meth:`_orm.Session.close` - a similar method will additionally + prevent re-use of the Session when the parameter + :paramref:`_orm.Session.close_resets_only` is set to ``False``. + + """ # noqa: E501 + + return self._proxied.reset() + + def commit(self) -> None: + r"""Flush pending changes and commit the current transaction. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + When the COMMIT operation is complete, all objects are fully + :term:`expired`, erasing their internal contents, which will be + automatically re-loaded when the objects are next accessed. In the + interim, these objects are in an expired state and will not function if + they are :term:`detached` from the :class:`.Session`. Additionally, + this re-load operation is not supported when using asyncio-oriented + APIs. The :paramref:`.Session.expire_on_commit` parameter may be used + to disable this behavior. + + When there is no transaction in place for the :class:`.Session`, + indicating that no operations were invoked on this :class:`.Session` + since the previous call to :meth:`.Session.commit`, the method will + begin and commit an internal-only "logical" transaction, that does not + normally affect the database unless pending flush changes were + detected, but will still invoke event handlers and object expiration + rules. + + The outermost database transaction is committed unconditionally, + automatically releasing any SAVEPOINTs in effect. + + .. seealso:: + + :ref:`session_committing` + + :ref:`unitofwork_transaction` + + :ref:`asyncio_orm_avoid_lazyloads` + + + """ # noqa: E501 + + return self._proxied.commit() + + def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Connection: + r"""Return a :class:`_engine.Connection` object corresponding to this + :class:`.Session` object's transactional state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Either the :class:`_engine.Connection` corresponding to the current + transaction is returned, or if no transaction is in progress, a new + one is begun and the :class:`_engine.Connection` + returned (note that no + transactional state is established with the DBAPI until the first + SQL statement is emitted). + + Ambiguity in multi-bind or unbound :class:`.Session` objects can be + resolved through any of the optional keyword arguments. This + ultimately makes usage of the :meth:`.get_bind` method for resolution. + + :param bind_arguments: dictionary of bind arguments. May include + "mapper", "bind", "clause", other custom arguments that are passed + to :meth:`.Session.get_bind`. + + :param execution_options: a dictionary of execution options that will + be passed to :meth:`_engine.Connection.execution_options`, **when the + connection is first procured only**. If the connection is already + present within the :class:`.Session`, a warning is emitted and + the arguments are ignored. + + .. seealso:: + + :ref:`session_transaction_isolation` + + + """ # noqa: E501 + + return self._proxied.connection( + bind_arguments=bind_arguments, execution_options=execution_options + ) + + def delete(self, instance: object) -> None: + r"""Mark an instance as deleted. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The object is assumed to be either :term:`persistent` or + :term:`detached` when passed; after the method is called, the + object will remain in the :term:`persistent` state until the next + flush proceeds. During this time, the object will also be a member + of the :attr:`_orm.Session.deleted` collection. + + When the next flush proceeds, the object will move to the + :term:`deleted` state, indicating a ``DELETE`` statement was emitted + for its row within the current transaction. When the transaction + is successfully committed, + the deleted object is moved to the :term:`detached` state and is + no longer present within this :class:`_orm.Session`. + + .. seealso:: + + :ref:`session_deleting` - at :ref:`session_basics` + + + """ # noqa: E501 + + return self._proxied.delete(instance) + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: ... + + @overload + def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + r"""Execute a SQL expression construct. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Returns a :class:`_engine.Result` object representing + results of the statement execution. + + E.g.:: + + from sqlalchemy import select + result = session.execute( + select(User).where(User.id == 5) + ) + + The API contract of :meth:`_orm.Session.execute` is similar to that + of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version + of :class:`_engine.Connection`. + + .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is + now the primary point of ORM statement execution when using + :term:`2.0 style` ORM usage. + + :param statement: + An executable statement (i.e. an :class:`.Executable` expression + such as :func:`_expression.select`). + + :param params: + Optional dictionary, or list of dictionaries, containing + bound parameter values. If a single dictionary, single-row + execution occurs; if a list of dictionaries, an + "executemany" will be invoked. The keys in each dictionary + must correspond to parameter names present in the statement. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`, and may also + provide additional options understood only in an ORM context. + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + :return: a :class:`_engine.Result` object. + + + + """ # noqa: E501 + + return self._proxied.execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + ) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + r"""Expire the attributes on an instance. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire(instance, attribute_names=attribute_names) + + def expire_all(self) -> None: + r"""Expires all persistent instances within this Session. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + + """ # noqa: E501 + + return self._proxied.expire_all() + + def expunge(self, instance: object) -> None: + r"""Remove the `instance` from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + + """ # noqa: E501 + + return self._proxied.expunge(instance) + + def expunge_all(self) -> None: + r"""Remove all object instances from this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + + """ # noqa: E501 + + return self._proxied.expunge_all() + + def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + r"""Flush all the object changes to the database. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Writes out all pending object creations, deletions and modifications + to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are + automatically ordered by the Session's unit of work dependency + solver. + + Database operations will be issued in the current transactional + context and do not affect the state of the transaction, unless an + error occurs, in which case the entire transaction is rolled back. + You may flush() as often as you like within a transaction to move + changes from Python to the database's transaction buffer. + + :param objects: Optional; restricts the flush operation to operate + only on elements that are in the given collection. + + This feature is for an extremely narrow set of use cases where + particular objects may need to be operated upon before the + full flush() occurs. It is not intended for general use. + + + """ # noqa: E501 + + return self._proxied.flush(objects=objects) + + def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Optional[_O]: + r"""Return an instance based on the given primary key identifier, + or ``None`` if not found. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + E.g.:: + + my_user = session.get(User, 5) + + some_object = session.get(VersionedFoo, (5, 10)) + + some_object = session.get( + VersionedFoo, + {"id": 5, "version_id": 10} + ) + + .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved + from the now legacy :meth:`_orm.Query.get` method. + + :meth:`_orm.Session.get` is special in that it provides direct + access to the identity map of the :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`_orm.Session.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :param entity: a mapped class or :class:`.Mapper` indicating the + type of entity to be loaded. + + :param ident: A scalar, tuple, or dictionary representing the + primary key. For a composite (e.g. multiple column) primary key, + a tuple or dictionary should be passed. + + For a single-column primary key, the scalar calling form is typically + the most expedient. If the primary key of a row is the value "5", + the call looks like:: + + my_object = session.get(SomeClass, 5) + + The tuple form contains primary key values typically in + the order in which they correspond to the mapped + :class:`_schema.Table` + object's primary key columns, or if the + :paramref:`_orm.Mapper.primary_key` configuration parameter were + used, in + the order used for that parameter. For example, if the primary key + of a row is represented by the integer + digits "5, 10" the call would look like:: + + my_object = session.get(SomeClass, (5, 10)) + + The dictionary form should include as keys the mapped attribute names + corresponding to each element of the primary key. If the mapped class + has the attributes ``id``, ``version_id`` as the attributes which + store the object's primary key value, the call would look like:: + + my_object = session.get(SomeClass, {"id": 5, "version_id": 10}) + + :param options: optional sequence of loader options which will be + applied to the query, if one is emitted. + + :param populate_existing: causes the method to unconditionally emit + a SQL query and refresh the object with the newly loaded data, + regardless of whether or not the object is already present. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + :param execution_options: optional dictionary of execution options, + which will be associated with the query execution if one is emitted. + This dictionary can provide a subset of the options that are + accepted by :meth:`_engine.Connection.execution_options`, and may + also provide additional options understood only in an ORM context. + + .. versionadded:: 1.4.29 + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + .. versionadded: 2.0.0rc1 + + :return: The object instance, or ``None``. + + + """ # noqa: E501 + + return self._proxied.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> _O: + r"""Return exactly one instance based on the given primary key + identifier, or raise an exception if not found. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query + selects no rows. + + For a detailed documentation of the arguments see the + method :meth:`.Session.get`. + + .. versionadded:: 2.0.22 + + :return: The object instance. + + .. seealso:: + + :meth:`.Session.get` - equivalent method that instead + returns ``None`` if no row was found with the provided primary + key + + + """ # noqa: E501 + + return self._proxied.get_one( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + _sa_skip_events: Optional[bool] = None, + _sa_skip_for_implicit_returning: bool = False, + **kw: Any, + ) -> Union[Engine, Connection]: + r"""Return a "bind" to which this :class:`.Session` is bound. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The "bind" is usually an instance of :class:`_engine.Engine`, + except in the case where the :class:`.Session` has been + explicitly bound directly to a :class:`_engine.Connection`. + + For a multiply-bound or unbound :class:`.Session`, the + ``mapper`` or ``clause`` arguments are used to determine the + appropriate bind to return. + + Note that the "mapper" argument is usually present + when :meth:`.Session.get_bind` is called via an ORM + operation such as a :meth:`.Session.query`, each + individual INSERT/UPDATE/DELETE operation within a + :meth:`.Session.flush`, call, etc. + + The order of resolution is: + + 1. if mapper given and :paramref:`.Session.binds` is present, + locate a bind based first on the mapper in use, then + on the mapped class in use, then on any base classes that are + present in the ``__mro__`` of the mapped class, from more specific + superclasses to more general. + 2. if clause given and ``Session.binds`` is present, + locate a bind based on :class:`_schema.Table` objects + found in the given clause present in ``Session.binds``. + 3. if ``Session.binds`` is present, return that. + 4. if clause given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the clause. + 5. if mapper given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the :class:`_schema.Table` or other + selectable to which the mapper is mapped. + 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError` + is raised. + + Note that the :meth:`.Session.get_bind` method can be overridden on + a user-defined subclass of :class:`.Session` to provide any kind + of bind resolution scheme. See the example at + :ref:`session_custom_partitioning`. + + :param mapper: + Optional mapped class or corresponding :class:`_orm.Mapper` instance. + The bind can be derived from a :class:`_orm.Mapper` first by + consulting the "binds" map associated with this :class:`.Session`, + and secondly by consulting the :class:`_schema.MetaData` associated + with the :class:`_schema.Table` to which the :class:`_orm.Mapper` is + mapped for a bind. + + :param clause: + A :class:`_expression.ClauseElement` (i.e. + :func:`_expression.select`, + :func:`_expression.text`, + etc.). If the ``mapper`` argument is not present or could not + produce a bind, the given expression construct will be searched + for a bound element, typically a :class:`_schema.Table` + associated with + bound :class:`_schema.MetaData`. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_mapper` + + :meth:`.Session.bind_table` + + + """ # noqa: E501 + + return self._proxied.get_bind( + mapper=mapper, + clause=clause, + bind=bind, + _sa_skip_events=_sa_skip_events, + _sa_skip_for_implicit_returning=_sa_skip_for_implicit_returning, + **kw, + ) + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + + """ # noqa: E501 + + return self._proxied.is_modified( + instance, include_collections=include_collections + ) + + def bulk_save_objects( + self, + objects: Iterable[object], + return_defaults: bool = False, + update_changed_only: bool = True, + preserve_order: bool = True, + ) -> None: + r"""Perform a bulk save of the given list of objects. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. + + For general INSERT and UPDATE of existing ORM mapped objects, + prefer standard :term:`unit of work` data management patterns, + introduced in the :ref:`unified_tutorial` at + :ref:`tutorial_orm_data_manipulation`. SQLAlchemy 2.0 + now uses :ref:`engine_insertmanyvalues` with modern dialects + which solves previous issues of bulk INSERT slowness. + + :param objects: a sequence of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. It is strongly + advised to please use the standard :meth:`_orm.Session.add_all` + approach. + + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + + :param preserve_order: when True, the order of inserts and updates + matches exactly the order in which the objects are given. When + False, common types of objects are grouped into inserts + and updates, to allow for more batching opportunities. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` + + + """ # noqa: E501 + + return self._proxied.bulk_save_objects( + objects, + return_defaults=return_defaults, + update_changed_only=update_changed_only, + preserve_order=preserve_order, + ) + + def bulk_insert_mappings( + self, + mapper: Mapper[Any], + mappings: Iterable[Dict[str, Any]], + return_defaults: bool = False, + render_nulls: bool = False, + ) -> None: + r"""Perform a bulk insert of the given list of mapping dictionaries. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be inserted, in terms of the attribute + names on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain all + keys to be populated into all tables. + + :param return_defaults: when True, the INSERT process will be altered + to ensure that newly generated primary key values will be fetched. + The rationale for this parameter is typically to enable + :ref:`Joined Table Inheritance ` mappings to + be bulk inserted. + + .. note:: for backends that don't support RETURNING, the + :paramref:`_orm.Session.bulk_insert_mappings.return_defaults` + parameter can significantly decrease performance as INSERT + statements can no longer be batched. See + :ref:`engine_insertmanyvalues` + for background on which backends are affected. + + :param render_nulls: When True, a value of ``None`` will result + in a NULL value being included in the INSERT statement, rather + than the column being omitted from the INSERT. This allows all + the rows being INSERTed to have the identical set of columns which + allows the full set of rows to be batched to the DBAPI. Normally, + each column-set that contains a different combination of NULL values + than the previous row must omit a different series of columns from + the rendered INSERT statement, which means it must be emitted as a + separate statement. By passing this flag, the full set of rows + are guaranteed to be batchable into one batch; the cost however is + that server-side defaults which are invoked by an omitted column will + be skipped, so care must be taken to ensure that these are not + necessary. + + .. warning:: + + When this flag is set, **server side default SQL values will + not be invoked** for those columns that are inserted as NULL; + the NULL value will be sent explicitly. Care must be taken + to ensure that no server-side default functions need to be + invoked for the operation as a whole. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + + """ # noqa: E501 + + return self._proxied.bulk_insert_mappings( + mapper, + mappings, + return_defaults=return_defaults, + render_nulls=render_nulls, + ) + + def bulk_update_mappings( + self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] + ) -> None: + r"""Perform a bulk update of the given list of mapping dictionaries. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, such + as a joined-inheritance mapping, each dictionary may contain keys + corresponding to all tables. All those keys which are present and + are not part of the primary key are applied to the SET clause of the + UPDATE statement; the primary key values, which are required, are + applied to the WHERE clause. + + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + + """ # noqa: E501 + + return self._proxied.bulk_update_mappings(mapper, mappings) + + def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + r"""Copy the state of a given instance into a corresponding instance + within this :class:`.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + :meth:`.Session.merge` examines the primary key attributes of the + source instance, and attempts to reconcile it with an instance of the + same primary key in the session. If not found locally, it attempts + to load the object from the database based on primary key, and if + none can be located, creates a new instance. The state of each + attribute on the source instance is then copied to the target + instance. The resulting target instance is then returned by the + method; the original source instance is left unmodified, and + un-associated with the :class:`.Session` if not already. + + This operation cascades to associated instances if the association is + mapped with ``cascade="merge"``. + + See :ref:`unitofwork_merging` for a detailed discussion of merging. + + :param instance: Instance to be merged. + :param load: Boolean, when False, :meth:`.merge` switches into + a "high performance" mode which causes it to forego emitting history + events as well as all database access. This flag is used for + cases such as transferring graphs of objects into a :class:`.Session` + from a second level cache, or to transfer just-loaded objects + into the :class:`.Session` owned by a worker thread or process + without re-querying the database. + + The ``load=False`` use case adds the caveat that the given + object has to be in a "clean" state, that is, has no pending changes + to be flushed - even if the incoming object is detached from any + :class:`.Session`. This is so that when + the merge operation populates local attributes and + cascades to related objects and + collections, the values can be "stamped" onto the + target object as is, without generating any history or attribute + events, and without the need to reconcile the incoming data with + any existing related objects or collections that might not + be loaded. The resulting objects from ``load=False`` are always + produced as "clean", so it is only appropriate that the given objects + should be "clean" as well, else this suggests a mis-use of the + method. + :param options: optional sequence of loader options which will be + applied to the :meth:`_orm.Session.get` method when the merge + operation loads the existing version of the object from the database. + + .. versionadded:: 1.4.24 + + + .. seealso:: + + :func:`.make_transient_to_detached` - provides for an alternative + means of "merging" a single object into the :class:`.Session` + + + """ # noqa: E501 + + return self._proxied.merge(instance, load=load, options=options) + + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... + + @overload + def query( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: ... + + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + r"""Return a new :class:`_query.Query` object corresponding to this + :class:`_orm.Session`. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Note that the :class:`_query.Query` object is legacy as of + SQLAlchemy 2.0; the :func:`_sql.select` construct is now used + to construct ORM queries. + + .. seealso:: + + :ref:`unified_tutorial` + + :ref:`queryguide_toplevel` + + :ref:`query_api_toplevel` - legacy API doc + + + """ # noqa: E501 + + return self._proxied.query(*entities, **kwargs) + + def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: ForUpdateParameter = None, + ) -> None: + r"""Expire and refresh attributes on the given instance. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + The selected attributes will first be expired as they would when using + :meth:`_orm.Session.expire`; then a SELECT statement will be issued to + the database to refresh column-oriented attributes with the current + value available in the current transaction. + + :func:`_orm.relationship` oriented attributes will also be immediately + loaded if they were already eagerly loaded on the object, using the + same eager loading strategy that they were loaded with originally. + + .. versionadded:: 1.4 - the :meth:`_orm.Session.refresh` method + can also refresh eagerly loaded attributes. + + :func:`_orm.relationship` oriented attributes that would normally + load using the ``select`` (or "lazy") loader strategy will also + load **if they are named explicitly in the attribute_names + collection**, emitting a SELECT statement for the attribute using the + ``immediate`` loader strategy. If lazy-loaded relationships are not + named in :paramref:`_orm.Session.refresh.attribute_names`, then + they remain as "lazy loaded" attributes and are not implicitly + refreshed. + + .. versionchanged:: 2.0.4 The :meth:`_orm.Session.refresh` method + will now refresh lazy-loaded :func:`_orm.relationship` oriented + attributes for those which are named explicitly in the + :paramref:`_orm.Session.refresh.attribute_names` collection. + + .. tip:: + + While the :meth:`_orm.Session.refresh` method is capable of + refreshing both column and relationship oriented attributes, its + primary focus is on refreshing of local column-oriented attributes + on a single instance. For more open ended "refresh" functionality, + including the ability to refresh the attributes on many objects at + once while having explicit control over relationship loader + strategies, use the + :ref:`populate existing ` feature + instead. + + Note that a highly isolated transaction will return the same values as + were previously read in that same transaction, regardless of changes + in database state outside of that transaction. Refreshing + attributes usually only makes sense at the start of a transaction + where database rows have not yet been accessed. + + :param attribute_names: optional. An iterable collection of + string attribute names indicating a subset of attributes to + be refreshed. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.expire_all` + + :ref:`orm_queryguide_populate_existing` - allows any ORM query + to refresh objects as they would be loaded normally. + + + """ # noqa: E501 + + return self._proxied.refresh( + instance, + attribute_names=attribute_names, + with_for_update=with_for_update, + ) + + def rollback(self) -> None: + r"""Rollback the current transaction in progress. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + If no transaction is in progress, this method is a pass-through. + + The method always rolls back + the topmost database transaction, discarding any nested + transactions that may be in progress. + + .. seealso:: + + :ref:`session_rollback` + + :ref:`unitofwork_transaction` + + + """ # noqa: E501 + + return self._proxied.rollback() + + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: ... + + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + r"""Execute a statement and return a scalar result. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a scalar Python + value. + + + """ # noqa: E501 + + return self._proxied.scalar( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: ... + + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + r"""Execute a statement and return the results as scalars. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a + :class:`_result.ScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_orm.Session.scalars` + + .. versionadded:: 1.4.26 Added :meth:`_orm.scoped_session.scalars` + + .. seealso:: + + :ref:`orm_queryguide_select_orm_entities` - contrasts the behavior + of :meth:`_orm.Session.execute` to :meth:`_orm.Session.scalars` + + + """ # noqa: E501 + + return self._proxied.scalars( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + @property + def bind(self) -> Optional[Union[Engine, Connection]]: + r"""Proxy for the :attr:`_orm.Session.bind` attribute + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.bind + + @bind.setter + def bind(self, attr: Optional[Union[Engine, Connection]]) -> None: + self._proxied.bind = attr + + @property + def dirty(self) -> Any: + r"""The set of all persistent instances considered dirty. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + + """ # noqa: E501 + + return self._proxied.dirty + + @property + def deleted(self) -> Any: + r"""The set of all instances marked as 'deleted' within this ``Session`` + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.deleted + + @property + def new(self) -> Any: + r"""The set of all instances marked as 'new' within this ``Session``. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.new + + @property + def identity_map(self) -> IdentityMap: + r"""Proxy for the :attr:`_orm.Session.identity_map` attribute + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.identity_map + + @identity_map.setter + def identity_map(self, attr: IdentityMap) -> None: + self._proxied.identity_map = attr + + @property + def is_active(self) -> Any: + r"""True if this :class:`.Session` not in "partial rollback" state. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + + """ # noqa: E501 + + return self._proxied.is_active + + @property + def autoflush(self) -> bool: + r"""Proxy for the :attr:`_orm.Session.autoflush` attribute + on behalf of the :class:`_orm.scoping.scoped_session` class. + + """ # noqa: E501 + + return self._proxied.autoflush + + @autoflush.setter + def autoflush(self, attr: bool) -> None: + self._proxied.autoflush = attr + + @property + def no_autoflush(self) -> Any: + r"""Return a context manager that disables autoflush. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + + """ # noqa: E501 + + return self._proxied.no_autoflush + + @property + def info(self) -> Any: + r"""A user-modifiable dictionary. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class + on behalf of the :class:`_orm.scoping.scoped_session` class. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + + """ # noqa: E501 + + return self._proxied.info + + @classmethod + def close_all(cls) -> None: + r"""Close *all* sessions in memory. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + .. deprecated:: 1.3 The :meth:`.Session.close_all` method is deprecated and will be removed in a future release. Please refer to :func:`.session.close_all_sessions`. + + """ # noqa: E501 + + return Session.close_all() + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + r"""Return the :class:`.Session` to which an object belongs. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This is an alias of :func:`.object_session`. + + + """ # noqa: E501 + + return Session.object_session(instance) + + @classmethod + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + r"""Return an identity key. + + .. container:: class_bases + + Proxied for the :class:`_orm.Session` class on + behalf of the :class:`_orm.scoping.scoped_session` class. + + This is an alias of :func:`.util.identity_key`. + + + """ # noqa: E501 + + return Session.identity_key( + class_=class_, + ident=ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + # END PROXY METHODS scoped_session + + +ScopedSession = scoped_session +"""Old name for backwards compatibility.""" diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/session.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/session.py new file mode 100644 index 0000000..3eba5aa --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/session.py @@ -0,0 +1,5238 @@ +# orm/session.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Provides the Session class and related utilities.""" + +from __future__ import annotations + +import contextlib +from enum import Enum +import itertools +import sys +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes +from . import bulk_persistence +from . import context +from . import descriptor_props +from . import exc +from . import identity +from . import loading +from . import query +from . import state as statelib +from ._typing import _O +from ._typing import insp_is_mapper +from ._typing import is_composite_class +from ._typing import is_orm_option +from ._typing import is_user_defined_option +from .base import _class_to_mapper +from .base import _none_set +from .base import _state_mapper +from .base import instance_str +from .base import LoaderCallableStatus +from .base import object_mapper +from .base import object_state +from .base import PassiveFlag +from .base import state_str +from .context import FromStatement +from .context import ORMCompileState +from .identity import IdentityMap +from .query import Query +from .state import InstanceState +from .state_changes import _StateChange +from .state_changes import _StateChangeState +from .state_changes import _StateChangeStates +from .unitofwork import UOWTransaction +from .. import engine +from .. import exc as sa_exc +from .. import sql +from .. import util +from ..engine import Connection +from ..engine import Engine +from ..engine.util import TransactionalContext +from ..event import dispatcher +from ..event import EventTarget +from ..inspection import inspect +from ..inspection import Inspectable +from ..sql import coercions +from ..sql import dml +from ..sql import roles +from ..sql import Select +from ..sql import TableClause +from ..sql import visitors +from ..sql.base import _NoArg +from ..sql.base import CompileState +from ..sql.schema import Table +from ..sql.selectable import ForUpdateArg +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..util import IdentitySet +from ..util.typing import Literal +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import OrmExecuteOptionsParameter + from .interfaces import ORMOption + from .interfaces import UserDefinedOption + from .mapper import Mapper + from .path_registry import PathRegistry + from .query import RowReturningQuery + from ..engine import CursorResult + from ..engine import Result + from ..engine import Row + from ..engine import RowMapping + from ..engine.base import Transaction + from ..engine.base import TwoPhaseTransaction + from ..engine.interfaces import _CoreAnyExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import CoreExecuteOptionsParameter + from ..engine.result import ScalarResult + from ..event import _InstanceLevelDispatch + from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _InfoType + from ..sql._typing import _T0 + from ..sql._typing import _T1 + from ..sql._typing import _T2 + from ..sql._typing import _T3 + from ..sql._typing import _T4 + from ..sql._typing import _T5 + from ..sql._typing import _T6 + from ..sql._typing import _T7 + from ..sql._typing import _TypedColumnClauseArgument as _TCCA + from ..sql.base import Executable + from ..sql.base import ExecutableOption + from ..sql.dml import UpdateBase + from ..sql.elements import ClauseElement + from ..sql.roles import TypedColumnsClauseRole + from ..sql.selectable import ForUpdateParameter + from ..sql.selectable import TypedReturnsRows + +_T = TypeVar("_T", bound=Any) + +__all__ = [ + "Session", + "SessionTransaction", + "sessionmaker", + "ORMExecuteState", + "close_all_sessions", + "make_transient", + "make_transient_to_detached", + "object_session", +] + +_sessions: weakref.WeakValueDictionary[int, Session] = ( + weakref.WeakValueDictionary() +) +"""Weak-referencing dictionary of :class:`.Session` objects. +""" + +statelib._sessions = _sessions + +_PKIdentityArgument = Union[Any, Tuple[Any, ...]] + +_BindArguments = Dict[str, Any] + +_EntityBindKey = Union[Type[_O], "Mapper[_O]"] +_SessionBindKey = Union[Type[Any], "Mapper[Any]", "TableClause", str] +_SessionBind = Union["Engine", "Connection"] + +JoinTransactionMode = Literal[ + "conditional_savepoint", + "rollback_only", + "control_fully", + "create_savepoint", +] + + +class _ConnectionCallableProto(Protocol): + """a callable that returns a :class:`.Connection` given an instance. + + This callable, when present on a :class:`.Session`, is called only from the + ORM's persistence mechanism (i.e. the unit of work flush process) to allow + for connection-per-instance schemes (i.e. horizontal sharding) to be used + as persistence time. + + This callable is not present on a plain :class:`.Session`, however + is established when using the horizontal sharding extension. + + """ + + def __call__( + self, + mapper: Optional[Mapper[Any]] = None, + instance: Optional[object] = None, + **kw: Any, + ) -> Connection: ... + + +def _state_session(state: InstanceState[Any]) -> Optional[Session]: + """Given an :class:`.InstanceState`, return the :class:`.Session` + associated, if any. + """ + return state.session + + +class _SessionClassMethods: + """Class-level methods for :class:`.Session`, :class:`.sessionmaker`.""" + + @classmethod + @util.deprecated( + "1.3", + "The :meth:`.Session.close_all` method is deprecated and will be " + "removed in a future release. Please refer to " + ":func:`.session.close_all_sessions`.", + ) + def close_all(cls) -> None: + """Close *all* sessions in memory.""" + + close_all_sessions() + + @classmethod + @util.preload_module("sqlalchemy.orm.util") + def identity_key( + cls, + class_: Optional[Type[Any]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[Any] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, + ) -> _IdentityKeyType[Any]: + """Return an identity key. + + This is an alias of :func:`.util.identity_key`. + + """ + return util.preloaded.orm_util.identity_key( + class_, + ident, + instance=instance, + row=row, + identity_token=identity_token, + ) + + @classmethod + def object_session(cls, instance: object) -> Optional[Session]: + """Return the :class:`.Session` to which an object belongs. + + This is an alias of :func:`.object_session`. + + """ + + return object_session(instance) + + +class SessionTransactionState(_StateChangeState): + ACTIVE = 1 + PREPARED = 2 + COMMITTED = 3 + DEACTIVE = 4 + CLOSED = 5 + PROVISIONING_CONNECTION = 6 + + +# backwards compatibility +ACTIVE, PREPARED, COMMITTED, DEACTIVE, CLOSED, PROVISIONING_CONNECTION = tuple( + SessionTransactionState +) + + +class ORMExecuteState(util.MemoizedSlots): + """Represents a call to the :meth:`_orm.Session.execute` method, as passed + to the :meth:`.SessionEvents.do_orm_execute` event hook. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`session_execute_events` - top level documentation on how + to use :meth:`_orm.SessionEvents.do_orm_execute` + + """ + + __slots__ = ( + "session", + "statement", + "parameters", + "execution_options", + "local_execution_options", + "bind_arguments", + "identity_token", + "_compile_state_cls", + "_starting_event_idx", + "_events_todo", + "_update_execution_options", + ) + + session: Session + """The :class:`_orm.Session` in use.""" + + statement: Executable + """The SQL statement being invoked. + + For an ORM selection as would + be retrieved from :class:`_orm.Query`, this is an instance of + :class:`_sql.select` that was generated from the ORM query. + """ + + parameters: Optional[_CoreAnyExecuteParams] + """Dictionary of parameters that was passed to + :meth:`_orm.Session.execute`.""" + + execution_options: _ExecuteOptions + """The complete dictionary of current execution options. + + This is a merge of the statement level options with the + locally passed execution options. + + .. seealso:: + + :attr:`_orm.ORMExecuteState.local_execution_options` + + :meth:`_sql.Executable.execution_options` + + :ref:`orm_queryguide_execution_options` + + """ + + local_execution_options: _ExecuteOptions + """Dictionary view of the execution options passed to the + :meth:`.Session.execute` method. + + This does not include options that may be associated with the statement + being invoked. + + .. seealso:: + + :attr:`_orm.ORMExecuteState.execution_options` + + """ + + bind_arguments: _BindArguments + """The dictionary passed as the + :paramref:`_orm.Session.execute.bind_arguments` dictionary. + + This dictionary may be used by extensions to :class:`_orm.Session` to pass + arguments that will assist in determining amongst a set of database + connections which one should be used to invoke this statement. + + """ + + _compile_state_cls: Optional[Type[ORMCompileState]] + _starting_event_idx: int + _events_todo: List[Any] + _update_execution_options: Optional[_ExecuteOptions] + + def __init__( + self, + session: Session, + statement: Executable, + parameters: Optional[_CoreAnyExecuteParams], + execution_options: _ExecuteOptions, + bind_arguments: _BindArguments, + compile_state_cls: Optional[Type[ORMCompileState]], + events_todo: List[_InstanceLevelDispatch[Session]], + ): + """Construct a new :class:`_orm.ORMExecuteState`. + + this object is constructed internally. + + """ + self.session = session + self.statement = statement + self.parameters = parameters + self.local_execution_options = execution_options + self.execution_options = statement._execution_options.union( + execution_options + ) + self.bind_arguments = bind_arguments + self._compile_state_cls = compile_state_cls + self._events_todo = list(events_todo) + + def _remaining_events(self) -> List[_InstanceLevelDispatch[Session]]: + return self._events_todo[self._starting_event_idx + 1 :] + + def invoke_statement( + self, + statement: Optional[Executable] = None, + params: Optional[_CoreAnyExecuteParams] = None, + execution_options: Optional[OrmExecuteOptionsParameter] = None, + bind_arguments: Optional[_BindArguments] = None, + ) -> Result[Any]: + """Execute the statement represented by this + :class:`.ORMExecuteState`, without re-invoking events that have + already proceeded. + + This method essentially performs a re-entrant execution of the current + statement for which the :meth:`.SessionEvents.do_orm_execute` event is + being currently invoked. The use case for this is for event handlers + that want to override how the ultimate + :class:`_engine.Result` object is returned, such as for schemes that + retrieve results from an offline cache or which concatenate results + from multiple executions. + + When the :class:`_engine.Result` object is returned by the actual + handler function within :meth:`_orm.SessionEvents.do_orm_execute` and + is propagated to the calling + :meth:`_orm.Session.execute` method, the remainder of the + :meth:`_orm.Session.execute` method is preempted and the + :class:`_engine.Result` object is returned to the caller of + :meth:`_orm.Session.execute` immediately. + + :param statement: optional statement to be invoked, in place of the + statement currently represented by :attr:`.ORMExecuteState.statement`. + + :param params: optional dictionary of parameters or list of parameters + which will be merged into the existing + :attr:`.ORMExecuteState.parameters` of this :class:`.ORMExecuteState`. + + .. versionchanged:: 2.0 a list of parameter dictionaries is accepted + for executemany executions. + + :param execution_options: optional dictionary of execution options + will be merged into the existing + :attr:`.ORMExecuteState.execution_options` of this + :class:`.ORMExecuteState`. + + :param bind_arguments: optional dictionary of bind_arguments + which will be merged amongst the current + :attr:`.ORMExecuteState.bind_arguments` + of this :class:`.ORMExecuteState`. + + :return: a :class:`_engine.Result` object with ORM-level results. + + .. seealso:: + + :ref:`do_orm_execute_re_executing` - background and examples on the + appropriate usage of :meth:`_orm.ORMExecuteState.invoke_statement`. + + + """ + + if statement is None: + statement = self.statement + + _bind_arguments = dict(self.bind_arguments) + if bind_arguments: + _bind_arguments.update(bind_arguments) + _bind_arguments["_sa_skip_events"] = True + + _params: Optional[_CoreAnyExecuteParams] + if params: + if self.is_executemany: + _params = [] + exec_many_parameters = cast( + "List[Dict[str, Any]]", self.parameters + ) + for _existing_params, _new_params in itertools.zip_longest( + exec_many_parameters, + cast("List[Dict[str, Any]]", params), + ): + if _existing_params is None or _new_params is None: + raise sa_exc.InvalidRequestError( + f"Can't apply executemany parameters to " + f"statement; number of parameter sets passed to " + f"Session.execute() ({len(exec_many_parameters)}) " + f"does not match number of parameter sets given " + f"to ORMExecuteState.invoke_statement() " + f"({len(params)})" + ) + _existing_params = dict(_existing_params) + _existing_params.update(_new_params) + _params.append(_existing_params) + else: + _params = dict(cast("Dict[str, Any]", self.parameters)) + _params.update(cast("Dict[str, Any]", params)) + else: + _params = self.parameters + + _execution_options = self.local_execution_options + if execution_options: + _execution_options = _execution_options.union(execution_options) + + return self.session._execute_internal( + statement, + _params, + execution_options=_execution_options, + bind_arguments=_bind_arguments, + _parent_execute_state=self, + ) + + @property + def bind_mapper(self) -> Optional[Mapper[Any]]: + """Return the :class:`_orm.Mapper` that is the primary "bind" mapper. + + For an :class:`_orm.ORMExecuteState` object invoking an ORM + statement, that is, the :attr:`_orm.ORMExecuteState.is_orm_statement` + attribute is ``True``, this attribute will return the + :class:`_orm.Mapper` that is considered to be the "primary" mapper + of the statement. The term "bind mapper" refers to the fact that + a :class:`_orm.Session` object may be "bound" to multiple + :class:`_engine.Engine` objects keyed to mapped classes, and the + "bind mapper" determines which of those :class:`_engine.Engine` objects + would be selected. + + For a statement that is invoked against a single mapped class, + :attr:`_orm.ORMExecuteState.bind_mapper` is intended to be a reliable + way of getting this mapper. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :attr:`_orm.ORMExecuteState.all_mappers` + + + """ + mp: Optional[Mapper[Any]] = self.bind_arguments.get("mapper", None) + return mp + + @property + def all_mappers(self) -> Sequence[Mapper[Any]]: + """Return a sequence of all :class:`_orm.Mapper` objects that are + involved at the top level of this statement. + + By "top level" we mean those :class:`_orm.Mapper` objects that would + be represented in the result set rows for a :func:`_sql.select` + query, or for a :func:`_dml.update` or :func:`_dml.delete` query, + the mapper that is the main subject of the UPDATE or DELETE. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :attr:`_orm.ORMExecuteState.bind_mapper` + + + + """ + if not self.is_orm_statement: + return [] + elif isinstance(self.statement, (Select, FromStatement)): + result = [] + seen = set() + for d in self.statement.column_descriptions: + ent = d["entity"] + if ent: + insp = inspect(ent, raiseerr=False) + if insp and insp.mapper and insp.mapper not in seen: + seen.add(insp.mapper) + result.append(insp.mapper) + return result + elif self.statement.is_dml and self.bind_mapper: + return [self.bind_mapper] + else: + return [] + + @property + def is_orm_statement(self) -> bool: + """return True if the operation is an ORM statement. + + This indicates that the select(), insert(), update(), or delete() + being invoked contains ORM entities as subjects. For a statement + that does not have ORM entities and instead refers only to + :class:`.Table` metadata, it is invoked as a Core SQL statement + and no ORM-level automation takes place. + + """ + return self._compile_state_cls is not None + + @property + def is_executemany(self) -> bool: + """return True if the parameters are a multi-element list of + dictionaries with more than one dictionary. + + .. versionadded:: 2.0 + + """ + return isinstance(self.parameters, list) + + @property + def is_select(self) -> bool: + """return True if this is a SELECT operation.""" + return self.statement.is_select + + @property + def is_insert(self) -> bool: + """return True if this is an INSERT operation.""" + return self.statement.is_dml and self.statement.is_insert + + @property + def is_update(self) -> bool: + """return True if this is an UPDATE operation.""" + return self.statement.is_dml and self.statement.is_update + + @property + def is_delete(self) -> bool: + """return True if this is a DELETE operation.""" + return self.statement.is_dml and self.statement.is_delete + + @property + def _is_crud(self) -> bool: + return isinstance(self.statement, (dml.Update, dml.Delete)) + + def update_execution_options(self, **opts: Any) -> None: + """Update the local execution options with new values.""" + self.local_execution_options = self.local_execution_options.union(opts) + + def _orm_compile_options( + self, + ) -> Optional[ + Union[ + context.ORMCompileState.default_compile_options, + Type[context.ORMCompileState.default_compile_options], + ] + ]: + if not self.is_select: + return None + try: + opts = self.statement._compile_options + except AttributeError: + return None + + if opts is not None and opts.isinstance( + context.ORMCompileState.default_compile_options + ): + return opts # type: ignore + else: + return None + + @property + def lazy_loaded_from(self) -> Optional[InstanceState[Any]]: + """An :class:`.InstanceState` that is using this statement execution + for a lazy load operation. + + The primary rationale for this attribute is to support the horizontal + sharding extension, where it is available within specific query + execution time hooks created by this extension. To that end, the + attribute is only intended to be meaningful at **query execution + time**, and importantly not any time prior to that, including query + compilation time. + + """ + return self.load_options._lazy_loaded_from + + @property + def loader_strategy_path(self) -> Optional[PathRegistry]: + """Return the :class:`.PathRegistry` for the current load path. + + This object represents the "path" in a query along relationships + when a particular object or collection is being loaded. + + """ + opts = self._orm_compile_options() + if opts is not None: + return opts._current_path + else: + return None + + @property + def is_column_load(self) -> bool: + """Return True if the operation is refreshing column-oriented + attributes on an existing ORM object. + + This occurs during operations such as :meth:`_orm.Session.refresh`, + as well as when an attribute deferred by :func:`_orm.defer` is + being loaded, or an attribute that was expired either directly + by :meth:`_orm.Session.expire` or via a commit operation is being + loaded. + + Handlers will very likely not want to add any options to queries + when such an operation is occurring as the query should be a straight + primary key fetch which should not have any additional WHERE criteria, + and loader options travelling with the instance + will have already been added to the query. + + .. versionadded:: 1.4.0b2 + + .. seealso:: + + :attr:`_orm.ORMExecuteState.is_relationship_load` + + """ + opts = self._orm_compile_options() + return opts is not None and opts._for_refresh_state + + @property + def is_relationship_load(self) -> bool: + """Return True if this load is loading objects on behalf of a + relationship. + + This means, the loader in effect is either a LazyLoader, + SelectInLoader, SubqueryLoader, or similar, and the entire + SELECT statement being emitted is on behalf of a relationship + load. + + Handlers will very likely not want to add any options to queries + when such an operation is occurring, as loader options are already + capable of being propagated to relationship loaders and should + be already present. + + .. seealso:: + + :attr:`_orm.ORMExecuteState.is_column_load` + + """ + opts = self._orm_compile_options() + if opts is None: + return False + path = self.loader_strategy_path + return path is not None and not path.is_root + + @property + def load_options( + self, + ) -> Union[ + context.QueryContext.default_load_options, + Type[context.QueryContext.default_load_options], + ]: + """Return the load_options that will be used for this execution.""" + + if not self.is_select: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against a SELECT statement " + "so there are no load options." + ) + + lo: Union[ + context.QueryContext.default_load_options, + Type[context.QueryContext.default_load_options], + ] = self.execution_options.get( + "_sa_orm_load_options", context.QueryContext.default_load_options + ) + return lo + + @property + def update_delete_options( + self, + ) -> Union[ + bulk_persistence.BulkUDCompileState.default_update_options, + Type[bulk_persistence.BulkUDCompileState.default_update_options], + ]: + """Return the update_delete_options that will be used for this + execution.""" + + if not self._is_crud: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against an UPDATE or DELETE " + "statement so there are no update options." + ) + uo: Union[ + bulk_persistence.BulkUDCompileState.default_update_options, + Type[bulk_persistence.BulkUDCompileState.default_update_options], + ] = self.execution_options.get( + "_sa_orm_update_options", + bulk_persistence.BulkUDCompileState.default_update_options, + ) + return uo + + @property + def _non_compile_orm_options(self) -> Sequence[ORMOption]: + return [ + opt + for opt in self.statement._with_options + if is_orm_option(opt) and not opt._is_compile_state + ] + + @property + def user_defined_options(self) -> Sequence[UserDefinedOption]: + """The sequence of :class:`.UserDefinedOptions` that have been + associated with the statement being invoked. + + """ + return [ + opt + for opt in self.statement._with_options + if is_user_defined_option(opt) + ] + + +class SessionTransactionOrigin(Enum): + """indicates the origin of a :class:`.SessionTransaction`. + + This enumeration is present on the + :attr:`.SessionTransaction.origin` attribute of any + :class:`.SessionTransaction` object. + + .. versionadded:: 2.0 + + """ + + AUTOBEGIN = 0 + """transaction were started by autobegin""" + + BEGIN = 1 + """transaction were started by calling :meth:`_orm.Session.begin`""" + + BEGIN_NESTED = 2 + """tranaction were started by :meth:`_orm.Session.begin_nested`""" + + SUBTRANSACTION = 3 + """transaction is an internal "subtransaction" """ + + +class SessionTransaction(_StateChange, TransactionalContext): + """A :class:`.Session`-level transaction. + + :class:`.SessionTransaction` is produced from the + :meth:`_orm.Session.begin` + and :meth:`_orm.Session.begin_nested` methods. It's largely an internal + object that in modern use provides a context manager for session + transactions. + + Documentation on interacting with :class:`_orm.SessionTransaction` is + at: :ref:`unitofwork_transaction`. + + + .. versionchanged:: 1.4 The scoping and API methods to work with the + :class:`_orm.SessionTransaction` object directly have been simplified. + + .. seealso:: + + :ref:`unitofwork_transaction` + + :meth:`.Session.begin` + + :meth:`.Session.begin_nested` + + :meth:`.Session.rollback` + + :meth:`.Session.commit` + + :meth:`.Session.in_transaction` + + :meth:`.Session.in_nested_transaction` + + :meth:`.Session.get_transaction` + + :meth:`.Session.get_nested_transaction` + + + """ + + _rollback_exception: Optional[BaseException] = None + + _connections: Dict[ + Union[Engine, Connection], Tuple[Connection, Transaction, bool, bool] + ] + session: Session + _parent: Optional[SessionTransaction] + + _state: SessionTransactionState + + _new: weakref.WeakKeyDictionary[InstanceState[Any], object] + _deleted: weakref.WeakKeyDictionary[InstanceState[Any], object] + _dirty: weakref.WeakKeyDictionary[InstanceState[Any], object] + _key_switches: weakref.WeakKeyDictionary[ + InstanceState[Any], Tuple[Any, Any] + ] + + origin: SessionTransactionOrigin + """Origin of this :class:`_orm.SessionTransaction`. + + Refers to a :class:`.SessionTransactionOrigin` instance which is an + enumeration indicating the source event that led to constructing + this :class:`_orm.SessionTransaction`. + + .. versionadded:: 2.0 + + """ + + nested: bool = False + """Indicates if this is a nested, or SAVEPOINT, transaction. + + When :attr:`.SessionTransaction.nested` is True, it is expected + that :attr:`.SessionTransaction.parent` will be present as well, + linking to the enclosing :class:`.SessionTransaction`. + + .. seealso:: + + :attr:`.SessionTransaction.origin` + + """ + + def __init__( + self, + session: Session, + origin: SessionTransactionOrigin, + parent: Optional[SessionTransaction] = None, + ): + TransactionalContext._trans_ctx_check(session) + + self.session = session + self._connections = {} + self._parent = parent + self.nested = nested = origin is SessionTransactionOrigin.BEGIN_NESTED + self.origin = origin + + if session._close_state is _SessionCloseState.CLOSED: + raise sa_exc.InvalidRequestError( + "This Session has been permanently closed and is unable " + "to handle any more transaction requests." + ) + + if nested: + if not parent: + raise sa_exc.InvalidRequestError( + "Can't start a SAVEPOINT transaction when no existing " + "transaction is in progress" + ) + + self._previous_nested_transaction = session._nested_transaction + elif origin is SessionTransactionOrigin.SUBTRANSACTION: + assert parent is not None + else: + assert parent is None + + self._state = SessionTransactionState.ACTIVE + + self._take_snapshot() + + # make sure transaction is assigned before we call the + # dispatch + self.session._transaction = self + + self.session.dispatch.after_transaction_create(self.session, self) + + def _raise_for_prerequisite_state( + self, operation_name: str, state: _StateChangeState + ) -> NoReturn: + if state is SessionTransactionState.DEACTIVE: + if self._rollback_exception: + raise sa_exc.PendingRollbackError( + "This Session's transaction has been rolled back " + "due to a previous exception during flush." + " To begin a new transaction with this Session, " + "first issue Session.rollback()." + f" Original exception was: {self._rollback_exception}", + code="7s2a", + ) + else: + raise sa_exc.InvalidRequestError( + "This session is in 'inactive' state, due to the " + "SQL transaction being rolled back; no further SQL " + "can be emitted within this transaction." + ) + elif state is SessionTransactionState.CLOSED: + raise sa_exc.ResourceClosedError("This transaction is closed") + elif state is SessionTransactionState.PROVISIONING_CONNECTION: + raise sa_exc.InvalidRequestError( + "This session is provisioning a new connection; concurrent " + "operations are not permitted", + code="isce", + ) + else: + raise sa_exc.InvalidRequestError( + f"This session is in '{state.name.lower()}' state; no " + "further SQL can be emitted within this transaction." + ) + + @property + def parent(self) -> Optional[SessionTransaction]: + """The parent :class:`.SessionTransaction` of this + :class:`.SessionTransaction`. + + If this attribute is ``None``, indicates this + :class:`.SessionTransaction` is at the top of the stack, and + corresponds to a real "COMMIT"/"ROLLBACK" + block. If non-``None``, then this is either a "subtransaction" + (an internal marker object used by the flush process) or a + "nested" / SAVEPOINT transaction. If the + :attr:`.SessionTransaction.nested` attribute is ``True``, then + this is a SAVEPOINT, and if ``False``, indicates this a subtransaction. + + """ + return self._parent + + @property + def is_active(self) -> bool: + return ( + self.session is not None + and self._state is SessionTransactionState.ACTIVE + ) + + @property + def _is_transaction_boundary(self) -> bool: + return self.nested or not self._parent + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) + def connection( + self, + bindkey: Optional[Mapper[Any]], + execution_options: Optional[_ExecuteOptions] = None, + **kwargs: Any, + ) -> Connection: + bind = self.session.get_bind(bindkey, **kwargs) + return self._connection_for_bind(bind, execution_options) + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) + def _begin(self, nested: bool = False) -> SessionTransaction: + return SessionTransaction( + self.session, + ( + SessionTransactionOrigin.BEGIN_NESTED + if nested + else SessionTransactionOrigin.SUBTRANSACTION + ), + self, + ) + + def _iterate_self_and_parents( + self, upto: Optional[SessionTransaction] = None + ) -> Iterable[SessionTransaction]: + current = self + result: Tuple[SessionTransaction, ...] = () + while current: + result += (current,) + if current._parent is upto: + break + elif current._parent is None: + raise sa_exc.InvalidRequestError( + "Transaction %s is not on the active transaction list" + % (upto) + ) + else: + current = current._parent + + return result + + def _take_snapshot(self) -> None: + if not self._is_transaction_boundary: + parent = self._parent + assert parent is not None + self._new = parent._new + self._deleted = parent._deleted + self._dirty = parent._dirty + self._key_switches = parent._key_switches + return + + is_begin = self.origin in ( + SessionTransactionOrigin.BEGIN, + SessionTransactionOrigin.AUTOBEGIN, + ) + if not is_begin and not self.session._flushing: + self.session.flush() + + self._new = weakref.WeakKeyDictionary() + self._deleted = weakref.WeakKeyDictionary() + self._dirty = weakref.WeakKeyDictionary() + self._key_switches = weakref.WeakKeyDictionary() + + def _restore_snapshot(self, dirty_only: bool = False) -> None: + """Restore the restoration state taken before a transaction began. + + Corresponds to a rollback. + + """ + assert self._is_transaction_boundary + + to_expunge = set(self._new).union(self.session._new) + self.session._expunge_states(to_expunge, to_transient=True) + + for s, (oldkey, newkey) in self._key_switches.items(): + # we probably can do this conditionally based on + # if we expunged or not, but safe_discard does that anyway + self.session.identity_map.safe_discard(s) + + # restore the old key + s.key = oldkey + + # now restore the object, but only if we didn't expunge + if s not in to_expunge: + self.session.identity_map.replace(s) + + for s in set(self._deleted).union(self.session._deleted): + self.session._update_impl(s, revert_deletion=True) + + assert not self.session._deleted + + for s in self.session.identity_map.all_states(): + if not dirty_only or s.modified or s in self._dirty: + s._expire(s.dict, self.session.identity_map._modified) + + def _remove_snapshot(self) -> None: + """Remove the restoration state taken before a transaction began. + + Corresponds to a commit. + + """ + assert self._is_transaction_boundary + + if not self.nested and self.session.expire_on_commit: + for s in self.session.identity_map.all_states(): + s._expire(s.dict, self.session.identity_map._modified) + + statelib.InstanceState._detach_states( + list(self._deleted), self.session + ) + self._deleted.clear() + elif self.nested: + parent = self._parent + assert parent is not None + parent._new.update(self._new) + parent._dirty.update(self._dirty) + parent._deleted.update(self._deleted) + parent._key_switches.update(self._key_switches) + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), _StateChangeStates.NO_CHANGE + ) + def _connection_for_bind( + self, + bind: _SessionBind, + execution_options: Optional[CoreExecuteOptionsParameter], + ) -> Connection: + if bind in self._connections: + if execution_options: + util.warn( + "Connection is already established for the " + "given bind; execution_options ignored" + ) + return self._connections[bind][0] + + self._state = SessionTransactionState.PROVISIONING_CONNECTION + + local_connect = False + should_commit = True + + try: + if self._parent: + conn = self._parent._connection_for_bind( + bind, execution_options + ) + if not self.nested: + return conn + else: + if isinstance(bind, engine.Connection): + conn = bind + if conn.engine in self._connections: + raise sa_exc.InvalidRequestError( + "Session already has a Connection associated " + "for the given Connection's Engine" + ) + else: + conn = bind.connect() + local_connect = True + + try: + if execution_options: + conn = conn.execution_options(**execution_options) + + transaction: Transaction + if self.session.twophase and self._parent is None: + # TODO: shouldn't we only be here if not + # conn.in_transaction() ? + # if twophase is set and conn.in_transaction(), validate + # that it is in fact twophase. + transaction = conn.begin_twophase() + elif self.nested: + transaction = conn.begin_nested() + elif conn.in_transaction(): + join_transaction_mode = self.session.join_transaction_mode + + if join_transaction_mode == "conditional_savepoint": + if conn.in_nested_transaction(): + join_transaction_mode = "create_savepoint" + else: + join_transaction_mode = "rollback_only" + + if join_transaction_mode in ( + "control_fully", + "rollback_only", + ): + if conn.in_nested_transaction(): + transaction = ( + conn._get_required_nested_transaction() + ) + else: + transaction = conn._get_required_transaction() + if join_transaction_mode == "rollback_only": + should_commit = False + elif join_transaction_mode == "create_savepoint": + transaction = conn.begin_nested() + else: + assert False, join_transaction_mode + else: + transaction = conn.begin() + except: + # connection will not not be associated with this Session; + # close it immediately so that it isn't closed under GC + if local_connect: + conn.close() + raise + else: + bind_is_connection = isinstance(bind, engine.Connection) + + self._connections[conn] = self._connections[conn.engine] = ( + conn, + transaction, + should_commit, + not bind_is_connection, + ) + self.session.dispatch.after_begin(self.session, self, conn) + return conn + finally: + self._state = SessionTransactionState.ACTIVE + + def prepare(self) -> None: + if self._parent is not None or not self.session.twophase: + raise sa_exc.InvalidRequestError( + "'twophase' mode not enabled, or not root transaction; " + "can't prepare." + ) + self._prepare_impl() + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED + ) + def _prepare_impl(self) -> None: + if self._parent is None or self.nested: + self.session.dispatch.before_commit(self.session) + + stx = self.session._transaction + assert stx is not None + if stx is not self: + for subtransaction in stx._iterate_self_and_parents(upto=self): + subtransaction.commit() + + if not self.session._flushing: + for _flush_guard in range(100): + if self.session._is_clean(): + break + self.session.flush() + else: + raise exc.FlushError( + "Over 100 subsequent flushes have occurred within " + "session.commit() - is an after_flush() hook " + "creating new objects?" + ) + + if self._parent is None and self.session.twophase: + try: + for t in set(self._connections.values()): + cast("TwoPhaseTransaction", t[1]).prepare() + except: + with util.safe_reraise(): + self.rollback() + + self._state = SessionTransactionState.PREPARED + + @_StateChange.declare_states( + (SessionTransactionState.ACTIVE, SessionTransactionState.PREPARED), + SessionTransactionState.CLOSED, + ) + def commit(self, _to_root: bool = False) -> None: + if self._state is not SessionTransactionState.PREPARED: + with self._expect_state(SessionTransactionState.PREPARED): + self._prepare_impl() + + if self._parent is None or self.nested: + for conn, trans, should_commit, autoclose in set( + self._connections.values() + ): + if should_commit: + trans.commit() + + self._state = SessionTransactionState.COMMITTED + self.session.dispatch.after_commit(self.session) + + self._remove_snapshot() + + with self._expect_state(SessionTransactionState.CLOSED): + self.close() + + if _to_root and self._parent: + self._parent.commit(_to_root=True) + + @_StateChange.declare_states( + ( + SessionTransactionState.ACTIVE, + SessionTransactionState.DEACTIVE, + SessionTransactionState.PREPARED, + ), + SessionTransactionState.CLOSED, + ) + def rollback( + self, _capture_exception: bool = False, _to_root: bool = False + ) -> None: + stx = self.session._transaction + assert stx is not None + if stx is not self: + for subtransaction in stx._iterate_self_and_parents(upto=self): + subtransaction.close() + + boundary = self + rollback_err = None + if self._state in ( + SessionTransactionState.ACTIVE, + SessionTransactionState.PREPARED, + ): + for transaction in self._iterate_self_and_parents(): + if transaction._parent is None or transaction.nested: + try: + for t in set(transaction._connections.values()): + t[1].rollback() + + transaction._state = SessionTransactionState.DEACTIVE + self.session.dispatch.after_rollback(self.session) + except: + rollback_err = sys.exc_info() + finally: + transaction._state = SessionTransactionState.DEACTIVE + transaction._restore_snapshot( + dirty_only=transaction.nested + ) + boundary = transaction + break + else: + transaction._state = SessionTransactionState.DEACTIVE + + sess = self.session + + if not rollback_err and not sess._is_clean(): + # if items were added, deleted, or mutated + # here, we need to re-restore the snapshot + util.warn( + "Session's state has been changed on " + "a non-active transaction - this state " + "will be discarded." + ) + boundary._restore_snapshot(dirty_only=boundary.nested) + + with self._expect_state(SessionTransactionState.CLOSED): + self.close() + + if self._parent and _capture_exception: + self._parent._rollback_exception = sys.exc_info()[1] + + if rollback_err and rollback_err[1]: + raise rollback_err[1].with_traceback(rollback_err[2]) + + sess.dispatch.after_soft_rollback(sess, self) + + if _to_root and self._parent: + self._parent.rollback(_to_root=True) + + @_StateChange.declare_states( + _StateChangeStates.ANY, SessionTransactionState.CLOSED + ) + def close(self, invalidate: bool = False) -> None: + if self.nested: + self.session._nested_transaction = ( + self._previous_nested_transaction + ) + + self.session._transaction = self._parent + + for connection, transaction, should_commit, autoclose in set( + self._connections.values() + ): + if invalidate and self._parent is None: + connection.invalidate() + if should_commit and transaction.is_active: + transaction.close() + if autoclose and self._parent is None: + connection.close() + + self._state = SessionTransactionState.CLOSED + sess = self.session + + # TODO: these two None sets were historically after the + # event hook below, and in 2.0 I changed it this way for some reason, + # and I remember there being a reason, but not what it was. + # Why do we need to get rid of them at all? test_memusage::CycleTest + # passes with these commented out. + # self.session = None # type: ignore + # self._connections = None # type: ignore + + sess.dispatch.after_transaction_end(sess, self) + + def _get_subject(self) -> Session: + return self.session + + def _transaction_is_active(self) -> bool: + return self._state is SessionTransactionState.ACTIVE + + def _transaction_is_closed(self) -> bool: + return self._state is SessionTransactionState.CLOSED + + def _rollback_can_be_called(self) -> bool: + return self._state not in (COMMITTED, CLOSED) + + +class _SessionCloseState(Enum): + ACTIVE = 1 + CLOSED = 2 + CLOSE_IS_RESET = 3 + + +class Session(_SessionClassMethods, EventTarget): + """Manages persistence operations for ORM-mapped objects. + + The :class:`_orm.Session` is **not safe for use in concurrent threads.**. + See :ref:`session_faq_threadsafe` for background. + + The Session's usage paradigm is described at :doc:`/orm/session`. + + + """ + + _is_asyncio = False + + dispatch: dispatcher[Session] + + identity_map: IdentityMap + """A mapping of object identities to objects themselves. + + Iterating through ``Session.identity_map.values()`` provides + access to the full set of persistent objects (i.e., those + that have row identity) currently in the session. + + .. seealso:: + + :func:`.identity_key` - helper function to produce the keys used + in this dictionary. + + """ + + _new: Dict[InstanceState[Any], Any] + _deleted: Dict[InstanceState[Any], Any] + bind: Optional[Union[Engine, Connection]] + __binds: Dict[_SessionBindKey, _SessionBind] + _flushing: bool + _warn_on_events: bool + _transaction: Optional[SessionTransaction] + _nested_transaction: Optional[SessionTransaction] + hash_key: int + autoflush: bool + expire_on_commit: bool + enable_baked_queries: bool + twophase: bool + join_transaction_mode: JoinTransactionMode + _query_cls: Type[Query[Any]] + _close_state: _SessionCloseState + + def __init__( + self, + bind: Optional[_SessionBind] = None, + *, + autoflush: bool = True, + future: Literal[True] = True, + expire_on_commit: bool = True, + autobegin: bool = True, + twophase: bool = False, + binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None, + enable_baked_queries: bool = True, + info: Optional[_InfoType] = None, + query_cls: Optional[Type[Query[Any]]] = None, + autocommit: Literal[False] = False, + join_transaction_mode: JoinTransactionMode = "conditional_savepoint", + close_resets_only: Union[bool, _NoArg] = _NoArg.NO_ARG, + ): + r"""Construct a new :class:`_orm.Session`. + + See also the :class:`.sessionmaker` function which is used to + generate a :class:`.Session`-producing callable with a given + set of arguments. + + :param autoflush: When ``True``, all query operations will issue a + :meth:`~.Session.flush` call to this ``Session`` before proceeding. + This is a convenience feature so that :meth:`~.Session.flush` need + not be called repeatedly in order for database queries to retrieve + results. + + .. seealso:: + + :ref:`session_flushing` - additional background on autoflush + + :param autobegin: Automatically start transactions (i.e. equivalent to + invoking :meth:`_orm.Session.begin`) when database access is + requested by an operation. Defaults to ``True``. Set to + ``False`` to prevent a :class:`_orm.Session` from implicitly + beginning transactions after construction, as well as after any of + the :meth:`_orm.Session.rollback`, :meth:`_orm.Session.commit`, + or :meth:`_orm.Session.close` methods are called. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`session_autobegin_disable` + + :param bind: An optional :class:`_engine.Engine` or + :class:`_engine.Connection` to + which this ``Session`` should be bound. When specified, all SQL + operations performed by this session will execute via this + connectable. + + :param binds: A dictionary which may specify any number of + :class:`_engine.Engine` or :class:`_engine.Connection` + objects as the source of + connectivity for SQL operations on a per-entity basis. The keys + of the dictionary consist of any series of mapped classes, + arbitrary Python classes that are bases for mapped classes, + :class:`_schema.Table` objects and :class:`_orm.Mapper` objects. + The + values of the dictionary are then instances of + :class:`_engine.Engine` + or less commonly :class:`_engine.Connection` objects. + Operations which + proceed relative to a particular mapped class will consult this + dictionary for the closest matching entity in order to determine + which :class:`_engine.Engine` should be used for a particular SQL + operation. The complete heuristics for resolution are + described at :meth:`.Session.get_bind`. Usage looks like:: + + Session = sessionmaker(binds={ + SomeMappedClass: create_engine('postgresql+psycopg2://engine1'), + SomeDeclarativeBase: create_engine('postgresql+psycopg2://engine2'), + some_mapper: create_engine('postgresql+psycopg2://engine3'), + some_table: create_engine('postgresql+psycopg2://engine4'), + }) + + .. seealso:: + + :ref:`session_partitioning` + + :meth:`.Session.bind_mapper` + + :meth:`.Session.bind_table` + + :meth:`.Session.get_bind` + + + :param \class_: Specify an alternate class other than + ``sqlalchemy.orm.session.Session`` which should be used by the + returned class. This is the only argument that is local to the + :class:`.sessionmaker` function, and is not sent directly to the + constructor for ``Session``. + + :param enable_baked_queries: legacy; defaults to ``True``. + A parameter consumed + by the :mod:`sqlalchemy.ext.baked` extension to determine if + "baked queries" should be cached, as is the normal operation + of this extension. When set to ``False``, caching as used by + this particular extension is disabled. + + .. versionchanged:: 1.4 The ``sqlalchemy.ext.baked`` extension is + legacy and is not used by any of SQLAlchemy's internals. This + flag therefore only affects applications that are making explicit + use of this extension within their own code. + + :param expire_on_commit: Defaults to ``True``. When ``True``, all + instances will be fully expired after each :meth:`~.commit`, + so that all attribute/object access subsequent to a completed + transaction will load from the most recent database state. + + .. seealso:: + + :ref:`session_committing` + + :param future: Deprecated; this flag is always True. + + .. seealso:: + + :ref:`migration_20_toplevel` + + :param info: optional dictionary of arbitrary data to be associated + with this :class:`.Session`. Is available via the + :attr:`.Session.info` attribute. Note the dictionary is copied at + construction time so that modifications to the per- + :class:`.Session` dictionary will be local to that + :class:`.Session`. + + :param query_cls: Class which should be used to create new Query + objects, as returned by the :meth:`~.Session.query` method. + Defaults to :class:`_query.Query`. + + :param twophase: When ``True``, all transactions will be started as + a "two phase" transaction, i.e. using the "two phase" semantics + of the database in use along with an XID. During a + :meth:`~.commit`, after :meth:`~.flush` has been issued for all + attached databases, the :meth:`~.TwoPhaseTransaction.prepare` + method on each database's :class:`.TwoPhaseTransaction` will be + called. This allows each database to roll back the entire + transaction, before each transaction is committed. + + :param autocommit: the "autocommit" keyword is present for backwards + compatibility but must remain at its default value of ``False``. + + :param join_transaction_mode: Describes the transactional behavior to + take when a given bind is a :class:`_engine.Connection` that + has already begun a transaction outside the scope of this + :class:`_orm.Session`; in other words the + :meth:`_engine.Connection.in_transaction()` method returns True. + + The following behaviors only take effect when the :class:`_orm.Session` + **actually makes use of the connection given**; that is, a method + such as :meth:`_orm.Session.execute`, :meth:`_orm.Session.connection`, + etc. are actually invoked: + + * ``"conditional_savepoint"`` - this is the default. if the given + :class:`_engine.Connection` is begun within a transaction but + does not have a SAVEPOINT, then ``"rollback_only"`` is used. + If the :class:`_engine.Connection` is additionally within + a SAVEPOINT, in other words + :meth:`_engine.Connection.in_nested_transaction()` method returns + True, then ``"create_savepoint"`` is used. + + ``"conditional_savepoint"`` behavior attempts to make use of + savepoints in order to keep the state of the existing transaction + unchanged, but only if there is already a savepoint in progress; + otherwise, it is not assumed that the backend in use has adequate + support for SAVEPOINT, as availability of this feature varies. + ``"conditional_savepoint"`` also seeks to establish approximate + backwards compatibility with previous :class:`_orm.Session` + behavior, for applications that are not setting a specific mode. It + is recommended that one of the explicit settings be used. + + * ``"create_savepoint"`` - the :class:`_orm.Session` will use + :meth:`_engine.Connection.begin_nested()` in all cases to create + its own transaction. This transaction by its nature rides + "on top" of any existing transaction that's opened on the given + :class:`_engine.Connection`; if the underlying database and + the driver in use has full, non-broken support for SAVEPOINT, the + external transaction will remain unaffected throughout the + lifespan of the :class:`_orm.Session`. + + The ``"create_savepoint"`` mode is the most useful for integrating + a :class:`_orm.Session` into a test suite where an externally + initiated transaction should remain unaffected; however, it relies + on proper SAVEPOINT support from the underlying driver and + database. + + .. tip:: When using SQLite, the SQLite driver included through + Python 3.11 does not handle SAVEPOINTs correctly in all cases + without workarounds. See the sections + :ref:`pysqlite_serializable` and :ref:`aiosqlite_serializable` + for details on current workarounds. + + * ``"control_fully"`` - the :class:`_orm.Session` will take + control of the given transaction as its own; + :meth:`_orm.Session.commit` will call ``.commit()`` on the + transaction, :meth:`_orm.Session.rollback` will call + ``.rollback()`` on the transaction, :meth:`_orm.Session.close` will + call ``.rollback`` on the transaction. + + .. tip:: This mode of use is equivalent to how SQLAlchemy 1.4 would + handle a :class:`_engine.Connection` given with an existing + SAVEPOINT (i.e. :meth:`_engine.Connection.begin_nested`); the + :class:`_orm.Session` would take full control of the existing + SAVEPOINT. + + * ``"rollback_only"`` - the :class:`_orm.Session` will take control + of the given transaction for ``.rollback()`` calls only; + ``.commit()`` calls will not be propagated to the given + transaction. ``.close()`` calls will have no effect on the + given transaction. + + .. tip:: This mode of use is equivalent to how SQLAlchemy 1.4 would + handle a :class:`_engine.Connection` given with an existing + regular database transaction (i.e. + :meth:`_engine.Connection.begin`); the :class:`_orm.Session` + would propagate :meth:`_orm.Session.rollback` calls to the + underlying transaction, but not :meth:`_orm.Session.commit` or + :meth:`_orm.Session.close` calls. + + .. versionadded:: 2.0.0rc1 + + :param close_resets_only: Defaults to ``True``. Determines if + the session should reset itself after calling ``.close()`` + or should pass in a no longer usable state, disabling re-use. + + .. versionadded:: 2.0.22 added flag ``close_resets_only``. + A future SQLAlchemy version may change the default value of + this flag to ``False``. + + .. seealso:: + + :ref:`session_closing` - Detail on the semantics of + :meth:`_orm.Session.close` and :meth:`_orm.Session.reset`. + + """ # noqa + + # considering allowing the "autocommit" keyword to still be accepted + # as long as it's False, so that external test suites, oslo.db etc + # continue to function as the argument appears to be passed in lots + # of cases including in our own test suite + if autocommit: + raise sa_exc.ArgumentError( + "autocommit=True is no longer supported" + ) + self.identity_map = identity.WeakInstanceDict() + + if not future: + raise sa_exc.ArgumentError( + "The 'future' parameter passed to " + "Session() may only be set to True." + ) + + self._new = {} # InstanceState->object, strong refs object + self._deleted = {} # same + self.bind = bind + self.__binds = {} + self._flushing = False + self._warn_on_events = False + self._transaction = None + self._nested_transaction = None + self.hash_key = _new_sessionid() + self.autobegin = autobegin + self.autoflush = autoflush + self.expire_on_commit = expire_on_commit + self.enable_baked_queries = enable_baked_queries + + # the idea is that at some point NO_ARG will warn that in the future + # the default will switch to close_resets_only=False. + if close_resets_only or close_resets_only is _NoArg.NO_ARG: + self._close_state = _SessionCloseState.CLOSE_IS_RESET + else: + self._close_state = _SessionCloseState.ACTIVE + if ( + join_transaction_mode + and join_transaction_mode + not in JoinTransactionMode.__args__ # type: ignore + ): + raise sa_exc.ArgumentError( + f"invalid selection for join_transaction_mode: " + f'"{join_transaction_mode}"' + ) + self.join_transaction_mode = join_transaction_mode + + self.twophase = twophase + self._query_cls = query_cls if query_cls else query.Query + if info: + self.info.update(info) + + if binds is not None: + for key, bind in binds.items(): + self._add_bind(key, bind) + + _sessions[self.hash_key] = self + + # used by sqlalchemy.engine.util.TransactionalContext + _trans_context_manager: Optional[TransactionalContext] = None + + connection_callable: Optional[_ConnectionCallableProto] = None + + def __enter__(self: _S) -> _S: + return self + + def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: + self.close() + + @contextlib.contextmanager + def _maker_context_manager(self: _S) -> Iterator[_S]: + with self: + with self.begin(): + yield self + + def in_transaction(self) -> bool: + """Return True if this :class:`_orm.Session` has begun a transaction. + + .. versionadded:: 1.4 + + .. seealso:: + + :attr:`_orm.Session.is_active` + + + """ + return self._transaction is not None + + def in_nested_transaction(self) -> bool: + """Return True if this :class:`_orm.Session` has begun a nested + transaction, e.g. SAVEPOINT. + + .. versionadded:: 1.4 + + """ + return self._nested_transaction is not None + + def get_transaction(self) -> Optional[SessionTransaction]: + """Return the current root transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + trans = self._transaction + while trans is not None and trans._parent is not None: + trans = trans._parent + return trans + + def get_nested_transaction(self) -> Optional[SessionTransaction]: + """Return the current nested transaction in progress, if any. + + .. versionadded:: 1.4 + + """ + + return self._nested_transaction + + @util.memoized_property + def info(self) -> _InfoType: + """A user-modifiable dictionary. + + The initial value of this dictionary can be populated using the + ``info`` argument to the :class:`.Session` constructor or + :class:`.sessionmaker` constructor or factory methods. The dictionary + here is always local to this :class:`.Session` and can be modified + independently of all other :class:`.Session` objects. + + """ + return {} + + def _autobegin_t(self, begin: bool = False) -> SessionTransaction: + if self._transaction is None: + if not begin and not self.autobegin: + raise sa_exc.InvalidRequestError( + "Autobegin is disabled on this Session; please call " + "session.begin() to start a new transaction" + ) + trans = SessionTransaction( + self, + ( + SessionTransactionOrigin.BEGIN + if begin + else SessionTransactionOrigin.AUTOBEGIN + ), + ) + assert self._transaction is trans + return trans + + return self._transaction + + def begin(self, nested: bool = False) -> SessionTransaction: + """Begin a transaction, or nested transaction, + on this :class:`.Session`, if one is not already begun. + + The :class:`_orm.Session` object features **autobegin** behavior, + so that normally it is not necessary to call the + :meth:`_orm.Session.begin` + method explicitly. However, it may be used in order to control + the scope of when the transactional state is begun. + + When used to begin the outermost transaction, an error is raised + if this :class:`.Session` is already inside of a transaction. + + :param nested: if True, begins a SAVEPOINT transaction and is + equivalent to calling :meth:`~.Session.begin_nested`. For + documentation on SAVEPOINT transactions, please see + :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` + acts as a Python context manager, allowing :meth:`.Session.begin` + to be used in a "with" block. See :ref:`session_explicit_begin` for + an example. + + .. seealso:: + + :ref:`session_autobegin` + + :ref:`unitofwork_transaction` + + :meth:`.Session.begin_nested` + + + """ + + trans = self._transaction + if trans is None: + trans = self._autobegin_t(begin=True) + + if not nested: + return trans + + assert trans is not None + + if nested: + trans = trans._begin(nested=nested) + assert self._transaction is trans + self._nested_transaction = trans + else: + raise sa_exc.InvalidRequestError( + "A transaction is already begun on this Session." + ) + + return trans # needed for __enter__/__exit__ hook + + def begin_nested(self) -> SessionTransaction: + """Begin a "nested" transaction on this Session, e.g. SAVEPOINT. + + The target database(s) and associated drivers must support SQL + SAVEPOINT for this method to function correctly. + + For documentation on SAVEPOINT + transactions, please see :ref:`session_begin_nested`. + + :return: the :class:`.SessionTransaction` object. Note that + :class:`.SessionTransaction` acts as a context manager, allowing + :meth:`.Session.begin_nested` to be used in a "with" block. + See :ref:`session_begin_nested` for a usage example. + + .. seealso:: + + :ref:`session_begin_nested` + + :ref:`pysqlite_serializable` - special workarounds required + with the SQLite driver in order for SAVEPOINT to work + correctly. For asyncio use cases, see the section + :ref:`aiosqlite_serializable`. + + """ + return self.begin(nested=True) + + def rollback(self) -> None: + """Rollback the current transaction in progress. + + If no transaction is in progress, this method is a pass-through. + + The method always rolls back + the topmost database transaction, discarding any nested + transactions that may be in progress. + + .. seealso:: + + :ref:`session_rollback` + + :ref:`unitofwork_transaction` + + """ + if self._transaction is None: + pass + else: + self._transaction.rollback(_to_root=True) + + def commit(self) -> None: + """Flush pending changes and commit the current transaction. + + When the COMMIT operation is complete, all objects are fully + :term:`expired`, erasing their internal contents, which will be + automatically re-loaded when the objects are next accessed. In the + interim, these objects are in an expired state and will not function if + they are :term:`detached` from the :class:`.Session`. Additionally, + this re-load operation is not supported when using asyncio-oriented + APIs. The :paramref:`.Session.expire_on_commit` parameter may be used + to disable this behavior. + + When there is no transaction in place for the :class:`.Session`, + indicating that no operations were invoked on this :class:`.Session` + since the previous call to :meth:`.Session.commit`, the method will + begin and commit an internal-only "logical" transaction, that does not + normally affect the database unless pending flush changes were + detected, but will still invoke event handlers and object expiration + rules. + + The outermost database transaction is committed unconditionally, + automatically releasing any SAVEPOINTs in effect. + + .. seealso:: + + :ref:`session_committing` + + :ref:`unitofwork_transaction` + + :ref:`asyncio_orm_avoid_lazyloads` + + """ + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + + trans.commit(_to_root=True) + + def prepare(self) -> None: + """Prepare the current transaction in progress for two phase commit. + + If no transaction is in progress, this method raises an + :exc:`~sqlalchemy.exc.InvalidRequestError`. + + Only root transactions of two phase sessions can be prepared. If the + current transaction is not such, an + :exc:`~sqlalchemy.exc.InvalidRequestError` is raised. + + """ + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + + trans.prepare() + + def connection( + self, + bind_arguments: Optional[_BindArguments] = None, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + ) -> Connection: + r"""Return a :class:`_engine.Connection` object corresponding to this + :class:`.Session` object's transactional state. + + Either the :class:`_engine.Connection` corresponding to the current + transaction is returned, or if no transaction is in progress, a new + one is begun and the :class:`_engine.Connection` + returned (note that no + transactional state is established with the DBAPI until the first + SQL statement is emitted). + + Ambiguity in multi-bind or unbound :class:`.Session` objects can be + resolved through any of the optional keyword arguments. This + ultimately makes usage of the :meth:`.get_bind` method for resolution. + + :param bind_arguments: dictionary of bind arguments. May include + "mapper", "bind", "clause", other custom arguments that are passed + to :meth:`.Session.get_bind`. + + :param execution_options: a dictionary of execution options that will + be passed to :meth:`_engine.Connection.execution_options`, **when the + connection is first procured only**. If the connection is already + present within the :class:`.Session`, a warning is emitted and + the arguments are ignored. + + .. seealso:: + + :ref:`session_transaction_isolation` + + """ + + if bind_arguments: + bind = bind_arguments.pop("bind", None) + + if bind is None: + bind = self.get_bind(**bind_arguments) + else: + bind = self.get_bind() + + return self._connection_for_bind( + bind, + execution_options=execution_options, + ) + + def _connection_for_bind( + self, + engine: _SessionBind, + execution_options: Optional[CoreExecuteOptionsParameter] = None, + **kw: Any, + ) -> Connection: + TransactionalContext._trans_ctx_check(self) + + trans = self._transaction + if trans is None: + trans = self._autobegin_t() + return trans._connection_for_bind(engine, execution_options) + + @overload + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: Literal[True] = ..., + ) -> Any: ... + + @overload + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: bool = ..., + ) -> Result[Any]: ... + + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: bool = False, + ) -> Any: + statement = coercions.expect(roles.StatementRole, statement) + + if not bind_arguments: + bind_arguments = {} + else: + bind_arguments = dict(bind_arguments) + + if ( + statement._propagate_attrs.get("compile_state_plugin", None) + == "orm" + ): + compile_state_cls = CompileState._get_plugin_class_for_plugin( + statement, "orm" + ) + if TYPE_CHECKING: + assert isinstance( + compile_state_cls, context.AbstractORMCompileState + ) + else: + compile_state_cls = None + bind_arguments.setdefault("clause", statement) + + execution_options = util.coerce_to_immutabledict(execution_options) + + if _parent_execute_state: + events_todo = _parent_execute_state._remaining_events() + else: + events_todo = self.dispatch.do_orm_execute + if _add_event: + events_todo = list(events_todo) + [_add_event] + + if events_todo: + if compile_state_cls is not None: + # for event handlers, do the orm_pre_session_exec + # pass ahead of the event handlers, so that things like + # .load_options, .update_delete_options etc. are populated. + # is_pre_event=True allows the hook to hold off on things + # it doesn't want to do twice, including autoflush as well + # as "pre fetch" for DML, etc. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + True, + ) + + orm_exec_state = ORMExecuteState( + self, + statement, + params, + execution_options, + bind_arguments, + compile_state_cls, + events_todo, + ) + for idx, fn in enumerate(events_todo): + orm_exec_state._starting_event_idx = idx + fn_result: Optional[Result[Any]] = fn(orm_exec_state) + if fn_result: + if _scalar_result: + return fn_result.scalar() + else: + return fn_result + + statement = orm_exec_state.statement + execution_options = orm_exec_state.local_execution_options + + if compile_state_cls is not None: + # now run orm_pre_session_exec() "for real". if there were + # event hooks, this will re-run the steps that interpret + # new execution_options into load_options / update_delete_options, + # which we assume the event hook might have updated. + # autoflush will also be invoked in this step if enabled. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + False, + ) + + bind = self.get_bind(**bind_arguments) + + conn = self._connection_for_bind(bind) + + if _scalar_result and not compile_state_cls: + if TYPE_CHECKING: + params = cast(_CoreSingleExecuteParams, params) + return conn.scalar( + statement, params or {}, execution_options=execution_options + ) + + if compile_state_cls: + result: Result[Any] = compile_state_cls.orm_execute_statement( + self, + statement, + params or {}, + execution_options, + bind_arguments, + conn, + ) + else: + result = conn.execute( + statement, params or {}, execution_options=execution_options + ) + + if _scalar_result: + return result.scalar() + else: + return result + + @overload + def execute( + self, + statement: TypedReturnsRows[_T], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[_T]: ... + + @overload + def execute( + self, + statement: UpdateBase, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> CursorResult[Any]: ... + + @overload + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: ... + + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result[Any]: + r"""Execute a SQL expression construct. + + Returns a :class:`_engine.Result` object representing + results of the statement execution. + + E.g.:: + + from sqlalchemy import select + result = session.execute( + select(User).where(User.id == 5) + ) + + The API contract of :meth:`_orm.Session.execute` is similar to that + of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version + of :class:`_engine.Connection`. + + .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is + now the primary point of ORM statement execution when using + :term:`2.0 style` ORM usage. + + :param statement: + An executable statement (i.e. an :class:`.Executable` expression + such as :func:`_expression.select`). + + :param params: + Optional dictionary, or list of dictionaries, containing + bound parameter values. If a single dictionary, single-row + execution occurs; if a list of dictionaries, an + "executemany" will be invoked. The keys in each dictionary + must correspond to parameter names present in the statement. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`, and may also + provide additional options understood only in an ORM context. + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + :return: a :class:`_engine.Result` object. + + + """ + return self._execute_internal( + statement, + params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + ) + + @overload + def scalar( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Optional[_T]: ... + + @overload + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: ... + + def scalar( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> Any: + """Execute a statement and return a scalar result. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a scalar Python + value. + + """ + + return self._execute_internal( + statement, + params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _scalar_result=True, + **kw, + ) + + @overload + def scalars( + self, + statement: TypedReturnsRows[Tuple[_T]], + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[_T]: ... + + @overload + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: ... + + def scalars( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + **kw: Any, + ) -> ScalarResult[Any]: + """Execute a statement and return the results as scalars. + + Usage and parameters are the same as that of + :meth:`_orm.Session.execute`; the return result is a + :class:`_result.ScalarResult` filtering object which + will return single elements rather than :class:`_row.Row` objects. + + :return: a :class:`_result.ScalarResult` object + + .. versionadded:: 1.4.24 Added :meth:`_orm.Session.scalars` + + .. versionadded:: 1.4.26 Added :meth:`_orm.scoped_session.scalars` + + .. seealso:: + + :ref:`orm_queryguide_select_orm_entities` - contrasts the behavior + of :meth:`_orm.Session.execute` to :meth:`_orm.Session.scalars` + + """ + + return self._execute_internal( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _scalar_result=False, # mypy appreciates this + **kw, + ).scalars() + + def close(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`. + + This expunges all ORM objects associated with this + :class:`_orm.Session`, ends any transaction in progress and + :term:`releases` any :class:`_engine.Connection` objects which this + :class:`_orm.Session` itself has checked out from associated + :class:`_engine.Engine` objects. The operation then leaves the + :class:`_orm.Session` in a state which it may be used again. + + .. tip:: + + In the default running mode the :meth:`_orm.Session.close` + method **does not prevent the Session from being used again**. + The :class:`_orm.Session` itself does not actually have a + distinct "closed" state; it merely means + the :class:`_orm.Session` will release all database connections + and ORM objects. + + Setting the parameter :paramref:`_orm.Session.close_resets_only` + to ``False`` will instead make the ``close`` final, meaning that + any further action on the session will be forbidden. + + .. versionchanged:: 1.4 The :meth:`.Session.close` method does not + immediately create a new :class:`.SessionTransaction` object; + instead, the new :class:`.SessionTransaction` is created only if + the :class:`.Session` is used again for a database operation. + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_orm.Session.close` and :meth:`_orm.Session.reset`. + + :meth:`_orm.Session.reset` - a similar method that behaves like + ``close()`` with the parameter + :paramref:`_orm.Session.close_resets_only` set to ``True``. + + """ + self._close_impl(invalidate=False) + + def reset(self) -> None: + """Close out the transactional resources and ORM objects used by this + :class:`_orm.Session`, resetting the session to its initial state. + + This method provides for same "reset-only" behavior that the + :meth:`_orm.Session.close` method has provided historically, where the + state of the :class:`_orm.Session` is reset as though the object were + brand new, and ready to be used again. + This method may then be useful for :class:`_orm.Session` objects + which set :paramref:`_orm.Session.close_resets_only` to ``False``, + so that "reset only" behavior is still available. + + .. versionadded:: 2.0.22 + + .. seealso:: + + :ref:`session_closing` - detail on the semantics of + :meth:`_orm.Session.close` and :meth:`_orm.Session.reset`. + + :meth:`_orm.Session.close` - a similar method will additionally + prevent re-use of the Session when the parameter + :paramref:`_orm.Session.close_resets_only` is set to ``False``. + """ + self._close_impl(invalidate=False, is_reset=True) + + def invalidate(self) -> None: + """Close this Session, using connection invalidation. + + This is a variant of :meth:`.Session.close` that will additionally + ensure that the :meth:`_engine.Connection.invalidate` + method will be called on each :class:`_engine.Connection` object + that is currently in use for a transaction (typically there is only + one connection unless the :class:`_orm.Session` is used with + multiple engines). + + This can be called when the database is known to be in a state where + the connections are no longer safe to be used. + + Below illustrates a scenario when using `gevent + `_, which can produce ``Timeout`` exceptions + that may mean the underlying connection should be discarded:: + + import gevent + + try: + sess = Session() + sess.add(User()) + sess.commit() + except gevent.Timeout: + sess.invalidate() + raise + except: + sess.rollback() + raise + + The method additionally does everything that :meth:`_orm.Session.close` + does, including that all ORM objects are expunged. + + """ + self._close_impl(invalidate=True) + + def _close_impl(self, invalidate: bool, is_reset: bool = False) -> None: + if not is_reset and self._close_state is _SessionCloseState.ACTIVE: + self._close_state = _SessionCloseState.CLOSED + self.expunge_all() + if self._transaction is not None: + for transaction in self._transaction._iterate_self_and_parents(): + transaction.close(invalidate) + + def expunge_all(self) -> None: + """Remove all object instances from this ``Session``. + + This is equivalent to calling ``expunge(obj)`` on all objects in this + ``Session``. + + """ + + all_states = self.identity_map.all_states() + list(self._new) + self.identity_map._kill() + self.identity_map = identity.WeakInstanceDict() + self._new = {} + self._deleted = {} + + statelib.InstanceState._detach_states(all_states, self) + + def _add_bind(self, key: _SessionBindKey, bind: _SessionBind) -> None: + try: + insp = inspect(key) + except sa_exc.NoInspectionAvailable as err: + if not isinstance(key, type): + raise sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key + ) from err + else: + self.__binds[key] = bind + else: + if TYPE_CHECKING: + assert isinstance(insp, Inspectable) + + if isinstance(insp, TableClause): + self.__binds[insp] = bind + elif insp_is_mapper(insp): + self.__binds[insp.class_] = bind + for _selectable in insp._all_tables: + self.__binds[_selectable] = bind + else: + raise sa_exc.ArgumentError( + "Not an acceptable bind target: %s" % key + ) + + def bind_mapper( + self, mapper: _EntityBindKey[_O], bind: _SessionBind + ) -> None: + """Associate a :class:`_orm.Mapper` or arbitrary Python class with a + "bind", e.g. an :class:`_engine.Engine` or + :class:`_engine.Connection`. + + The given entity is added to a lookup used by the + :meth:`.Session.get_bind` method. + + :param mapper: a :class:`_orm.Mapper` object, + or an instance of a mapped + class, or any Python class that is the base of a set of mapped + classes. + + :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection` + object. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_table` + + + """ + self._add_bind(mapper, bind) + + def bind_table(self, table: TableClause, bind: _SessionBind) -> None: + """Associate a :class:`_schema.Table` with a "bind", e.g. an + :class:`_engine.Engine` + or :class:`_engine.Connection`. + + The given :class:`_schema.Table` is added to a lookup used by the + :meth:`.Session.get_bind` method. + + :param table: a :class:`_schema.Table` object, + which is typically the target + of an ORM mapping, or is present within a selectable that is + mapped. + + :param bind: an :class:`_engine.Engine` or :class:`_engine.Connection` + object. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_mapper` + + + """ + self._add_bind(table, bind) + + def get_bind( + self, + mapper: Optional[_EntityBindKey[_O]] = None, + *, + clause: Optional[ClauseElement] = None, + bind: Optional[_SessionBind] = None, + _sa_skip_events: Optional[bool] = None, + _sa_skip_for_implicit_returning: bool = False, + **kw: Any, + ) -> Union[Engine, Connection]: + """Return a "bind" to which this :class:`.Session` is bound. + + The "bind" is usually an instance of :class:`_engine.Engine`, + except in the case where the :class:`.Session` has been + explicitly bound directly to a :class:`_engine.Connection`. + + For a multiply-bound or unbound :class:`.Session`, the + ``mapper`` or ``clause`` arguments are used to determine the + appropriate bind to return. + + Note that the "mapper" argument is usually present + when :meth:`.Session.get_bind` is called via an ORM + operation such as a :meth:`.Session.query`, each + individual INSERT/UPDATE/DELETE operation within a + :meth:`.Session.flush`, call, etc. + + The order of resolution is: + + 1. if mapper given and :paramref:`.Session.binds` is present, + locate a bind based first on the mapper in use, then + on the mapped class in use, then on any base classes that are + present in the ``__mro__`` of the mapped class, from more specific + superclasses to more general. + 2. if clause given and ``Session.binds`` is present, + locate a bind based on :class:`_schema.Table` objects + found in the given clause present in ``Session.binds``. + 3. if ``Session.binds`` is present, return that. + 4. if clause given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the clause. + 5. if mapper given, attempt to return a bind + linked to the :class:`_schema.MetaData` ultimately + associated with the :class:`_schema.Table` or other + selectable to which the mapper is mapped. + 6. No bind can be found, :exc:`~sqlalchemy.exc.UnboundExecutionError` + is raised. + + Note that the :meth:`.Session.get_bind` method can be overridden on + a user-defined subclass of :class:`.Session` to provide any kind + of bind resolution scheme. See the example at + :ref:`session_custom_partitioning`. + + :param mapper: + Optional mapped class or corresponding :class:`_orm.Mapper` instance. + The bind can be derived from a :class:`_orm.Mapper` first by + consulting the "binds" map associated with this :class:`.Session`, + and secondly by consulting the :class:`_schema.MetaData` associated + with the :class:`_schema.Table` to which the :class:`_orm.Mapper` is + mapped for a bind. + + :param clause: + A :class:`_expression.ClauseElement` (i.e. + :func:`_expression.select`, + :func:`_expression.text`, + etc.). If the ``mapper`` argument is not present or could not + produce a bind, the given expression construct will be searched + for a bound element, typically a :class:`_schema.Table` + associated with + bound :class:`_schema.MetaData`. + + .. seealso:: + + :ref:`session_partitioning` + + :paramref:`.Session.binds` + + :meth:`.Session.bind_mapper` + + :meth:`.Session.bind_table` + + """ + + # this function is documented as a subclassing hook, so we have + # to call this method even if the return is simple + if bind: + return bind + elif not self.__binds and self.bind: + # simplest and most common case, we have a bind and no + # per-mapper/table binds, we're done + return self.bind + + # we don't have self.bind and either have self.__binds + # or we don't have self.__binds (which is legacy). Look at the + # mapper and the clause + if mapper is None and clause is None: + if self.bind: + return self.bind + else: + raise sa_exc.UnboundExecutionError( + "This session is not bound to a single Engine or " + "Connection, and no context was provided to locate " + "a binding." + ) + + # look more closely at the mapper. + if mapper is not None: + try: + inspected_mapper = inspect(mapper) + except sa_exc.NoInspectionAvailable as err: + if isinstance(mapper, type): + raise exc.UnmappedClassError(mapper) from err + else: + raise + else: + inspected_mapper = None + + # match up the mapper or clause in the __binds + if self.__binds: + # matching mappers and selectables to entries in the + # binds dictionary; supported use case. + if inspected_mapper: + for cls in inspected_mapper.class_.__mro__: + if cls in self.__binds: + return self.__binds[cls] + if clause is None: + clause = inspected_mapper.persist_selectable + + if clause is not None: + plugin_subject = clause._propagate_attrs.get( + "plugin_subject", None + ) + + if plugin_subject is not None: + for cls in plugin_subject.mapper.class_.__mro__: + if cls in self.__binds: + return self.__binds[cls] + + for obj in visitors.iterate(clause): + if obj in self.__binds: + if TYPE_CHECKING: + assert isinstance(obj, Table) + return self.__binds[obj] + + # none of the __binds matched, but we have a fallback bind. + # return that + if self.bind: + return self.bind + + context = [] + if inspected_mapper is not None: + context.append(f"mapper {inspected_mapper}") + if clause is not None: + context.append("SQL expression") + + raise sa_exc.UnboundExecutionError( + f"Could not locate a bind configured on " + f'{", ".join(context)} or this Session.' + ) + + @overload + def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... + + @overload + def query( + self, _colexpr: TypedColumnsClauseRole[_T] + ) -> RowReturningQuery[Tuple[_T]]: ... + + # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... + + @overload + def query( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + + @overload + def query( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... + + # END OVERLOADED FUNCTIONS self.query + + @overload + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: ... + + def query( + self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any + ) -> Query[Any]: + """Return a new :class:`_query.Query` object corresponding to this + :class:`_orm.Session`. + + Note that the :class:`_query.Query` object is legacy as of + SQLAlchemy 2.0; the :func:`_sql.select` construct is now used + to construct ORM queries. + + .. seealso:: + + :ref:`unified_tutorial` + + :ref:`queryguide_toplevel` + + :ref:`query_api_toplevel` - legacy API doc + + """ + + return self._query_cls(entities, self, **kwargs) + + def _identity_lookup( + self, + mapper: Mapper[_O], + primary_key_identity: Union[Any, Tuple[Any, ...]], + identity_token: Any = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Union[Optional[_O], LoaderCallableStatus]: + """Locate an object in the identity map. + + Given a primary key identity, constructs an identity key and then + looks in the session's identity map. If present, the object may + be run through unexpiration rules (e.g. load unloaded attributes, + check if was deleted). + + e.g.:: + + obj = session._identity_lookup(inspect(SomeClass), (1, )) + + :param mapper: mapper in use + :param primary_key_identity: the primary key we are searching for, as + a tuple. + :param identity_token: identity token that should be used to create + the identity key. Used as is, however overriding subclasses can + repurpose this in order to interpret the value in a special way, + such as if None then look among multiple target tokens. + :param passive: passive load flag passed to + :func:`.loading.get_from_identity`, which impacts the behavior if + the object is found; the object may be validated and/or unexpired + if the flag allows for SQL to be emitted. + :param lazy_loaded_from: an :class:`.InstanceState` that is + specifically asking for this identity as a related identity. Used + for sharding schemes where there is a correspondence between an object + and a related object being lazy-loaded (or otherwise + relationship-loaded). + + :return: None if the object is not found in the identity map, *or* + if the object was unexpired and found to have been deleted. + if passive flags disallow SQL and the object is expired, returns + PASSIVE_NO_RESULT. In all other cases the instance is returned. + + .. versionchanged:: 1.4.0 - the :meth:`.Session._identity_lookup` + method was moved from :class:`_query.Query` to + :class:`.Session`, to avoid having to instantiate the + :class:`_query.Query` object. + + + """ + + key = mapper.identity_key_from_primary_key( + primary_key_identity, identity_token=identity_token + ) + + # work around: https://github.com/python/typing/discussions/1143 + return_value = loading.get_from_identity(self, mapper, key, passive) + return return_value + + @util.non_memoized_property + @contextlib.contextmanager + def no_autoflush(self) -> Iterator[Session]: + """Return a context manager that disables autoflush. + + e.g.:: + + with session.no_autoflush: + + some_object = SomeClass() + session.add(some_object) + # won't autoflush + some_object.related_thing = session.query(SomeRelated).first() + + Operations that proceed within the ``with:`` block + will not be subject to flushes occurring upon query + access. This is useful when initializing a series + of objects which involve existing database queries, + where the uncompleted object should not yet be flushed. + + """ + autoflush = self.autoflush + self.autoflush = False + try: + yield self + finally: + self.autoflush = autoflush + + @util.langhelpers.tag_method_for_warnings( + "This warning originated from the Session 'autoflush' process, " + "which was invoked automatically in response to a user-initiated " + "operation.", + sa_exc.SAWarning, + ) + def _autoflush(self) -> None: + if self.autoflush and not self._flushing: + try: + self.flush() + except sa_exc.StatementError as e: + # note we are reraising StatementError as opposed to + # raising FlushError with "chaining" to remain compatible + # with code that catches StatementError, IntegrityError, + # etc. + e.add_detail( + "raised as a result of Query-invoked autoflush; " + "consider using a session.no_autoflush block if this " + "flush is occurring prematurely" + ) + raise e.with_traceback(sys.exc_info()[2]) + + def refresh( + self, + instance: object, + attribute_names: Optional[Iterable[str]] = None, + with_for_update: ForUpdateParameter = None, + ) -> None: + """Expire and refresh attributes on the given instance. + + The selected attributes will first be expired as they would when using + :meth:`_orm.Session.expire`; then a SELECT statement will be issued to + the database to refresh column-oriented attributes with the current + value available in the current transaction. + + :func:`_orm.relationship` oriented attributes will also be immediately + loaded if they were already eagerly loaded on the object, using the + same eager loading strategy that they were loaded with originally. + + .. versionadded:: 1.4 - the :meth:`_orm.Session.refresh` method + can also refresh eagerly loaded attributes. + + :func:`_orm.relationship` oriented attributes that would normally + load using the ``select`` (or "lazy") loader strategy will also + load **if they are named explicitly in the attribute_names + collection**, emitting a SELECT statement for the attribute using the + ``immediate`` loader strategy. If lazy-loaded relationships are not + named in :paramref:`_orm.Session.refresh.attribute_names`, then + they remain as "lazy loaded" attributes and are not implicitly + refreshed. + + .. versionchanged:: 2.0.4 The :meth:`_orm.Session.refresh` method + will now refresh lazy-loaded :func:`_orm.relationship` oriented + attributes for those which are named explicitly in the + :paramref:`_orm.Session.refresh.attribute_names` collection. + + .. tip:: + + While the :meth:`_orm.Session.refresh` method is capable of + refreshing both column and relationship oriented attributes, its + primary focus is on refreshing of local column-oriented attributes + on a single instance. For more open ended "refresh" functionality, + including the ability to refresh the attributes on many objects at + once while having explicit control over relationship loader + strategies, use the + :ref:`populate existing ` feature + instead. + + Note that a highly isolated transaction will return the same values as + were previously read in that same transaction, regardless of changes + in database state outside of that transaction. Refreshing + attributes usually only makes sense at the start of a transaction + where database rows have not yet been accessed. + + :param attribute_names: optional. An iterable collection of + string attribute names indicating a subset of attributes to + be refreshed. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.expire_all` + + :ref:`orm_queryguide_populate_existing` - allows any ORM query + to refresh objects as they would be loaded normally. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + + self._expire_state(state, attribute_names) + + # this autoflush previously used to occur as a secondary effect + # of the load_on_ident below. Meaning we'd organize the SELECT + # based on current DB pks, then flush, then if pks changed in that + # flush, crash. this was unticketed but discovered as part of + # #8703. So here, autoflush up front, dont autoflush inside + # load_on_ident. + self._autoflush() + + if with_for_update == {}: + raise sa_exc.ArgumentError( + "with_for_update should be the boolean value " + "True, or a dictionary with options. " + "A blank dictionary is ambiguous." + ) + + with_for_update = ForUpdateArg._from_argument(with_for_update) + + stmt: Select[Any] = sql.select(object_mapper(instance)) + if ( + loading.load_on_ident( + self, + stmt, + state.key, + refresh_state=state, + with_for_update=with_for_update, + only_load_props=attribute_names, + require_pk_cols=True, + # technically unnecessary as we just did autoflush + # above, however removes the additional unnecessary + # call to _autoflush() + no_autoflush=True, + is_user_refresh=True, + ) + is None + ): + raise sa_exc.InvalidRequestError( + "Could not refresh instance '%s'" % instance_str(instance) + ) + + def expire_all(self) -> None: + """Expires all persistent instances within this Session. + + When any attributes on a persistent instance is next accessed, + a query will be issued using the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire individual objects and individual attributes + on those objects, use :meth:`Session.expire`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire_all` is not usually needed, + assuming the transaction is isolated. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + """ + for state in self.identity_map.all_states(): + state._expire(state.dict, self.identity_map._modified) + + def expire( + self, instance: object, attribute_names: Optional[Iterable[str]] = None + ) -> None: + """Expire the attributes on an instance. + + Marks the attributes of an instance as out of date. When an expired + attribute is next accessed, a query will be issued to the + :class:`.Session` object's current transactional context in order to + load all expired attributes for the given instance. Note that + a highly isolated transaction will return the same values as were + previously read in that same transaction, regardless of changes + in database state outside of that transaction. + + To expire all objects in the :class:`.Session` simultaneously, + use :meth:`Session.expire_all`. + + The :class:`.Session` object's default behavior is to + expire all state whenever the :meth:`Session.rollback` + or :meth:`Session.commit` methods are called, so that new + state can be loaded for the new transaction. For this reason, + calling :meth:`Session.expire` only makes sense for the specific + case that a non-ORM SQL statement was emitted in the current + transaction. + + :param instance: The instance to be refreshed. + :param attribute_names: optional list of string attribute names + indicating a subset of attributes to be expired. + + .. seealso:: + + :ref:`session_expire` - introductory material + + :meth:`.Session.expire` + + :meth:`.Session.refresh` + + :meth:`_orm.Query.populate_existing` + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + self._expire_state(state, attribute_names) + + def _expire_state( + self, + state: InstanceState[Any], + attribute_names: Optional[Iterable[str]], + ) -> None: + self._validate_persistent(state) + if attribute_names: + state._expire_attributes(state.dict, attribute_names) + else: + # pre-fetch the full cascade since the expire is going to + # remove associations + cascaded = list( + state.manager.mapper.cascade_iterator("refresh-expire", state) + ) + self._conditional_expire(state) + for o, m, st_, dct_ in cascaded: + self._conditional_expire(st_) + + def _conditional_expire( + self, state: InstanceState[Any], autoflush: Optional[bool] = None + ) -> None: + """Expire a state if persistent, else expunge if pending""" + + if state.key: + state._expire(state.dict, self.identity_map._modified) + elif state in self._new: + self._new.pop(state) + state._detach(self) + + def expunge(self, instance: object) -> None: + """Remove the `instance` from this ``Session``. + + This will free all internal references to the instance. Cascading + will be applied according to the *expunge* cascade rule. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + if state.session_id is not self.hash_key: + raise sa_exc.InvalidRequestError( + "Instance %s is not present in this Session" % state_str(state) + ) + + cascaded = list( + state.manager.mapper.cascade_iterator("expunge", state) + ) + self._expunge_states([state] + [st_ for o, m, st_, dct_ in cascaded]) + + def _expunge_states( + self, states: Iterable[InstanceState[Any]], to_transient: bool = False + ) -> None: + for state in states: + if state in self._new: + self._new.pop(state) + elif self.identity_map.contains_state(state): + self.identity_map.safe_discard(state) + self._deleted.pop(state, None) + elif self._transaction: + # state is "detached" from being deleted, but still present + # in the transaction snapshot + self._transaction._deleted.pop(state, None) + statelib.InstanceState._detach_states( + states, self, to_transient=to_transient + ) + + def _register_persistent(self, states: Set[InstanceState[Any]]) -> None: + """Register all persistent objects from a flush. + + This is used both for pending objects moving to the persistent + state as well as already persistent objects. + + """ + + pending_to_persistent = self.dispatch.pending_to_persistent or None + for state in states: + mapper = _state_mapper(state) + + # prevent against last minute dereferences of the object + obj = state.obj() + if obj is not None: + instance_key = mapper._identity_key_from_state(state) + + if ( + _none_set.intersection(instance_key[1]) + and not mapper.allow_partial_pks + or _none_set.issuperset(instance_key[1]) + ): + raise exc.FlushError( + "Instance %s has a NULL identity key. If this is an " + "auto-generated value, check that the database table " + "allows generation of new primary key values, and " + "that the mapped Column object is configured to " + "expect these generated values. Ensure also that " + "this flush() is not occurring at an inappropriate " + "time, such as within a load() event." + % state_str(state) + ) + + if state.key is None: + state.key = instance_key + elif state.key != instance_key: + # primary key switch. use safe_discard() in case another + # state has already replaced this one in the identity + # map (see test/orm/test_naturalpks.py ReversePKsTest) + self.identity_map.safe_discard(state) + trans = self._transaction + assert trans is not None + if state in trans._key_switches: + orig_key = trans._key_switches[state][0] + else: + orig_key = state.key + trans._key_switches[state] = ( + orig_key, + instance_key, + ) + state.key = instance_key + + # there can be an existing state in the identity map + # that is replaced when the primary keys of two instances + # are swapped; see test/orm/test_naturalpks.py -> test_reverse + old = self.identity_map.replace(state) + if ( + old is not None + and mapper._identity_key_from_state(old) == instance_key + and old.obj() is not None + ): + util.warn( + "Identity map already had an identity for %s, " + "replacing it with newly flushed object. Are there " + "load operations occurring inside of an event handler " + "within the flush?" % (instance_key,) + ) + state._orphaned_outside_of_session = False + + statelib.InstanceState._commit_all_states( + ((state, state.dict) for state in states), self.identity_map + ) + + self._register_altered(states) + + if pending_to_persistent is not None: + for state in states.intersection(self._new): + pending_to_persistent(self, state) + + # remove from new last, might be the last strong ref + for state in set(states).intersection(self._new): + self._new.pop(state) + + def _register_altered(self, states: Iterable[InstanceState[Any]]) -> None: + if self._transaction: + for state in states: + if state in self._new: + self._transaction._new[state] = True + else: + self._transaction._dirty[state] = True + + def _remove_newly_deleted( + self, states: Iterable[InstanceState[Any]] + ) -> None: + persistent_to_deleted = self.dispatch.persistent_to_deleted or None + for state in states: + if self._transaction: + self._transaction._deleted[state] = True + + if persistent_to_deleted is not None: + # get a strong reference before we pop out of + # self._deleted + obj = state.obj() # noqa + + self.identity_map.safe_discard(state) + self._deleted.pop(state, None) + state._deleted = True + # can't call state._detach() here, because this state + # is still in the transaction snapshot and needs to be + # tracked as part of that + if persistent_to_deleted is not None: + persistent_to_deleted(self, state) + + def add(self, instance: object, _warn: bool = True) -> None: + """Place an object into this :class:`_orm.Session`. + + Objects that are in the :term:`transient` state when passed to the + :meth:`_orm.Session.add` method will move to the + :term:`pending` state, until the next flush, at which point they + will move to the :term:`persistent` state. + + Objects that are in the :term:`detached` state when passed to the + :meth:`_orm.Session.add` method will move to the :term:`persistent` + state directly. + + If the transaction used by the :class:`_orm.Session` is rolled back, + objects which were transient when they were passed to + :meth:`_orm.Session.add` will be moved back to the + :term:`transient` state, and will no longer be present within this + :class:`_orm.Session`. + + .. seealso:: + + :meth:`_orm.Session.add_all` + + :ref:`session_adding` - at :ref:`session_basics` + + """ + if _warn and self._warn_on_events: + self._flush_warning("Session.add()") + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + + self._save_or_update_state(state) + + def add_all(self, instances: Iterable[object]) -> None: + """Add the given collection of instances to this :class:`_orm.Session`. + + See the documentation for :meth:`_orm.Session.add` for a general + behavioral description. + + .. seealso:: + + :meth:`_orm.Session.add` + + :ref:`session_adding` - at :ref:`session_basics` + + """ + + if self._warn_on_events: + self._flush_warning("Session.add_all()") + + for instance in instances: + self.add(instance, _warn=False) + + def _save_or_update_state(self, state: InstanceState[Any]) -> None: + state._orphaned_outside_of_session = False + self._save_or_update_impl(state) + + mapper = _state_mapper(state) + for o, m, st_, dct_ in mapper.cascade_iterator( + "save-update", state, halt_on=self._contains_state + ): + self._save_or_update_impl(st_) + + def delete(self, instance: object) -> None: + """Mark an instance as deleted. + + The object is assumed to be either :term:`persistent` or + :term:`detached` when passed; after the method is called, the + object will remain in the :term:`persistent` state until the next + flush proceeds. During this time, the object will also be a member + of the :attr:`_orm.Session.deleted` collection. + + When the next flush proceeds, the object will move to the + :term:`deleted` state, indicating a ``DELETE`` statement was emitted + for its row within the current transaction. When the transaction + is successfully committed, + the deleted object is moved to the :term:`detached` state and is + no longer present within this :class:`_orm.Session`. + + .. seealso:: + + :ref:`session_deleting` - at :ref:`session_basics` + + """ + if self._warn_on_events: + self._flush_warning("Session.delete()") + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + + self._delete_impl(state, instance, head=True) + + def _delete_impl( + self, state: InstanceState[Any], obj: object, head: bool + ) -> None: + if state.key is None: + if head: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % state_str(state) + ) + else: + return + + to_attach = self._before_attach(state, obj) + + if state in self._deleted: + return + + self.identity_map.add(state) + + if to_attach: + self._after_attach(state, obj) + + if head: + # grab the cascades before adding the item to the deleted list + # so that autoflush does not delete the item + # the strong reference to the instance itself is significant here + cascade_states = list( + state.manager.mapper.cascade_iterator("delete", state) + ) + else: + cascade_states = None + + self._deleted[state] = obj + + if head: + if TYPE_CHECKING: + assert cascade_states is not None + for o, m, st_, dct_ in cascade_states: + self._delete_impl(st_, o, False) + + def get( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Optional[_O]: + """Return an instance based on the given primary key identifier, + or ``None`` if not found. + + E.g.:: + + my_user = session.get(User, 5) + + some_object = session.get(VersionedFoo, (5, 10)) + + some_object = session.get( + VersionedFoo, + {"id": 5, "version_id": 10} + ) + + .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved + from the now legacy :meth:`_orm.Query.get` method. + + :meth:`_orm.Session.get` is special in that it provides direct + access to the identity map of the :class:`.Session`. + If the given primary key identifier is present + in the local identity map, the object is returned + directly from this collection and no SQL is emitted, + unless the object has been marked fully expired. + If not present, + a SELECT is performed in order to locate the object. + + :meth:`_orm.Session.get` also will perform a check if + the object is present in the identity map and + marked as expired - a SELECT + is emitted to refresh the object as well as to + ensure that the row is still present. + If not, :class:`~sqlalchemy.orm.exc.ObjectDeletedError` is raised. + + :param entity: a mapped class or :class:`.Mapper` indicating the + type of entity to be loaded. + + :param ident: A scalar, tuple, or dictionary representing the + primary key. For a composite (e.g. multiple column) primary key, + a tuple or dictionary should be passed. + + For a single-column primary key, the scalar calling form is typically + the most expedient. If the primary key of a row is the value "5", + the call looks like:: + + my_object = session.get(SomeClass, 5) + + The tuple form contains primary key values typically in + the order in which they correspond to the mapped + :class:`_schema.Table` + object's primary key columns, or if the + :paramref:`_orm.Mapper.primary_key` configuration parameter were + used, in + the order used for that parameter. For example, if the primary key + of a row is represented by the integer + digits "5, 10" the call would look like:: + + my_object = session.get(SomeClass, (5, 10)) + + The dictionary form should include as keys the mapped attribute names + corresponding to each element of the primary key. If the mapped class + has the attributes ``id``, ``version_id`` as the attributes which + store the object's primary key value, the call would look like:: + + my_object = session.get(SomeClass, {"id": 5, "version_id": 10}) + + :param options: optional sequence of loader options which will be + applied to the query, if one is emitted. + + :param populate_existing: causes the method to unconditionally emit + a SQL query and refresh the object with the newly loaded data, + regardless of whether or not the object is already present. + + :param with_for_update: optional boolean ``True`` indicating FOR UPDATE + should be used, or may be a dictionary containing flags to + indicate a more specific set of FOR UPDATE flags for the SELECT; + flags should match the parameters of + :meth:`_query.Query.with_for_update`. + Supersedes the :paramref:`.Session.refresh.lockmode` parameter. + + :param execution_options: optional dictionary of execution options, + which will be associated with the query execution if one is emitted. + This dictionary can provide a subset of the options that are + accepted by :meth:`_engine.Connection.execution_options`, and may + also provide additional options understood only in an ORM context. + + .. versionadded:: 1.4.29 + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + .. versionadded: 2.0.0rc1 + + :return: The object instance, or ``None``. + + """ + return self._get_impl( + entity, + ident, + loading.load_on_pk_identity, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + def get_one( + self, + entity: _EntityBindKey[_O], + ident: _PKIdentityArgument, + *, + options: Optional[Sequence[ORMOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> _O: + """Return exactly one instance based on the given primary key + identifier, or raise an exception if not found. + + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query + selects no rows. + + For a detailed documentation of the arguments see the + method :meth:`.Session.get`. + + .. versionadded:: 2.0.22 + + :return: The object instance. + + .. seealso:: + + :meth:`.Session.get` - equivalent method that instead + returns ``None`` if no row was found with the provided primary + key + + """ + + instance = self.get( + entity, + ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + if instance is None: + raise sa_exc.NoResultFound( + "No row was found when one was required" + ) + + return instance + + def _get_impl( + self, + entity: _EntityBindKey[_O], + primary_key_identity: _PKIdentityArgument, + db_load_fn: Callable[..., _O], + *, + options: Optional[Sequence[ExecutableOption]] = None, + populate_existing: bool = False, + with_for_update: ForUpdateParameter = None, + identity_token: Optional[Any] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + ) -> Optional[_O]: + # convert composite types to individual args + if ( + is_composite_class(primary_key_identity) + and type(primary_key_identity) + in descriptor_props._composite_getters + ): + getter = descriptor_props._composite_getters[ + type(primary_key_identity) + ] + primary_key_identity = getter(primary_key_identity) + + mapper: Optional[Mapper[_O]] = inspect(entity) + + if mapper is None or not mapper.is_mapper: + raise sa_exc.ArgumentError( + "Expected mapped class or mapper, got: %r" % entity + ) + + is_dict = isinstance(primary_key_identity, dict) + if not is_dict: + primary_key_identity = util.to_list( + primary_key_identity, default=[None] + ) + + if len(primary_key_identity) != len(mapper.primary_key): + raise sa_exc.InvalidRequestError( + "Incorrect number of values in identifier to formulate " + "primary key for session.get(); primary key columns " + "are %s" % ",".join("'%s'" % c for c in mapper.primary_key) + ) + + if is_dict: + pk_synonyms = mapper._pk_synonyms + + if pk_synonyms: + correct_keys = set(pk_synonyms).intersection( + primary_key_identity + ) + + if correct_keys: + primary_key_identity = dict(primary_key_identity) + for k in correct_keys: + primary_key_identity[pk_synonyms[k]] = ( + primary_key_identity[k] + ) + + try: + primary_key_identity = list( + primary_key_identity[prop.key] + for prop in mapper._identity_key_props + ) + + except KeyError as err: + raise sa_exc.InvalidRequestError( + "Incorrect names of values in identifier to formulate " + "primary key for session.get(); primary key attribute " + "names are %s (synonym names are also accepted)" + % ",".join( + "'%s'" % prop.key + for prop in mapper._identity_key_props + ) + ) from err + + if ( + not populate_existing + and not mapper.always_refresh + and with_for_update is None + ): + instance = self._identity_lookup( + mapper, + primary_key_identity, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + if instance is not None: + # reject calls for id in identity map but class + # mismatch. + if not isinstance(instance, mapper.class_): + return None + return instance + + # TODO: this was being tested before, but this is not possible + assert instance is not LoaderCallableStatus.PASSIVE_CLASS_MISMATCH + + # set_label_style() not strictly necessary, however this will ensure + # that tablename_colname style is used which at the moment is + # asserted in a lot of unit tests :) + + load_options = context.QueryContext.default_load_options + + if populate_existing: + load_options += {"_populate_existing": populate_existing} + statement = sql.select(mapper).set_label_style( + LABEL_STYLE_TABLENAME_PLUS_COL + ) + if with_for_update is not None: + statement._for_update_arg = ForUpdateArg._from_argument( + with_for_update + ) + + if options: + statement = statement.options(*options) + return db_load_fn( + self, + statement, + primary_key_identity, + load_options=load_options, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, + ) + + def merge( + self, + instance: _O, + *, + load: bool = True, + options: Optional[Sequence[ORMOption]] = None, + ) -> _O: + """Copy the state of a given instance into a corresponding instance + within this :class:`.Session`. + + :meth:`.Session.merge` examines the primary key attributes of the + source instance, and attempts to reconcile it with an instance of the + same primary key in the session. If not found locally, it attempts + to load the object from the database based on primary key, and if + none can be located, creates a new instance. The state of each + attribute on the source instance is then copied to the target + instance. The resulting target instance is then returned by the + method; the original source instance is left unmodified, and + un-associated with the :class:`.Session` if not already. + + This operation cascades to associated instances if the association is + mapped with ``cascade="merge"``. + + See :ref:`unitofwork_merging` for a detailed discussion of merging. + + :param instance: Instance to be merged. + :param load: Boolean, when False, :meth:`.merge` switches into + a "high performance" mode which causes it to forego emitting history + events as well as all database access. This flag is used for + cases such as transferring graphs of objects into a :class:`.Session` + from a second level cache, or to transfer just-loaded objects + into the :class:`.Session` owned by a worker thread or process + without re-querying the database. + + The ``load=False`` use case adds the caveat that the given + object has to be in a "clean" state, that is, has no pending changes + to be flushed - even if the incoming object is detached from any + :class:`.Session`. This is so that when + the merge operation populates local attributes and + cascades to related objects and + collections, the values can be "stamped" onto the + target object as is, without generating any history or attribute + events, and without the need to reconcile the incoming data with + any existing related objects or collections that might not + be loaded. The resulting objects from ``load=False`` are always + produced as "clean", so it is only appropriate that the given objects + should be "clean" as well, else this suggests a mis-use of the + method. + :param options: optional sequence of loader options which will be + applied to the :meth:`_orm.Session.get` method when the merge + operation loads the existing version of the object from the database. + + .. versionadded:: 1.4.24 + + + .. seealso:: + + :func:`.make_transient_to_detached` - provides for an alternative + means of "merging" a single object into the :class:`.Session` + + """ + + if self._warn_on_events: + self._flush_warning("Session.merge()") + + _recursive: Dict[InstanceState[Any], object] = {} + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object] = {} + + if load: + # flush current contents if we expect to load data + self._autoflush() + + object_mapper(instance) # verify mapped + autoflush = self.autoflush + try: + self.autoflush = False + return self._merge( + attributes.instance_state(instance), + attributes.instance_dict(instance), + load=load, + options=options, + _recursive=_recursive, + _resolve_conflict_map=_resolve_conflict_map, + ) + finally: + self.autoflush = autoflush + + def _merge( + self, + state: InstanceState[_O], + state_dict: _InstanceDict, + *, + options: Optional[Sequence[ORMOption]] = None, + load: bool, + _recursive: Dict[Any, object], + _resolve_conflict_map: Dict[_IdentityKeyType[Any], object], + ) -> _O: + mapper: Mapper[_O] = _state_mapper(state) + if state in _recursive: + return cast(_O, _recursive[state]) + + new_instance = False + key = state.key + + merged: Optional[_O] + + if key is None: + if state in self._new: + util.warn( + "Instance %s is already pending in this Session yet is " + "being merged again; this is probably not what you want " + "to do" % state_str(state) + ) + + if not load: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects transient (i.e. unpersisted) objects. flush() " + "all changes on mapped instances before merging with " + "load=False." + ) + key = mapper._identity_key_from_state(state) + key_is_persistent = LoaderCallableStatus.NEVER_SET not in key[ + 1 + ] and ( + not _none_set.intersection(key[1]) + or ( + mapper.allow_partial_pks + and not _none_set.issuperset(key[1]) + ) + ) + else: + key_is_persistent = True + + if key in self.identity_map: + try: + merged = self.identity_map[key] + except KeyError: + # object was GC'ed right as we checked for it + merged = None + else: + merged = None + + if merged is None: + if key_is_persistent and key in _resolve_conflict_map: + merged = cast(_O, _resolve_conflict_map[key]) + + elif not load: + if state.modified: + raise sa_exc.InvalidRequestError( + "merge() with load=False option does not support " + "objects marked as 'dirty'. flush() all changes on " + "mapped instances before merging with load=False." + ) + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_state.key = key + self._update_impl(merged_state) + new_instance = True + + elif key_is_persistent: + merged = self.get( + mapper.class_, + key[1], + identity_token=key[2], + options=options, + ) + + if merged is None: + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + new_instance = True + self._save_or_update_state(merged_state) + else: + merged_state = attributes.instance_state(merged) + merged_dict = attributes.instance_dict(merged) + + _recursive[state] = merged + _resolve_conflict_map[key] = merged + + # check that we didn't just pull the exact same + # state out. + if state is not merged_state: + # version check if applicable + if mapper.version_id_col is not None: + existing_version = mapper._get_state_attr_by_column( + state, + state_dict, + mapper.version_id_col, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + + merged_version = mapper._get_state_attr_by_column( + merged_state, + merged_dict, + mapper.version_id_col, + passive=PassiveFlag.PASSIVE_NO_INITIALIZE, + ) + + if ( + existing_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT + and merged_version + is not LoaderCallableStatus.PASSIVE_NO_RESULT + and existing_version != merged_version + ): + raise exc.StaleDataError( + "Version id '%s' on merged state %s " + "does not match existing version '%s'. " + "Leave the version attribute unset when " + "merging to update the most recent version." + % ( + existing_version, + state_str(merged_state), + merged_version, + ) + ) + + merged_state.load_path = state.load_path + merged_state.load_options = state.load_options + + # since we are copying load_options, we need to copy + # the callables_ that would have been generated by those + # load_options. + # assumes that the callables we put in state.callables_ + # are not instance-specific (which they should not be) + merged_state._copy_callables(state) + + for prop in mapper.iterate_properties: + prop.merge( + self, + state, + state_dict, + merged_state, + merged_dict, + load, + _recursive, + _resolve_conflict_map, + ) + + if not load: + # remove any history + merged_state._commit_all(merged_dict, self.identity_map) + merged_state.manager.dispatch._sa_event_merge_wo_load( + merged_state, None + ) + + if new_instance: + merged_state.manager.dispatch.load(merged_state, None) + + return merged + + def _validate_persistent(self, state: InstanceState[Any]) -> None: + if not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persistent within this Session" + % state_str(state) + ) + + def _save_impl(self, state: InstanceState[Any]) -> None: + if state.key is not None: + raise sa_exc.InvalidRequestError( + "Object '%s' already has an identity - " + "it can't be registered as pending" % state_str(state) + ) + + obj = state.obj() + to_attach = self._before_attach(state, obj) + if state not in self._new: + self._new[state] = obj + state.insert_order = len(self._new) + if to_attach: + self._after_attach(state, obj) + + def _update_impl( + self, state: InstanceState[Any], revert_deletion: bool = False + ) -> None: + if state.key is None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % state_str(state) + ) + + if state._deleted: + if revert_deletion: + if not state._attached: + return + del state._deleted + else: + raise sa_exc.InvalidRequestError( + "Instance '%s' has been deleted. " + "Use the make_transient() " + "function to send this object back " + "to the transient state." % state_str(state) + ) + + obj = state.obj() + + # check for late gc + if obj is None: + return + + to_attach = self._before_attach(state, obj) + + self._deleted.pop(state, None) + if revert_deletion: + self.identity_map.replace(state) + else: + self.identity_map.add(state) + + if to_attach: + self._after_attach(state, obj) + elif revert_deletion: + self.dispatch.deleted_to_persistent(self, state) + + def _save_or_update_impl(self, state: InstanceState[Any]) -> None: + if state.key is None: + self._save_impl(state) + else: + self._update_impl(state) + + def enable_relationship_loading(self, obj: object) -> None: + """Associate an object with this :class:`.Session` for related + object loading. + + .. warning:: + + :meth:`.enable_relationship_loading` exists to serve special + use cases and is not recommended for general use. + + Accesses of attributes mapped with :func:`_orm.relationship` + will attempt to load a value from the database using this + :class:`.Session` as the source of connectivity. The values + will be loaded based on foreign key and primary key values + present on this object - if not present, then those relationships + will be unavailable. + + The object will be attached to this session, but will + **not** participate in any persistence operations; its state + for almost all purposes will remain either "transient" or + "detached", except for the case of relationship loading. + + Also note that backrefs will often not work as expected. + Altering a relationship-bound attribute on the target object + may not fire off a backref event, if the effective value + is what was already loaded from a foreign-key-holding value. + + The :meth:`.Session.enable_relationship_loading` method is + similar to the ``load_on_pending`` flag on :func:`_orm.relationship`. + Unlike that flag, :meth:`.Session.enable_relationship_loading` allows + an object to remain transient while still being able to load + related items. + + To make a transient object associated with a :class:`.Session` + via :meth:`.Session.enable_relationship_loading` pending, add + it to the :class:`.Session` using :meth:`.Session.add` normally. + If the object instead represents an existing identity in the database, + it should be merged using :meth:`.Session.merge`. + + :meth:`.Session.enable_relationship_loading` does not improve + behavior when the ORM is used normally - object references should be + constructed at the object level, not at the foreign key level, so + that they are present in an ordinary way before flush() + proceeds. This method is not intended for general use. + + .. seealso:: + + :paramref:`_orm.relationship.load_on_pending` - this flag + allows per-relationship loading of many-to-ones on items that + are pending. + + :func:`.make_transient_to_detached` - allows for an object to + be added to a :class:`.Session` without SQL emitted, which then + will unexpire attributes on access. + + """ + try: + state = attributes.instance_state(obj) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(obj) from err + + to_attach = self._before_attach(state, obj) + state._load_pending = True + if to_attach: + self._after_attach(state, obj) + + def _before_attach(self, state: InstanceState[Any], obj: object) -> bool: + self._autobegin_t() + + if state.session_id == self.hash_key: + return False + + if state.session_id and state.session_id in _sessions: + raise sa_exc.InvalidRequestError( + "Object '%s' is already attached to session '%s' " + "(this is '%s')" + % (state_str(state), state.session_id, self.hash_key) + ) + + self.dispatch.before_attach(self, state) + + return True + + def _after_attach(self, state: InstanceState[Any], obj: object) -> None: + state.session_id = self.hash_key + if state.modified and state._strong_obj is None: + state._strong_obj = obj + self.dispatch.after_attach(self, state) + + if state.key: + self.dispatch.detached_to_persistent(self, state) + else: + self.dispatch.transient_to_pending(self, state) + + def __contains__(self, instance: object) -> bool: + """Return True if the instance is associated with this session. + + The instance may be pending or persistent within the Session for a + result of True. + + """ + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + return self._contains_state(state) + + def __iter__(self) -> Iterator[object]: + """Iterate over all pending or persistent instances within this + Session. + + """ + return iter( + list(self._new.values()) + list(self.identity_map.values()) + ) + + def _contains_state(self, state: InstanceState[Any]) -> bool: + return state in self._new or self.identity_map.contains_state(state) + + def flush(self, objects: Optional[Sequence[Any]] = None) -> None: + """Flush all the object changes to the database. + + Writes out all pending object creations, deletions and modifications + to the database as INSERTs, DELETEs, UPDATEs, etc. Operations are + automatically ordered by the Session's unit of work dependency + solver. + + Database operations will be issued in the current transactional + context and do not affect the state of the transaction, unless an + error occurs, in which case the entire transaction is rolled back. + You may flush() as often as you like within a transaction to move + changes from Python to the database's transaction buffer. + + :param objects: Optional; restricts the flush operation to operate + only on elements that are in the given collection. + + This feature is for an extremely narrow set of use cases where + particular objects may need to be operated upon before the + full flush() occurs. It is not intended for general use. + + """ + + if self._flushing: + raise sa_exc.InvalidRequestError("Session is already flushing") + + if self._is_clean(): + return + try: + self._flushing = True + self._flush(objects) + finally: + self._flushing = False + + def _flush_warning(self, method: Any) -> None: + util.warn( + "Usage of the '%s' operation is not currently supported " + "within the execution stage of the flush process. " + "Results may not be consistent. Consider using alternative " + "event listeners or connection-level operations instead." % method + ) + + def _is_clean(self) -> bool: + return ( + not self.identity_map.check_modified() + and not self._deleted + and not self._new + ) + + def _flush(self, objects: Optional[Sequence[object]] = None) -> None: + dirty = self._dirty_states + if not dirty and not self._deleted and not self._new: + self.identity_map._modified.clear() + return + + flush_context = UOWTransaction(self) + + if self.dispatch.before_flush: + self.dispatch.before_flush(self, flush_context, objects) + # re-establish "dirty states" in case the listeners + # added + dirty = self._dirty_states + + deleted = set(self._deleted) + new = set(self._new) + + dirty = set(dirty).difference(deleted) + + # create the set of all objects we want to operate upon + if objects: + # specific list passed in + objset = set() + for o in objects: + try: + state = attributes.instance_state(o) + + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(o) from err + objset.add(state) + else: + objset = None + + # store objects whose fate has been decided + processed = set() + + # put all saves/updates into the flush context. detect top-level + # orphans and throw them into deleted. + if objset: + proc = new.union(dirty).intersection(objset).difference(deleted) + else: + proc = new.union(dirty).difference(deleted) + + for state in proc: + is_orphan = _state_mapper(state)._is_orphan(state) + + is_persistent_orphan = is_orphan and state.has_identity + + if ( + is_orphan + and not is_persistent_orphan + and state._orphaned_outside_of_session + ): + self._expunge_states([state]) + else: + _reg = flush_context.register_object( + state, isdelete=is_persistent_orphan + ) + assert _reg, "Failed to add object to the flush context!" + processed.add(state) + + # put all remaining deletes into the flush context. + if objset: + proc = deleted.intersection(objset).difference(processed) + else: + proc = deleted.difference(processed) + for state in proc: + _reg = flush_context.register_object(state, isdelete=True) + assert _reg, "Failed to add object to the flush context!" + + if not flush_context.has_work: + return + + flush_context.transaction = transaction = self._autobegin_t()._begin() + try: + self._warn_on_events = True + try: + flush_context.execute() + finally: + self._warn_on_events = False + + self.dispatch.after_flush(self, flush_context) + + flush_context.finalize_flush_changes() + + if not objects and self.identity_map._modified: + len_ = len(self.identity_map._modified) + + statelib.InstanceState._commit_all_states( + [ + (state, state.dict) + for state in self.identity_map._modified + ], + instance_dict=self.identity_map, + ) + util.warn( + "Attribute history events accumulated on %d " + "previously clean instances " + "within inner-flush event handlers have been " + "reset, and will not result in database updates. " + "Consider using set_committed_value() within " + "inner-flush event handlers to avoid this warning." % len_ + ) + + # useful assertions: + # if not objects: + # assert not self.identity_map._modified + # else: + # assert self.identity_map._modified == \ + # self.identity_map._modified.difference(objects) + + self.dispatch.after_flush_postexec(self, flush_context) + + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + + def bulk_save_objects( + self, + objects: Iterable[object], + return_defaults: bool = False, + update_changed_only: bool = True, + preserve_order: bool = True, + ) -> None: + """Perform a bulk save of the given list of objects. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. + + For general INSERT and UPDATE of existing ORM mapped objects, + prefer standard :term:`unit of work` data management patterns, + introduced in the :ref:`unified_tutorial` at + :ref:`tutorial_orm_data_manipulation`. SQLAlchemy 2.0 + now uses :ref:`engine_insertmanyvalues` with modern dialects + which solves previous issues of bulk INSERT slowness. + + :param objects: a sequence of mapped object instances. The mapped + objects are persisted as is, and are **not** associated with the + :class:`.Session` afterwards. + + For each object, whether the object is sent as an INSERT or an + UPDATE is dependent on the same rules used by the :class:`.Session` + in traditional operation; if the object has the + :attr:`.InstanceState.key` + attribute set, then the object is assumed to be "detached" and + will result in an UPDATE. Otherwise, an INSERT is used. + + In the case of an UPDATE, statements are grouped based on which + attributes have changed, and are thus to be the subject of each + SET clause. If ``update_changed_only`` is False, then all + attributes present within each object are applied to the UPDATE + statement, which may help in allowing the statements to be grouped + together into a larger executemany(), and will also reduce the + overhead of checking history on attributes. + + :param return_defaults: when True, rows that are missing values which + generate defaults, namely integer primary key defaults and sequences, + will be inserted **one at a time**, so that the primary key value + is available. In particular this will allow joined-inheritance + and other multi-table mappings to insert correctly without the need + to provide primary key values ahead of time; however, + :paramref:`.Session.bulk_save_objects.return_defaults` **greatly + reduces the performance gains** of the method overall. It is strongly + advised to please use the standard :meth:`_orm.Session.add_all` + approach. + + :param update_changed_only: when True, UPDATE statements are rendered + based on those attributes in each state that have logged changes. + When False, all attributes present are rendered into the SET clause + with the exception of primary key attributes. + + :param preserve_order: when True, the order of inserts and updates + matches exactly the order in which the objects are given. When + False, common types of objects are grouped into inserts + and updates, to allow for more batching opportunities. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_update_mappings` + + """ + + obj_states: Iterable[InstanceState[Any]] + + obj_states = (attributes.instance_state(obj) for obj in objects) + + if not preserve_order: + # the purpose of this sort is just so that common mappers + # and persistence states are grouped together, so that groupby + # will return a single group for a particular type of mapper. + # it's not trying to be deterministic beyond that. + obj_states = sorted( + obj_states, + key=lambda state: (id(state.mapper), state.key is not None), + ) + + def grouping_key( + state: InstanceState[_O], + ) -> Tuple[Mapper[_O], bool]: + return (state.mapper, state.key is not None) + + for (mapper, isupdate), states in itertools.groupby( + obj_states, grouping_key + ): + self._bulk_save_mappings( + mapper, + states, + isupdate, + True, + return_defaults, + update_changed_only, + False, + ) + + def bulk_insert_mappings( + self, + mapper: Mapper[Any], + mappings: Iterable[Dict[str, Any]], + return_defaults: bool = False, + render_nulls: bool = False, + ) -> None: + """Perform a bulk insert of the given list of mapping dictionaries. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be inserted, in terms of the attribute + names on the mapped class. If the mapping refers to multiple tables, + such as a joined-inheritance mapping, each dictionary must contain all + keys to be populated into all tables. + + :param return_defaults: when True, the INSERT process will be altered + to ensure that newly generated primary key values will be fetched. + The rationale for this parameter is typically to enable + :ref:`Joined Table Inheritance ` mappings to + be bulk inserted. + + .. note:: for backends that don't support RETURNING, the + :paramref:`_orm.Session.bulk_insert_mappings.return_defaults` + parameter can significantly decrease performance as INSERT + statements can no longer be batched. See + :ref:`engine_insertmanyvalues` + for background on which backends are affected. + + :param render_nulls: When True, a value of ``None`` will result + in a NULL value being included in the INSERT statement, rather + than the column being omitted from the INSERT. This allows all + the rows being INSERTed to have the identical set of columns which + allows the full set of rows to be batched to the DBAPI. Normally, + each column-set that contains a different combination of NULL values + than the previous row must omit a different series of columns from + the rendered INSERT statement, which means it must be emitted as a + separate statement. By passing this flag, the full set of rows + are guaranteed to be batchable into one batch; the cost however is + that server-side defaults which are invoked by an omitted column will + be skipped, so care must be taken to ensure that these are not + necessary. + + .. warning:: + + When this flag is set, **server side default SQL values will + not be invoked** for those columns that are inserted as NULL; + the NULL value will be sent explicitly. Care must be taken + to ensure that no server-side default functions need to be + invoked for the operation as a whole. + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_save_objects` + + :meth:`.Session.bulk_update_mappings` + + """ + self._bulk_save_mappings( + mapper, + mappings, + False, + False, + return_defaults, + False, + render_nulls, + ) + + def bulk_update_mappings( + self, mapper: Mapper[Any], mappings: Iterable[Dict[str, Any]] + ) -> None: + """Perform a bulk update of the given list of mapping dictionaries. + + .. legacy:: + + This method is a legacy feature as of the 2.0 series of + SQLAlchemy. For modern bulk INSERT and UPDATE, see + the sections :ref:`orm_queryguide_bulk_insert` and + :ref:`orm_queryguide_bulk_update`. The 2.0 API shares + implementation details with this method and adds new features + as well. + + :param mapper: a mapped class, or the actual :class:`_orm.Mapper` + object, + representing the single kind of object represented within the mapping + list. + + :param mappings: a sequence of dictionaries, each one containing the + state of the mapped row to be updated, in terms of the attribute names + on the mapped class. If the mapping refers to multiple tables, such + as a joined-inheritance mapping, each dictionary may contain keys + corresponding to all tables. All those keys which are present and + are not part of the primary key are applied to the SET clause of the + UPDATE statement; the primary key values, which are required, are + applied to the WHERE clause. + + + .. seealso:: + + :doc:`queryguide/dml` + + :meth:`.Session.bulk_insert_mappings` + + :meth:`.Session.bulk_save_objects` + + """ + self._bulk_save_mappings( + mapper, mappings, True, False, False, False, False + ) + + def _bulk_save_mappings( + self, + mapper: Mapper[_O], + mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], + isupdate: bool, + isstates: bool, + return_defaults: bool, + update_changed_only: bool, + render_nulls: bool, + ) -> None: + mapper = _class_to_mapper(mapper) + self._flushing = True + + transaction = self._autobegin_t()._begin() + try: + if isupdate: + bulk_persistence._bulk_update( + mapper, + mappings, + transaction, + isstates, + update_changed_only, + ) + else: + bulk_persistence._bulk_insert( + mapper, + mappings, + transaction, + isstates, + return_defaults, + render_nulls, + ) + transaction.commit() + + except: + with util.safe_reraise(): + transaction.rollback(_capture_exception=True) + finally: + self._flushing = False + + def is_modified( + self, instance: object, include_collections: bool = True + ) -> bool: + r"""Return ``True`` if the given instance has locally + modified attributes. + + This method retrieves the history for each instrumented + attribute on the instance and performs a comparison of the current + value to its previously committed value, if any. + + It is in effect a more expensive and accurate + version of checking for the given instance in the + :attr:`.Session.dirty` collection; a full test for + each attribute's net "dirty" status is performed. + + E.g.:: + + return session.is_modified(someobject) + + A few caveats to this method apply: + + * Instances present in the :attr:`.Session.dirty` collection may + report ``False`` when tested with this method. This is because + the object may have received change events via attribute mutation, + thus placing it in :attr:`.Session.dirty`, but ultimately the state + is the same as that loaded from the database, resulting in no net + change here. + * Scalar attributes may not have recorded the previously set + value when a new value was applied, if the attribute was not loaded, + or was expired, at the time the new value was received - in these + cases, the attribute is assumed to have a change, even if there is + ultimately no net change against its database value. SQLAlchemy in + most cases does not need the "old" value when a set event occurs, so + it skips the expense of a SQL call if the old value isn't present, + based on the assumption that an UPDATE of the scalar value is + usually needed, and in those few cases where it isn't, is less + expensive on average than issuing a defensive SELECT. + + The "old" value is fetched unconditionally upon set only if the + attribute container has the ``active_history`` flag set to ``True``. + This flag is set typically for primary key attributes and scalar + object references that are not a simple many-to-one. To set this + flag for any arbitrary mapped column, use the ``active_history`` + argument with :func:`.column_property`. + + :param instance: mapped instance to be tested for pending changes. + :param include_collections: Indicates if multivalued collections + should be included in the operation. Setting this to ``False`` is a + way to detect only local-column based properties (i.e. scalar columns + or many-to-one foreign keys) that would result in an UPDATE for this + instance upon flush. + + """ + state = object_state(instance) + + if not state.modified: + return False + + dict_ = state.dict + + for attr in state.manager.attributes: + if ( + not include_collections + and hasattr(attr.impl, "get_collection") + ) or not hasattr(attr.impl, "get_history"): + continue + + (added, unchanged, deleted) = attr.impl.get_history( + state, dict_, passive=PassiveFlag.NO_CHANGE + ) + + if added or deleted: + return True + else: + return False + + @property + def is_active(self) -> bool: + """True if this :class:`.Session` not in "partial rollback" state. + + .. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins + a new transaction immediately, so this attribute will be False + when the :class:`_orm.Session` is first instantiated. + + "partial rollback" state typically indicates that the flush process + of the :class:`_orm.Session` has failed, and that the + :meth:`_orm.Session.rollback` method must be emitted in order to + fully roll back the transaction. + + If this :class:`_orm.Session` is not in a transaction at all, the + :class:`_orm.Session` will autobegin when it is first used, so in this + case :attr:`_orm.Session.is_active` will return True. + + Otherwise, if this :class:`_orm.Session` is within a transaction, + and that transaction has not been rolled back internally, the + :attr:`_orm.Session.is_active` will also return True. + + .. seealso:: + + :ref:`faq_session_rollback` + + :meth:`_orm.Session.in_transaction` + + """ + return self._transaction is None or self._transaction.is_active + + @property + def _dirty_states(self) -> Iterable[InstanceState[Any]]: + """The set of all persistent states considered dirty. + + This method returns all states that were modified including + those that were possibly deleted. + + """ + return self.identity_map._dirty_states() + + @property + def dirty(self) -> IdentitySet: + """The set of all persistent instances considered dirty. + + E.g.:: + + some_mapped_object in session.dirty + + Instances are considered dirty when they were modified but not + deleted. + + Note that this 'dirty' calculation is 'optimistic'; most + attribute-setting or collection modification operations will + mark an instance as 'dirty' and place it in this set, even if + there is no net change to the attribute's value. At flush + time, the value of each attribute is compared to its + previously saved value, and if there's no net change, no SQL + operation will occur (this is a more expensive operation so + it's only done at flush time). + + To check if an instance has actionable net changes to its + attributes, use the :meth:`.Session.is_modified` method. + + """ + return IdentitySet( + [ + state.obj() + for state in self._dirty_states + if state not in self._deleted + ] + ) + + @property + def deleted(self) -> IdentitySet: + "The set of all instances marked as 'deleted' within this ``Session``" + + return util.IdentitySet(list(self._deleted.values())) + + @property + def new(self) -> IdentitySet: + "The set of all instances marked as 'new' within this ``Session``." + + return util.IdentitySet(list(self._new.values())) + + +_S = TypeVar("_S", bound="Session") + + +class sessionmaker(_SessionClassMethods, Generic[_S]): + """A configurable :class:`.Session` factory. + + The :class:`.sessionmaker` factory generates new + :class:`.Session` objects when called, creating them given + the configurational arguments established here. + + e.g.:: + + from sqlalchemy import create_engine + from sqlalchemy.orm import sessionmaker + + # an Engine, which the Session will use for connection + # resources + engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/') + + Session = sessionmaker(engine) + + with Session() as session: + session.add(some_object) + session.add(some_other_object) + session.commit() + + Context manager use is optional; otherwise, the returned + :class:`_orm.Session` object may be closed explicitly via the + :meth:`_orm.Session.close` method. Using a + ``try:/finally:`` block is optional, however will ensure that the close + takes place even if there are database errors:: + + session = Session() + try: + session.add(some_object) + session.add(some_other_object) + session.commit() + finally: + session.close() + + :class:`.sessionmaker` acts as a factory for :class:`_orm.Session` + objects in the same way as an :class:`_engine.Engine` acts as a factory + for :class:`_engine.Connection` objects. In this way it also includes + a :meth:`_orm.sessionmaker.begin` method, that provides a context + manager which both begins and commits a transaction, as well as closes + out the :class:`_orm.Session` when complete, rolling back the transaction + if any errors occur:: + + Session = sessionmaker(engine) + + with Session.begin() as session: + session.add(some_object) + session.add(some_other_object) + # commits transaction, closes session + + .. versionadded:: 1.4 + + When calling upon :class:`_orm.sessionmaker` to construct a + :class:`_orm.Session`, keyword arguments may also be passed to the + method; these arguments will override that of the globally configured + parameters. Below we use a :class:`_orm.sessionmaker` bound to a certain + :class:`_engine.Engine` to produce a :class:`_orm.Session` that is instead + bound to a specific :class:`_engine.Connection` procured from that engine:: + + Session = sessionmaker(engine) + + # bind an individual session to a connection + + with engine.connect() as connection: + with Session(bind=connection) as session: + # work with session + + The class also includes a method :meth:`_orm.sessionmaker.configure`, which + can be used to specify additional keyword arguments to the factory, which + will take effect for subsequent :class:`.Session` objects generated. This + is usually used to associate one or more :class:`_engine.Engine` objects + with an existing + :class:`.sessionmaker` factory before it is first used:: + + # application starts, sessionmaker does not have + # an engine bound yet + Session = sessionmaker() + + # ... later, when an engine URL is read from a configuration + # file or other events allow the engine to be created + engine = create_engine('sqlite:///foo.db') + Session.configure(bind=engine) + + sess = Session() + # work with session + + .. seealso:: + + :ref:`session_getting` - introductory text on creating + sessions using :class:`.sessionmaker`. + + """ + + class_: Type[_S] + + @overload + def __init__( + self, + bind: Optional[_SessionBind] = ..., + *, + class_: Type[_S], + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): ... + + @overload + def __init__( + self: "sessionmaker[Session]", + bind: Optional[_SessionBind] = ..., + *, + autoflush: bool = ..., + expire_on_commit: bool = ..., + info: Optional[_InfoType] = ..., + **kw: Any, + ): ... + + def __init__( + self, + bind: Optional[_SessionBind] = None, + *, + class_: Type[_S] = Session, # type: ignore + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[_InfoType] = None, + **kw: Any, + ): + r"""Construct a new :class:`.sessionmaker`. + + All arguments here except for ``class_`` correspond to arguments + accepted by :class:`.Session` directly. See the + :meth:`.Session.__init__` docstring for more details on parameters. + + :param bind: a :class:`_engine.Engine` or other :class:`.Connectable` + with + which newly created :class:`.Session` objects will be associated. + :param class\_: class to use in order to create new :class:`.Session` + objects. Defaults to :class:`.Session`. + :param autoflush: The autoflush setting to use with newly created + :class:`.Session` objects. + + .. seealso:: + + :ref:`session_flushing` - additional background on autoflush + + :param expire_on_commit=True: the + :paramref:`_orm.Session.expire_on_commit` setting to use + with newly created :class:`.Session` objects. + + :param info: optional dictionary of information that will be available + via :attr:`.Session.info`. Note this dictionary is *updated*, not + replaced, when the ``info`` parameter is specified to the specific + :class:`.Session` construction operation. + + :param \**kw: all other keyword arguments are passed to the + constructor of newly created :class:`.Session` objects. + + """ + kw["bind"] = bind + kw["autoflush"] = autoflush + kw["expire_on_commit"] = expire_on_commit + if info is not None: + kw["info"] = info + self.kw = kw + # make our own subclass of the given class, so that + # events can be associated with it specifically. + self.class_ = type(class_.__name__, (class_,), {}) + + def begin(self) -> contextlib.AbstractContextManager[_S]: + """Produce a context manager that both provides a new + :class:`_orm.Session` as well as a transaction that commits. + + + e.g.:: + + Session = sessionmaker(some_engine) + + with Session.begin() as session: + session.add(some_object) + + # commits transaction, closes session + + .. versionadded:: 1.4 + + + """ + + session = self() + return session._maker_context_manager() + + def __call__(self, **local_kw: Any) -> _S: + """Produce a new :class:`.Session` object using the configuration + established in this :class:`.sessionmaker`. + + In Python, the ``__call__`` method is invoked on an object when + it is "called" in the same way as a function:: + + Session = sessionmaker(some_engine) + session = Session() # invokes sessionmaker.__call__() + + """ + for k, v in self.kw.items(): + if k == "info" and "info" in local_kw: + d = v.copy() + d.update(local_kw["info"]) + local_kw["info"] = d + else: + local_kw.setdefault(k, v) + return self.class_(**local_kw) + + def configure(self, **new_kw: Any) -> None: + """(Re)configure the arguments for this sessionmaker. + + e.g.:: + + Session = sessionmaker() + + Session.configure(bind=create_engine('sqlite://')) + """ + self.kw.update(new_kw) + + def __repr__(self) -> str: + return "%s(class_=%r, %s)" % ( + self.__class__.__name__, + self.class_.__name__, + ", ".join("%s=%r" % (k, v) for k, v in self.kw.items()), + ) + + +def close_all_sessions() -> None: + """Close all sessions in memory. + + This function consults a global registry of all :class:`.Session` objects + and calls :meth:`.Session.close` on them, which resets them to a clean + state. + + This function is not for general use but may be useful for test suites + within the teardown scheme. + + .. versionadded:: 1.3 + + """ + + for sess in _sessions.values(): + sess.close() + + +def make_transient(instance: object) -> None: + """Alter the state of the given instance so that it is :term:`transient`. + + .. note:: + + :func:`.make_transient` is a special-case function for + advanced use cases only. + + The given mapped instance is assumed to be in the :term:`persistent` or + :term:`detached` state. The function will remove its association with any + :class:`.Session` as well as its :attr:`.InstanceState.identity`. The + effect is that the object will behave as though it were newly constructed, + except retaining any attribute / collection values that were loaded at the + time of the call. The :attr:`.InstanceState.deleted` flag is also reset + if this object had been deleted as a result of using + :meth:`.Session.delete`. + + .. warning:: + + :func:`.make_transient` does **not** "unexpire" or otherwise eagerly + load ORM-mapped attributes that are not currently loaded at the time + the function is called. This includes attributes which: + + * were expired via :meth:`.Session.expire` + + * were expired as the natural effect of committing a session + transaction, e.g. :meth:`.Session.commit` + + * are normally :term:`lazy loaded` but are not currently loaded + + * are "deferred" (see :ref:`orm_queryguide_column_deferral`) and are + not yet loaded + + * were not present in the query which loaded this object, such as that + which is common in joined table inheritance and other scenarios. + + After :func:`.make_transient` is called, unloaded attributes such + as those above will normally resolve to the value ``None`` when + accessed, or an empty collection for a collection-oriented attribute. + As the object is transient and un-associated with any database + identity, it will no longer retrieve these values. + + .. seealso:: + + :func:`.make_transient_to_detached` + + """ + state = attributes.instance_state(instance) + s = _state_session(state) + if s: + s._expunge_states([state]) + + # remove expired state + state.expired_attributes.clear() + + # remove deferred callables + if state.callables: + del state.callables + + if state.key: + del state.key + if state._deleted: + del state._deleted + + +def make_transient_to_detached(instance: object) -> None: + """Make the given transient instance :term:`detached`. + + .. note:: + + :func:`.make_transient_to_detached` is a special-case function for + advanced use cases only. + + All attribute history on the given instance + will be reset as though the instance were freshly loaded + from a query. Missing attributes will be marked as expired. + The primary key attributes of the object, which are required, will be made + into the "key" of the instance. + + The object can then be added to a session, or merged + possibly with the load=False flag, at which point it will look + as if it were loaded that way, without emitting SQL. + + This is a special use case function that differs from a normal + call to :meth:`.Session.merge` in that a given persistent state + can be manufactured without any SQL calls. + + .. seealso:: + + :func:`.make_transient` + + :meth:`.Session.enable_relationship_loading` + + """ + state = attributes.instance_state(instance) + if state.session_id or state.key: + raise sa_exc.InvalidRequestError("Given object must be transient") + state.key = state.mapper._identity_key_from_state(state) + if state._deleted: + del state._deleted + state._commit_all(state.dict) + state._expire_attributes(state.dict, state.unloaded) + + +def object_session(instance: object) -> Optional[Session]: + """Return the :class:`.Session` to which the given instance belongs. + + This is essentially the same as the :attr:`.InstanceState.session` + accessor. See that attribute for details. + + """ + + try: + state = attributes.instance_state(instance) + except exc.NO_STATE as err: + raise exc.UnmappedInstanceError(instance) from err + else: + return _state_session(state) + + +_new_sessionid = util.counter() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/state.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/state.py new file mode 100644 index 0000000..03b81f9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/state.py @@ -0,0 +1,1136 @@ +# orm/state.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Defines instrumentation of instances. + +This module is usually not directly visible to user applications, but +defines a large part of the ORM's interactivity. + +""" + +from __future__ import annotations + +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Optional +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from . import base +from . import exc as orm_exc +from . import interfaces +from ._typing import _O +from ._typing import is_collection_impl +from .base import ATTR_WAS_SET +from .base import INIT_OK +from .base import LoaderCallableStatus +from .base import NEVER_SET +from .base import NO_VALUE +from .base import PASSIVE_NO_INITIALIZE +from .base import PASSIVE_NO_RESULT +from .base import PASSIVE_OFF +from .base import SQL_OK +from .path_registry import PathRegistry +from .. import exc as sa_exc +from .. import inspection +from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ._typing import _IdentityKeyType + from ._typing import _InstanceDict + from ._typing import _LoaderCallable + from .attributes import AttributeImpl + from .attributes import History + from .base import PassiveFlag + from .collections import _AdaptedCollectionProtocol + from .identity import IdentityMap + from .instrumentation import ClassManager + from .interfaces import ORMOption + from .mapper import Mapper + from .session import Session + from ..engine import Row + from ..ext.asyncio.session import async_session as _async_provider + from ..ext.asyncio.session import AsyncSession + +if TYPE_CHECKING: + _sessions: weakref.WeakValueDictionary[int, Session] +else: + # late-populated by session.py + _sessions = None + + +if not TYPE_CHECKING: + # optionally late-provided by sqlalchemy.ext.asyncio.session + + _async_provider = None # noqa + + +class _InstanceDictProto(Protocol): + def __call__(self) -> Optional[IdentityMap]: ... + + +class _InstallLoaderCallableProto(Protocol[_O]): + """used at result loading time to install a _LoaderCallable callable + upon a specific InstanceState, which will be used to populate an + attribute when that attribute is accessed. + + Concrete examples are per-instance deferred column loaders and + relationship lazy loaders. + + """ + + def __call__( + self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] + ) -> None: ... + + +@inspection._self_inspects +class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): + """tracks state information at the instance level. + + The :class:`.InstanceState` is a key object used by the + SQLAlchemy ORM in order to track the state of an object; + it is created the moment an object is instantiated, typically + as a result of :term:`instrumentation` which SQLAlchemy applies + to the ``__init__()`` method of the class. + + :class:`.InstanceState` is also a semi-public object, + available for runtime inspection as to the state of a + mapped instance, including information such as its current + status within a particular :class:`.Session` and details + about data on individual attributes. The public API + in order to acquire a :class:`.InstanceState` object + is to use the :func:`_sa.inspect` system:: + + >>> from sqlalchemy import inspect + >>> insp = inspect(some_mapped_object) + >>> insp.attrs.nickname.history + History(added=['new nickname'], unchanged=(), deleted=['nickname']) + + .. seealso:: + + :ref:`orm_mapper_inspection_instancestate` + + """ + + __slots__ = ( + "__dict__", + "__weakref__", + "class_", + "manager", + "obj", + "committed_state", + "expired_attributes", + ) + + manager: ClassManager[_O] + session_id: Optional[int] = None + key: Optional[_IdentityKeyType[_O]] = None + runid: Optional[int] = None + load_options: Tuple[ORMOption, ...] = () + load_path: PathRegistry = PathRegistry.root + insert_order: Optional[int] = None + _strong_obj: Optional[object] = None + obj: weakref.ref[_O] + + committed_state: Dict[str, Any] + + modified: bool = False + expired: bool = False + _deleted: bool = False + _load_pending: bool = False + _orphaned_outside_of_session: bool = False + is_instance: bool = True + identity_token: object = None + _last_known_values: Optional[Dict[str, Any]] = None + + _instance_dict: _InstanceDictProto + """A weak reference, or in the default case a plain callable, that + returns a reference to the current :class:`.IdentityMap`, if any. + + """ + if not TYPE_CHECKING: + + def _instance_dict(self): + """default 'weak reference' for _instance_dict""" + return None + + expired_attributes: Set[str] + """The set of keys which are 'expired' to be loaded by + the manager's deferred scalar loader, assuming no pending + changes. + + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs.""" + + callables: Dict[str, Callable[[InstanceState[_O], PassiveFlag], Any]] + """A namespace where a per-state loader callable can be associated. + + In SQLAlchemy 1.0, this is only used for lazy loaders / deferred + loaders that were set up via query option. + + Previously, callables was used also to indicate expired attributes + by storing a link to the InstanceState itself in this dictionary. + This role is now handled by the expired_attributes set. + + """ + + if not TYPE_CHECKING: + callables = util.EMPTY_DICT + + def __init__(self, obj: _O, manager: ClassManager[_O]): + self.class_ = obj.__class__ + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) + self.committed_state = {} + self.expired_attributes = set() + + @util.memoized_property + def attrs(self) -> util.ReadOnlyProperties[AttributeState]: + """Return a namespace representing each attribute on + the mapped object, including its current value + and history. + + The returned object is an instance of :class:`.AttributeState`. + This object allows inspection of the current data + within an attribute as well as attribute history + since the last flush. + + """ + return util.ReadOnlyProperties( + {key: AttributeState(self, key) for key in self.manager} + ) + + @property + def transient(self) -> bool: + """Return ``True`` if the object is :term:`transient`. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is None and not self._attached + + @property + def pending(self) -> bool: + """Return ``True`` if the object is :term:`pending`. + + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is None and self._attached + + @property + def deleted(self) -> bool: + """Return ``True`` if the object is :term:`deleted`. + + An object that is in the deleted state is guaranteed to + not be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`; however if the session's transaction is rolled + back, the object will be restored to the persistent state and + the identity map. + + .. note:: + + The :attr:`.InstanceState.deleted` attribute refers to a specific + state of the object that occurs between the "persistent" and + "detached" states; once the object is :term:`detached`, the + :attr:`.InstanceState.deleted` attribute **no longer returns + True**; in order to detect that a state was deleted, regardless + of whether or not the object is associated with a + :class:`.Session`, use the :attr:`.InstanceState.was_deleted` + accessor. + + .. versionadded: 1.1 + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and self._attached and self._deleted + + @property + def was_deleted(self) -> bool: + """Return True if this object is or was previously in the + "deleted" state and has not been reverted to persistent. + + This flag returns True once the object was deleted in flush. + When the object is expunged from the session either explicitly + or via transaction commit and enters the "detached" state, + this flag will continue to report True. + + .. seealso:: + + :attr:`.InstanceState.deleted` - refers to the "deleted" state + + :func:`.orm.util.was_deleted` - standalone function + + :ref:`session_object_states` + + """ + return self._deleted + + @property + def persistent(self) -> bool: + """Return ``True`` if the object is :term:`persistent`. + + An object that is in the persistent state is guaranteed to + be within the :attr:`.Session.identity_map` of its parent + :class:`.Session`. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and self._attached and not self._deleted + + @property + def detached(self) -> bool: + """Return ``True`` if the object is :term:`detached`. + + .. seealso:: + + :ref:`session_object_states` + + """ + return self.key is not None and not self._attached + + @util.non_memoized_property + @util.preload_module("sqlalchemy.orm.session") + def _attached(self) -> bool: + return ( + self.session_id is not None + and self.session_id in util.preloaded.orm_session._sessions + ) + + def _track_last_known_value(self, key: str) -> None: + """Track the last known value of a particular key after expiration + operations. + + .. versionadded:: 1.3 + + """ + + lkv = self._last_known_values + if lkv is None: + self._last_known_values = lkv = {} + if key not in lkv: + lkv[key] = NO_VALUE + + @property + def session(self) -> Optional[Session]: + """Return the owning :class:`.Session` for this instance, + or ``None`` if none available. + + Note that the result here can in some cases be *different* + from that of ``obj in session``; an object that's been deleted + will report as not ``in session``, however if the transaction is + still in progress, this attribute will still refer to that session. + Only when the transaction is completed does the object become + fully detached under normal circumstances. + + .. seealso:: + + :attr:`_orm.InstanceState.async_session` + + """ + if self.session_id: + try: + return _sessions[self.session_id] + except KeyError: + pass + return None + + @property + def async_session(self) -> Optional[AsyncSession]: + """Return the owning :class:`_asyncio.AsyncSession` for this instance, + or ``None`` if none available. + + This attribute is only non-None when the :mod:`sqlalchemy.ext.asyncio` + API is in use for this ORM object. The returned + :class:`_asyncio.AsyncSession` object will be a proxy for the + :class:`_orm.Session` object that would be returned from the + :attr:`_orm.InstanceState.session` attribute for this + :class:`_orm.InstanceState`. + + .. versionadded:: 1.4.18 + + .. seealso:: + + :ref:`asyncio_toplevel` + + """ + if _async_provider is None: + return None + + sess = self.session + if sess is not None: + return _async_provider(sess) + else: + return None + + @property + def object(self) -> Optional[_O]: + """Return the mapped object represented by this + :class:`.InstanceState`. + + Returns None if the object has been garbage collected + + """ + return self.obj() + + @property + def identity(self) -> Optional[Tuple[Any, ...]]: + """Return the mapped identity of the mapped object. + This is the primary key identity as persisted by the ORM + which can always be passed directly to + :meth:`_query.Query.get`. + + Returns ``None`` if the object has no primary key identity. + + .. note:: + An object which is :term:`transient` or :term:`pending` + does **not** have a mapped identity until it is flushed, + even if its attributes include primary key values. + + """ + if self.key is None: + return None + else: + return self.key[1] + + @property + def identity_key(self) -> Optional[_IdentityKeyType[_O]]: + """Return the identity key for the mapped object. + + This is the key used to locate the object within + the :attr:`.Session.identity_map` mapping. It contains + the identity as returned by :attr:`.identity` within it. + + + """ + return self.key + + @util.memoized_property + def parents(self) -> Dict[int, Union[Literal[False], InstanceState[Any]]]: + return {} + + @util.memoized_property + def _pending_mutations(self) -> Dict[str, PendingCollection]: + return {} + + @util.memoized_property + def _empty_collections(self) -> Dict[str, _AdaptedCollectionProtocol]: + return {} + + @util.memoized_property + def mapper(self) -> Mapper[_O]: + """Return the :class:`_orm.Mapper` used for this mapped object.""" + return self.manager.mapper + + @property + def has_identity(self) -> bool: + """Return ``True`` if this object has an identity key. + + This should always have the same value as the + expression ``state.persistent`` or ``state.detached``. + + """ + return bool(self.key) + + @classmethod + def _detach_states( + self, + states: Iterable[InstanceState[_O]], + session: Session, + to_transient: bool = False, + ) -> None: + persistent_to_detached = ( + session.dispatch.persistent_to_detached or None + ) + deleted_to_detached = session.dispatch.deleted_to_detached or None + pending_to_transient = session.dispatch.pending_to_transient or None + persistent_to_transient = ( + session.dispatch.persistent_to_transient or None + ) + + for state in states: + deleted = state._deleted + pending = state.key is None + persistent = not pending and not deleted + + state.session_id = None + + if to_transient and state.key: + del state.key + if persistent: + if to_transient: + if persistent_to_transient is not None: + persistent_to_transient(session, state) + elif persistent_to_detached is not None: + persistent_to_detached(session, state) + elif deleted and deleted_to_detached is not None: + deleted_to_detached(session, state) + elif pending and pending_to_transient is not None: + pending_to_transient(session, state) + + state._strong_obj = None + + def _detach(self, session: Optional[Session] = None) -> None: + if session: + InstanceState._detach_states([self], session) + else: + self.session_id = self._strong_obj = None + + def _dispose(self) -> None: + # used by the test suite, apparently + self._detach() + + def _cleanup(self, ref: weakref.ref[_O]) -> None: + """Weakref callback cleanup. + + This callable cleans out the state when it is being garbage + collected. + + this _cleanup **assumes** that there are no strong refs to us! + Will not work otherwise! + + """ + + # Python builtins become undefined during interpreter shutdown. + # Guard against exceptions during this phase, as the method cannot + # proceed in any case if builtins have been undefined. + if dict is None: + return + + instance_dict = self._instance_dict() + if instance_dict is not None: + instance_dict._fast_discard(self) + del self._instance_dict + + # we can't possibly be in instance_dict._modified + # b.c. this is weakref cleanup only, that set + # is strong referencing! + # assert self not in instance_dict._modified + + self.session_id = self._strong_obj = None + + @property + def dict(self) -> _InstanceDict: + """Return the instance dict used by the object. + + Under normal circumstances, this is always synonymous + with the ``__dict__`` attribute of the mapped object, + unless an alternative instrumentation system has been + configured. + + In the case that the actual object has been garbage + collected, this accessor returns a blank dictionary. + + """ + o = self.obj() + if o is not None: + return base.instance_dict(o) + else: + return {} + + def _initialize_instance(*mixed: Any, **kwargs: Any) -> None: + self, instance, args = mixed[0], mixed[1], mixed[2:] # noqa + manager = self.manager + + manager.dispatch.init(self, args, kwargs) + + try: + manager.original_init(*mixed[1:], **kwargs) + except: + with util.safe_reraise(): + manager.dispatch.init_failure(self, args, kwargs) + + def get_history(self, key: str, passive: PassiveFlag) -> History: + return self.manager[key].impl.get_history(self, self.dict, passive) + + def get_impl(self, key: str) -> AttributeImpl: + return self.manager[key].impl + + def _get_pending_mutation(self, key: str) -> PendingCollection: + if key not in self._pending_mutations: + self._pending_mutations[key] = PendingCollection() + return self._pending_mutations[key] + + def __getstate__(self) -> Dict[str, Any]: + state_dict: Dict[str, Any] = { + "instance": self.obj(), + "class_": self.class_, + "committed_state": self.committed_state, + "expired_attributes": self.expired_attributes, + } + state_dict.update( + (k, self.__dict__[k]) + for k in ( + "_pending_mutations", + "modified", + "expired", + "callables", + "key", + "parents", + "load_options", + "class_", + "expired_attributes", + "info", + ) + if k in self.__dict__ + ) + if self.load_path: + state_dict["load_path"] = self.load_path.serialize() + + state_dict["manager"] = self.manager._serialize(self, state_dict) + + return state_dict + + def __setstate__(self, state_dict: Dict[str, Any]) -> None: + inst = state_dict["instance"] + if inst is not None: + self.obj = weakref.ref(inst, self._cleanup) + self.class_ = inst.__class__ + else: + self.obj = lambda: None # type: ignore + self.class_ = state_dict["class_"] + + self.committed_state = state_dict.get("committed_state", {}) + self._pending_mutations = state_dict.get("_pending_mutations", {}) + self.parents = state_dict.get("parents", {}) + self.modified = state_dict.get("modified", False) + self.expired = state_dict.get("expired", False) + if "info" in state_dict: + self.info.update(state_dict["info"]) + if "callables" in state_dict: + self.callables = state_dict["callables"] + + self.expired_attributes = state_dict["expired_attributes"] + else: + if "expired_attributes" in state_dict: + self.expired_attributes = state_dict["expired_attributes"] + else: + self.expired_attributes = set() + + self.__dict__.update( + [ + (k, state_dict[k]) + for k in ("key", "load_options") + if k in state_dict + ] + ) + if self.key: + self.identity_token = self.key[2] + + if "load_path" in state_dict: + self.load_path = PathRegistry.deserialize(state_dict["load_path"]) + + state_dict["manager"](self, inst, state_dict) + + def _reset(self, dict_: _InstanceDict, key: str) -> None: + """Remove the given attribute and any + callables associated with it.""" + + old = dict_.pop(key, None) + manager_impl = self.manager[key].impl + if old is not None and is_collection_impl(manager_impl): + manager_impl._invalidate_collection(old) + self.expired_attributes.discard(key) + if self.callables: + self.callables.pop(key, None) + + def _copy_callables(self, from_: InstanceState[Any]) -> None: + if "callables" in from_.__dict__: + self.callables = dict(from_.callables) + + @classmethod + def _instance_level_callable_processor( + cls, manager: ClassManager[_O], fn: _LoaderCallable, key: Any + ) -> _InstallLoaderCallableProto[_O]: + impl = manager[key].impl + if is_collection_impl(impl): + fixed_impl = impl + + def _set_callable( + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] + ) -> None: + if "callables" not in state.__dict__: + state.callables = {} + old = dict_.pop(key, None) + if old is not None: + fixed_impl._invalidate_collection(old) + state.callables[key] = fn + + else: + + def _set_callable( + state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] + ) -> None: + if "callables" not in state.__dict__: + state.callables = {} + state.callables[key] = fn + + return _set_callable + + def _expire( + self, dict_: _InstanceDict, modified_set: Set[InstanceState[Any]] + ) -> None: + self.expired = True + if self.modified: + modified_set.discard(self) + self.committed_state.clear() + self.modified = False + + self._strong_obj = None + + if "_pending_mutations" in self.__dict__: + del self.__dict__["_pending_mutations"] + + if "parents" in self.__dict__: + del self.__dict__["parents"] + + self.expired_attributes.update( + [impl.key for impl in self.manager._loader_impls] + ) + + if self.callables: + # the per state loader callables we can remove here are + # LoadDeferredColumns, which undefers a column at the instance + # level that is mapped with deferred, and LoadLazyAttribute, + # which lazy loads a relationship at the instance level that + # is mapped with "noload" or perhaps "immediateload". + # Before 1.4, only column-based + # attributes could be considered to be "expired", so here they + # were the only ones "unexpired", which means to make them deferred + # again. For the moment, as of 1.4 we also apply the same + # treatment relationships now, that is, an instance level lazy + # loader is reset in the same way as a column loader. + for k in self.expired_attributes.intersection(self.callables): + del self.callables[k] + + for k in self.manager._collection_impl_keys.intersection(dict_): + collection = dict_.pop(k) + collection._sa_adapter.invalidated = True + + if self._last_known_values: + self._last_known_values.update( + {k: dict_[k] for k in self._last_known_values if k in dict_} + ) + + for key in self.manager._all_key_set.intersection(dict_): + del dict_[key] + + self.manager.dispatch.expire(self, None) + + def _expire_attributes( + self, + dict_: _InstanceDict, + attribute_names: Iterable[str], + no_loader: bool = False, + ) -> None: + pending = self.__dict__.get("_pending_mutations", None) + + callables = self.callables + + for key in attribute_names: + impl = self.manager[key].impl + if impl.accepts_scalar_loader: + if no_loader and (impl.callable_ or key in callables): + continue + + self.expired_attributes.add(key) + if callables and key in callables: + del callables[key] + old = dict_.pop(key, NO_VALUE) + if is_collection_impl(impl) and old is not NO_VALUE: + impl._invalidate_collection(old) + + lkv = self._last_known_values + if lkv is not None and key in lkv and old is not NO_VALUE: + lkv[key] = old + + self.committed_state.pop(key, None) + if pending: + pending.pop(key, None) + + self.manager.dispatch.expire(self, attribute_names) + + def _load_expired( + self, state: InstanceState[_O], passive: PassiveFlag + ) -> LoaderCallableStatus: + """__call__ allows the InstanceState to act as a deferred + callable for loading expired attributes, which is also + serializable (picklable). + + """ + + if not passive & SQL_OK: + return PASSIVE_NO_RESULT + + toload = self.expired_attributes.intersection(self.unmodified) + toload = toload.difference( + attr + for attr in toload + if not self.manager[attr].impl.load_on_unexpire + ) + + self.manager.expired_attribute_loader(self, toload, passive) + + # if the loader failed, or this + # instance state didn't have an identity, + # the attributes still might be in the callables + # dict. ensure they are removed. + self.expired_attributes.clear() + + return ATTR_WAS_SET + + @property + def unmodified(self) -> Set[str]: + """Return the set of keys which have no uncommitted changes""" + + return set(self.manager).difference(self.committed_state) + + def unmodified_intersection(self, keys: Iterable[str]) -> Set[str]: + """Return self.unmodified.intersection(keys).""" + + return ( + set(keys) + .intersection(self.manager) + .difference(self.committed_state) + ) + + @property + def unloaded(self) -> Set[str]: + """Return the set of keys which do not have a loaded value. + + This includes expired attributes and any other attribute that was never + populated or modified. + + """ + return ( + set(self.manager) + .difference(self.committed_state) + .difference(self.dict) + ) + + @property + @util.deprecated( + "2.0", + "The :attr:`.InstanceState.unloaded_expirable` attribute is " + "deprecated. Please use :attr:`.InstanceState.unloaded`.", + ) + def unloaded_expirable(self) -> Set[str]: + """Synonymous with :attr:`.InstanceState.unloaded`. + + This attribute was added as an implementation-specific detail at some + point and should be considered to be private. + + """ + return self.unloaded + + @property + def _unloaded_non_object(self) -> Set[str]: + return self.unloaded.intersection( + attr + for attr in self.manager + if self.manager[attr].impl.accepts_scalar_loader + ) + + def _modified_event( + self, + dict_: _InstanceDict, + attr: Optional[AttributeImpl], + previous: Any, + collection: bool = False, + is_userland: bool = False, + ) -> None: + if attr: + if not attr.send_modified_events: + return + if is_userland and attr.key not in dict_: + raise sa_exc.InvalidRequestError( + "Can't flag attribute '%s' modified; it's not present in " + "the object state" % attr.key + ) + if attr.key not in self.committed_state or is_userland: + if collection: + if TYPE_CHECKING: + assert is_collection_impl(attr) + if previous is NEVER_SET: + if attr.key in dict_: + previous = dict_[attr.key] + + if previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + self.committed_state[attr.key] = previous + + lkv = self._last_known_values + if lkv is not None and attr.key in lkv: + lkv[attr.key] = NO_VALUE + + # assert self._strong_obj is None or self.modified + + if (self.session_id and self._strong_obj is None) or not self.modified: + self.modified = True + instance_dict = self._instance_dict() + if instance_dict: + has_modified = bool(instance_dict._modified) + instance_dict._modified.add(self) + else: + has_modified = False + + # only create _strong_obj link if attached + # to a session + + inst = self.obj() + if self.session_id: + self._strong_obj = inst + + # if identity map already had modified objects, + # assume autobegin already occurred, else check + # for autobegin + if not has_modified: + # inline of autobegin, to ensure session transaction + # snapshot is established + try: + session = _sessions[self.session_id] + except KeyError: + pass + else: + if session._transaction is None: + session._autobegin_t() + + if inst is None and attr: + raise orm_exc.ObjectDereferencedError( + "Can't emit change event for attribute '%s' - " + "parent object of type %s has been garbage " + "collected." + % (self.manager[attr.key], base.state_class_str(self)) + ) + + def _commit(self, dict_: _InstanceDict, keys: Iterable[str]) -> None: + """Commit attributes. + + This is used by a partial-attribute load operation to mark committed + those attributes which were refreshed from the database. + + Attributes marked as "expired" can potentially remain "expired" after + this step if a value was not populated in state.dict. + + """ + for key in keys: + self.committed_state.pop(key, None) + + self.expired = False + + self.expired_attributes.difference_update( + set(keys).intersection(dict_) + ) + + # the per-keys commit removes object-level callables, + # while that of commit_all does not. it's not clear + # if this behavior has a clear rationale, however tests do + # ensure this is what it does. + if self.callables: + for key in ( + set(self.callables).intersection(keys).intersection(dict_) + ): + del self.callables[key] + + def _commit_all( + self, dict_: _InstanceDict, instance_dict: Optional[IdentityMap] = None + ) -> None: + """commit all attributes unconditionally. + + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers for scalar attributes loaded are removed. + - lazy load callables for objects / collections *stay* + + Attributes marked as "expired" can potentially remain + "expired" after this step if a value was not populated in state.dict. + + """ + self._commit_all_states([(self, dict_)], instance_dict) + + @classmethod + def _commit_all_states( + self, + iter_: Iterable[Tuple[InstanceState[Any], _InstanceDict]], + instance_dict: Optional[IdentityMap] = None, + ) -> None: + """Mass / highly inlined version of commit_all().""" + + for state, dict_ in iter_: + state_dict = state.__dict__ + + state.committed_state.clear() + + if "_pending_mutations" in state_dict: + del state_dict["_pending_mutations"] + + state.expired_attributes.difference_update(dict_) + + if instance_dict and state.modified: + instance_dict._modified.discard(state) + + state.modified = state.expired = False + state._strong_obj = None + + +class AttributeState: + """Provide an inspection interface corresponding + to a particular attribute on a particular mapped object. + + The :class:`.AttributeState` object is accessed + via the :attr:`.InstanceState.attrs` collection + of a particular :class:`.InstanceState`:: + + from sqlalchemy import inspect + + insp = inspect(some_mapped_object) + attr_state = insp.attrs.some_attribute + + """ + + __slots__ = ("state", "key") + + state: InstanceState[Any] + key: str + + def __init__(self, state: InstanceState[Any], key: str): + self.state = state + self.key = key + + @property + def loaded_value(self) -> Any: + """The current value of this attribute as loaded from the database. + + If the value has not been loaded, or is otherwise not present + in the object's dictionary, returns NO_VALUE. + + """ + return self.state.dict.get(self.key, NO_VALUE) + + @property + def value(self) -> Any: + """Return the value of this attribute. + + This operation is equivalent to accessing the object's + attribute directly or via ``getattr()``, and will fire + off any pending loader callables if needed. + + """ + return self.state.manager[self.key].__get__( + self.state.obj(), self.state.class_ + ) + + @property + def history(self) -> History: + """Return the current **pre-flush** change history for + this attribute, via the :class:`.History` interface. + + This method will **not** emit loader callables if the value of the + attribute is unloaded. + + .. note:: + + The attribute history system tracks changes on a **per flush + basis**. Each time the :class:`.Session` is flushed, the history + of each attribute is reset to empty. The :class:`.Session` by + default autoflushes each time a :class:`_query.Query` is invoked. + For + options on how to control this, see :ref:`session_flushing`. + + + .. seealso:: + + :meth:`.AttributeState.load_history` - retrieve history + using loader callables if the value is not locally present. + + :func:`.attributes.get_history` - underlying function + + """ + return self.state.get_history(self.key, PASSIVE_NO_INITIALIZE) + + def load_history(self) -> History: + """Return the current **pre-flush** change history for + this attribute, via the :class:`.History` interface. + + This method **will** emit loader callables if the value of the + attribute is unloaded. + + .. note:: + + The attribute history system tracks changes on a **per flush + basis**. Each time the :class:`.Session` is flushed, the history + of each attribute is reset to empty. The :class:`.Session` by + default autoflushes each time a :class:`_query.Query` is invoked. + For + options on how to control this, see :ref:`session_flushing`. + + .. seealso:: + + :attr:`.AttributeState.history` + + :func:`.attributes.get_history` - underlying function + + """ + return self.state.get_history(self.key, PASSIVE_OFF ^ INIT_OK) + + +class PendingCollection: + """A writable placeholder for an unloaded collection. + + Stores items appended to and removed from a collection that has not yet + been loaded. When the collection is loaded, the changes stored in + PendingCollection are applied to it to produce the final result. + + """ + + __slots__ = ("deleted_items", "added_items") + + deleted_items: util.IdentitySet + added_items: util.OrderedIdentitySet + + def __init__(self) -> None: + self.deleted_items = util.IdentitySet() + self.added_items = util.OrderedIdentitySet() + + def merge_with_history(self, history: History) -> History: + return history._merge(self.added_items, self.deleted_items) + + def append(self, value: Any) -> None: + if value in self.deleted_items: + self.deleted_items.remove(value) + else: + self.added_items.add(value) + + def remove(self, value: Any) -> None: + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/state_changes.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/state_changes.py new file mode 100644 index 0000000..56963c6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/state_changes.py @@ -0,0 +1,198 @@ +# orm/state_changes.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""State tracking utilities used by :class:`_orm.Session`. + +""" + +from __future__ import annotations + +import contextlib +from enum import Enum +from typing import Any +from typing import Callable +from typing import cast +from typing import Iterator +from typing import NoReturn +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union + +from .. import exc as sa_exc +from .. import util +from ..util.typing import Literal + +_F = TypeVar("_F", bound=Callable[..., Any]) + + +class _StateChangeState(Enum): + pass + + +class _StateChangeStates(_StateChangeState): + ANY = 1 + NO_CHANGE = 2 + CHANGE_IN_PROGRESS = 3 + + +class _StateChange: + """Supplies state assertion decorators. + + The current use case is for the :class:`_orm.SessionTransaction` class. The + :class:`_StateChange` class itself is agnostic of the + :class:`_orm.SessionTransaction` class so could in theory be generalized + for other systems as well. + + """ + + _next_state: _StateChangeState = _StateChangeStates.ANY + _state: _StateChangeState = _StateChangeStates.NO_CHANGE + _current_fn: Optional[Callable[..., Any]] = None + + def _raise_for_prerequisite_state( + self, operation_name: str, state: _StateChangeState + ) -> NoReturn: + raise sa_exc.IllegalStateChangeError( + f"Can't run operation '{operation_name}()' when Session " + f"is in state {state!r}", + code="isce", + ) + + @classmethod + def declare_states( + cls, + prerequisite_states: Union[ + Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...] + ], + moves_to: _StateChangeState, + ) -> Callable[[_F], _F]: + """Method decorator declaring valid states. + + :param prerequisite_states: sequence of acceptable prerequisite + states. Can be the single constant _State.ANY to indicate no + prerequisite state + + :param moves_to: the expected state at the end of the method, assuming + no exceptions raised. Can be the constant _State.NO_CHANGE to + indicate state should not change at the end of the method. + + """ + assert prerequisite_states, "no prequisite states sent" + has_prerequisite_states = ( + prerequisite_states is not _StateChangeStates.ANY + ) + + prerequisite_state_collection = cast( + "Tuple[_StateChangeState, ...]", prerequisite_states + ) + expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE + + @util.decorator + def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any: + current_state = self._state + + if ( + has_prerequisite_states + and current_state not in prerequisite_state_collection + ): + self._raise_for_prerequisite_state(fn.__name__, current_state) + + next_state = self._next_state + existing_fn = self._current_fn + expect_state = moves_to if expect_state_change else current_state + + if ( + # destination states are restricted + next_state is not _StateChangeStates.ANY + # method seeks to change state + and expect_state_change + # destination state incorrect + and next_state is not expect_state + ): + if existing_fn and next_state in ( + _StateChangeStates.NO_CHANGE, + _StateChangeStates.CHANGE_IN_PROGRESS, + ): + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' can't be called here; " + f"method '{existing_fn.__name__}()' is already " + f"in progress and this would cause an unexpected " + f"state change to {moves_to!r}", + code="isce", + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Cant run operation '{fn.__name__}()' here; " + f"will move to state {moves_to!r} where we are " + f"expecting {next_state!r}", + code="isce", + ) + + self._current_fn = fn + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS + try: + ret_value = fn(self, *arg, **kw) + except: + raise + else: + if self._state is expect_state: + return ret_value + + if self._state is current_state: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' failed to " + "change state " + f"to {moves_to!r} as expected", + code="isce", + ) + elif existing_fn: + raise sa_exc.IllegalStateChangeError( + f"While method '{existing_fn.__name__}()' was " + "running, " + f"method '{fn.__name__}()' caused an " + "unexpected " + f"state change to {self._state!r}", + code="isce", + ) + else: + raise sa_exc.IllegalStateChangeError( + f"Method '{fn.__name__}()' caused an unexpected " + f"state change to {self._state!r}", + code="isce", + ) + + finally: + self._next_state = next_state + self._current_fn = existing_fn + + return _go + + @contextlib.contextmanager + def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]: + """called within a method that changes states. + + method must also use the ``@declare_states()`` decorator. + + """ + assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, ( + "Unexpected call to _expect_state outside of " + "state-changing method" + ) + + self._next_state = expected + try: + yield + except: + raise + else: + if self._state is not expected: + raise sa_exc.IllegalStateChangeError( + f"Unexpected state change to {self._state!r}", code="isce" + ) + finally: + self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/strategies.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/strategies.py new file mode 100644 index 0000000..20c3b9c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/strategies.py @@ -0,0 +1,3344 @@ +# orm/strategies.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""sqlalchemy.orm.interfaces.LoaderStrategy + implementations, and related MapperOptions.""" + +from __future__ import annotations + +import collections +import itertools +from typing import Any +from typing import Dict +from typing import Tuple +from typing import TYPE_CHECKING + +from . import attributes +from . import exc as orm_exc +from . import interfaces +from . import loading +from . import path_registry +from . import properties +from . import query +from . import relationships +from . import unitofwork +from . import util as orm_util +from .base import _DEFER_FOR_STATE +from .base import _RAISE_FOR_STATE +from .base import _SET_DEFERRED_EXPIRED +from .base import ATTR_WAS_SET +from .base import LoaderCallableStatus +from .base import PASSIVE_OFF +from .base import PassiveFlag +from .context import _column_descriptions +from .context import ORMCompileState +from .context import ORMSelectCompileState +from .context import QueryContext +from .interfaces import LoaderStrategy +from .interfaces import StrategizedProperty +from .session import _state_session +from .state import InstanceState +from .strategy_options import Load +from .util import _none_set +from .util import AliasedClass +from .. import event +from .. import exc as sa_exc +from .. import inspect +from .. import log +from .. import sql +from .. import util +from ..sql import util as sql_util +from ..sql import visitors +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL +from ..sql.selectable import Select + +if TYPE_CHECKING: + from .relationships import RelationshipProperty + from ..sql.elements import ColumnElement + + +def _register_attribute( + prop, + mapper, + useobject, + compare_function=None, + typecallable=None, + callable_=None, + proxy_property=None, + active_history=False, + impl_class=None, + **kw, +): + listen_hooks = [] + + uselist = useobject and prop.uselist + + if useobject and prop.single_parent: + listen_hooks.append(single_parent_validator) + + if prop.key in prop.parent.validators: + fn, opts = prop.parent.validators[prop.key] + listen_hooks.append( + lambda desc, prop: orm_util._validator_events( + desc, prop.key, fn, **opts + ) + ) + + if useobject: + listen_hooks.append(unitofwork.track_cascade_events) + + # need to assemble backref listeners + # after the singleparentvalidator, mapper validator + if useobject: + backref = prop.back_populates + if backref and prop._effective_sync_backref: + listen_hooks.append( + lambda desc, prop: attributes.backref_listeners( + desc, backref, uselist + ) + ) + + # a single MapperProperty is shared down a class inheritance + # hierarchy, so we set up attribute instrumentation and backref event + # for each mapper down the hierarchy. + + # typically, "mapper" is the same as prop.parent, due to the way + # the configure_mappers() process runs, however this is not strongly + # enforced, and in the case of a second configure_mappers() run the + # mapper here might not be prop.parent; also, a subclass mapper may + # be called here before a superclass mapper. That is, can't depend + # on mappers not already being set up so we have to check each one. + + for m in mapper.self_and_descendants: + if prop is m._props.get( + prop.key + ) and not m.class_manager._attr_has_impl(prop.key): + desc = attributes.register_attribute_impl( + m.class_, + prop.key, + parent_token=prop, + uselist=uselist, + compare_function=compare_function, + useobject=useobject, + trackparent=useobject + and ( + prop.single_parent + or prop.direction is interfaces.ONETOMANY + ), + typecallable=typecallable, + callable_=callable_, + active_history=active_history, + impl_class=impl_class, + send_modified_events=not useobject or not prop.viewonly, + doc=prop.doc, + **kw, + ) + + for hook in listen_hooks: + hook(desc, prop) + + +@properties.ColumnProperty.strategy_for(instrument=False, deferred=False) +class UninstrumentedColumnLoader(LoaderStrategy): + """Represent a non-instrumented MapperProperty. + + The polymorphic_on argument of mapper() often results in this, + if the argument is against the with_polymorphic selectable. + + """ + + __slots__ = ("columns",) + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.columns = self.parent_property.columns + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection=None, + **kwargs, + ): + for c in self.columns: + if adapter: + c = adapter.columns[c] + compile_state._append_dedupe_col_collection(c, column_collection) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + pass + + +@log.class_logger +@properties.ColumnProperty.strategy_for(instrument=True, deferred=False) +class ColumnLoader(LoaderStrategy): + """Provide loading behavior for a :class:`.ColumnProperty`.""" + + __slots__ = "columns", "is_composite" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.columns = self.parent_property.columns + self.is_composite = hasattr(self.parent_property, "composite_class") + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + check_for_adapt=False, + **kwargs, + ): + for c in self.columns: + if adapter: + if check_for_adapt: + c = adapter.adapt_check_present(c) + if c is None: + return + else: + c = adapter.columns[c] + + compile_state._append_dedupe_col_collection(c, column_collection) + + fetch = self.columns[0] + if adapter: + fetch = adapter.columns[fetch] + if fetch is None: + # None happens here only for dml bulk_persistence cases + # when context.DMLReturningColFilter is used + return + + memoized_populators[self.parent_property] = fetch + + def init_class_attribute(self, mapper): + self.is_class_level = True + coltype = self.columns[0].type + # TODO: check all columns ? check for foreign key as well? + active_history = ( + self.parent_property.active_history + or self.columns[0].primary_key + or ( + mapper.version_id_col is not None + and mapper._columntoproperty.get(mapper.version_id_col, None) + is self.parent_property + ) + ) + + _register_attribute( + self.parent_property, + mapper, + useobject=False, + compare_function=coltype.compare_values, + active_history=active_history, + ) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + # look through list of columns represented here + # to see which, if any, is present in the row. + + for col in self.columns: + if adapter: + col = adapter.columns[col] + getter = result._getter(col, False) + if getter: + populators["quick"].append((self.key, getter)) + break + else: + populators["expire"].append((self.key, True)) + + +@log.class_logger +@properties.ColumnProperty.strategy_for(query_expression=True) +class ExpressionColumnLoader(ColumnLoader): + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + + # compare to the "default" expression that is mapped in + # the column. If it's sql.null, we don't need to render + # unless an expr is passed in the options. + null = sql.null().label(None) + self._have_default_expression = any( + not c.compare(null) for c in self.parent_property.columns + ) + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + **kwargs, + ): + columns = None + if loadopt and loadopt._extra_criteria: + columns = loadopt._extra_criteria + + elif self._have_default_expression: + columns = self.parent_property.columns + + if columns is None: + return + + for c in columns: + if adapter: + c = adapter.columns[c] + compile_state._append_dedupe_col_collection(c, column_collection) + + fetch = columns[0] + if adapter: + fetch = adapter.columns[fetch] + if fetch is None: + # None is not expected to be the result of any + # adapter implementation here, however there may be theoretical + # usages of returning() with context.DMLReturningColFilter + return + + memoized_populators[self.parent_property] = fetch + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + # look through list of columns represented here + # to see which, if any, is present in the row. + if loadopt and loadopt._extra_criteria: + columns = loadopt._extra_criteria + + for col in columns: + if adapter: + col = adapter.columns[col] + getter = result._getter(col, False) + if getter: + populators["quick"].append((self.key, getter)) + break + else: + populators["expire"].append((self.key, True)) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute( + self.parent_property, + mapper, + useobject=False, + compare_function=self.columns[0].type.compare_values, + accepts_scalar_loader=False, + ) + + +@log.class_logger +@properties.ColumnProperty.strategy_for(deferred=True, instrument=True) +@properties.ColumnProperty.strategy_for( + deferred=True, instrument=True, raiseload=True +) +@properties.ColumnProperty.strategy_for(do_nothing=True) +class DeferredColumnLoader(LoaderStrategy): + """Provide loading behavior for a deferred :class:`.ColumnProperty`.""" + + __slots__ = "columns", "group", "raiseload" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + if hasattr(self.parent_property, "composite_class"): + raise NotImplementedError( + "Deferred loading for composite types not implemented yet" + ) + self.raiseload = self.strategy_opts.get("raiseload", False) + self.columns = self.parent_property.columns + self.group = self.parent_property.group + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + # for a DeferredColumnLoader, this method is only used during a + # "row processor only" query; see test_deferred.py -> + # tests with "rowproc_only" in their name. As of the 1.0 series, + # loading._instance_processor doesn't use a "row processing" function + # to populate columns, instead it uses data in the "populators" + # dictionary. Normally, the DeferredColumnLoader.setup_query() + # sets up that data in the "memoized_populators" dictionary + # and "create_row_processor()" here is never invoked. + + if ( + context.refresh_state + and context.query._compile_options._only_load_props + and self.key in context.query._compile_options._only_load_props + ): + self.parent_property._get_strategy( + (("deferred", False), ("instrument", True)) + ).create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + elif not self.is_class_level: + if self.raiseload: + set_deferred_for_local_state = ( + self.parent_property._raise_column_loader + ) + else: + set_deferred_for_local_state = ( + self.parent_property._deferred_column_loader + ) + populators["new"].append((self.key, set_deferred_for_local_state)) + else: + populators["expire"].append((self.key, False)) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute( + self.parent_property, + mapper, + useobject=False, + compare_function=self.columns[0].type.compare_values, + callable_=self._load_for_state, + load_on_unexpire=False, + ) + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + only_load_props=None, + **kw, + ): + if ( + ( + compile_state.compile_options._render_for_subquery + and self.parent_property._renders_in_subqueries + ) + or ( + loadopt + and set(self.columns).intersection( + self.parent._should_undefer_in_wildcard + ) + ) + or ( + loadopt + and self.group + and loadopt.local_opts.get( + "undefer_group_%s" % self.group, False + ) + ) + or (only_load_props and self.key in only_load_props) + ): + self.parent_property._get_strategy( + (("deferred", False), ("instrument", True)) + ).setup_query( + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + memoized_populators, + **kw, + ) + elif self.is_class_level: + memoized_populators[self.parent_property] = _SET_DEFERRED_EXPIRED + elif not self.raiseload: + memoized_populators[self.parent_property] = _DEFER_FOR_STATE + else: + memoized_populators[self.parent_property] = _RAISE_FOR_STATE + + def _load_for_state(self, state, passive): + if not state.key: + return LoaderCallableStatus.ATTR_EMPTY + + if not passive & PassiveFlag.SQL_OK: + return LoaderCallableStatus.PASSIVE_NO_RESULT + + localparent = state.manager.mapper + + if self.group: + toload = [ + p.key + for p in localparent.iterate_properties + if isinstance(p, StrategizedProperty) + and isinstance(p.strategy, DeferredColumnLoader) + and p.group == self.group + ] + else: + toload = [self.key] + + # narrow the keys down to just those which have no history + group = [k for k in toload if k in state.unmodified] + + session = _state_session(state) + if session is None: + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "deferred load operation of attribute '%s' cannot proceed" + % (orm_util.state_str(state), self.key) + ) + + if self.raiseload: + self._invoke_raise_load(state, passive, "raise") + + loading.load_scalar_attributes( + state.mapper, state, set(group), PASSIVE_OFF + ) + + return LoaderCallableStatus.ATTR_WAS_SET + + def _invoke_raise_load(self, state, passive, lazy): + raise sa_exc.InvalidRequestError( + "'%s' is not available due to raiseload=True" % (self,) + ) + + +class LoadDeferredColumns: + """serializable loader object used by DeferredColumnLoader""" + + def __init__(self, key: str, raiseload: bool = False): + self.key = key + self.raiseload = raiseload + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + + localparent = state.manager.mapper + prop = localparent._props[key] + if self.raiseload: + strategy_key = ( + ("deferred", True), + ("instrument", True), + ("raiseload", True), + ) + else: + strategy_key = (("deferred", True), ("instrument", True)) + strategy = prop._get_strategy(strategy_key) + return strategy._load_for_state(state, passive) + + +class AbstractRelationshipLoader(LoaderStrategy): + """LoaderStratgies which deal with related objects.""" + + __slots__ = "mapper", "target", "uselist", "entity" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.mapper = self.parent_property.mapper + self.entity = self.parent_property.entity + self.target = self.parent_property.target + self.uselist = self.parent_property.uselist + + def _immediateload_create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + return self.parent_property._get_strategy( + (("lazy", "immediate"),) + ).create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(do_nothing=True) +class DoNothingLoader(LoaderStrategy): + """Relationship loader that makes no change to the object's state. + + Compared to NoLoader, this loader does not initialize the + collection/attribute to empty/none; the usual default LazyLoader will + take effect. + + """ + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="noload") +@relationships.RelationshipProperty.strategy_for(lazy=None) +class NoLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.Relationship` + with "lazy=None". + + """ + + __slots__ = () + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _register_attribute( + self.parent_property, + mapper, + useobject=True, + typecallable=self.parent_property.collection_class, + ) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + def invoke_no_load(state, dict_, row): + if self.uselist: + attributes.init_state_collection(state, dict_, self.key) + else: + dict_[self.key] = None + + populators["new"].append((self.key, invoke_no_load)) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy=True) +@relationships.RelationshipProperty.strategy_for(lazy="select") +@relationships.RelationshipProperty.strategy_for(lazy="raise") +@relationships.RelationshipProperty.strategy_for(lazy="raise_on_sql") +@relationships.RelationshipProperty.strategy_for(lazy="baked_select") +class LazyLoader( + AbstractRelationshipLoader, util.MemoizedSlots, log.Identified +): + """Provide loading behavior for a :class:`.Relationship` + with "lazy=True", that is loads when first accessed. + + """ + + __slots__ = ( + "_lazywhere", + "_rev_lazywhere", + "_lazyload_reverse_option", + "_order_by", + "use_get", + "is_aliased_class", + "_bind_to_col", + "_equated_columns", + "_rev_bind_to_col", + "_rev_equated_columns", + "_simple_lazy_clause", + "_raise_always", + "_raise_on_sql", + ) + + _lazywhere: ColumnElement[bool] + _bind_to_col: Dict[str, ColumnElement[Any]] + _rev_lazywhere: ColumnElement[bool] + _rev_bind_to_col: Dict[str, ColumnElement[Any]] + + parent_property: RelationshipProperty[Any] + + def __init__( + self, parent: RelationshipProperty[Any], strategy_key: Tuple[Any, ...] + ): + super().__init__(parent, strategy_key) + self._raise_always = self.strategy_opts["lazy"] == "raise" + self._raise_on_sql = self.strategy_opts["lazy"] == "raise_on_sql" + + self.is_aliased_class = inspect(self.entity).is_aliased_class + + join_condition = self.parent_property._join_condition + ( + self._lazywhere, + self._bind_to_col, + self._equated_columns, + ) = join_condition.create_lazy_clause() + + ( + self._rev_lazywhere, + self._rev_bind_to_col, + self._rev_equated_columns, + ) = join_condition.create_lazy_clause(reverse_direction=True) + + if self.parent_property.order_by: + self._order_by = [ + sql_util._deep_annotate(elem, {"_orm_adapt": True}) + for elem in util.to_list(self.parent_property.order_by) + ] + else: + self._order_by = None + + self.logger.info("%s lazy loading clause %s", self, self._lazywhere) + + # determine if our "lazywhere" clause is the same as the mapper's + # get() clause. then we can just use mapper.get() + # + # TODO: the "not self.uselist" can be taken out entirely; a m2o + # load that populates for a list (very unusual, but is possible with + # the API) can still set for "None" and the attribute system will + # populate as an empty list. + self.use_get = ( + not self.is_aliased_class + and not self.uselist + and self.entity._get_clause[0].compare( + self._lazywhere, + use_proxies=True, + compare_keys=False, + equivalents=self.mapper._equivalent_columns, + ) + ) + + if self.use_get: + for col in list(self._equated_columns): + if col in self.mapper._equivalent_columns: + for c in self.mapper._equivalent_columns[col]: + self._equated_columns[c] = self._equated_columns[col] + + self.logger.info( + "%s will use Session.get() to optimize instance loads", self + ) + + def init_class_attribute(self, mapper): + self.is_class_level = True + + _legacy_inactive_history_style = ( + self.parent_property._legacy_inactive_history_style + ) + + if self.parent_property.active_history: + active_history = True + _deferred_history = False + + elif ( + self.parent_property.direction is not interfaces.MANYTOONE + or not self.use_get + ): + if _legacy_inactive_history_style: + active_history = True + _deferred_history = False + else: + active_history = False + _deferred_history = True + else: + active_history = _deferred_history = False + + _register_attribute( + self.parent_property, + mapper, + useobject=True, + callable_=self._load_for_state, + typecallable=self.parent_property.collection_class, + active_history=active_history, + _deferred_history=_deferred_history, + ) + + def _memoized_attr__simple_lazy_clause(self): + lazywhere = sql_util._deep_annotate( + self._lazywhere, {"_orm_adapt": True} + ) + + criterion, bind_to_col = (lazywhere, self._bind_to_col) + + params = [] + + def visit_bindparam(bindparam): + bindparam.unique = False + + visitors.traverse(criterion, {}, {"bindparam": visit_bindparam}) + + def visit_bindparam(bindparam): + if bindparam._identifying_key in bind_to_col: + params.append( + ( + bindparam.key, + bind_to_col[bindparam._identifying_key], + None, + ) + ) + elif bindparam.callable is None: + params.append((bindparam.key, None, bindparam.value)) + + criterion = visitors.cloned_traverse( + criterion, {}, {"bindparam": visit_bindparam} + ) + + return criterion, params + + def _generate_lazy_clause(self, state, passive): + criterion, param_keys = self._simple_lazy_clause + + if state is None: + return sql_util.adapt_criterion_to_null( + criterion, [key for key, ident, value in param_keys] + ) + + mapper = self.parent_property.parent + + o = state.obj() # strong ref + dict_ = attributes.instance_dict(o) + + if passive & PassiveFlag.INIT_OK: + passive ^= PassiveFlag.INIT_OK + + params = {} + for key, ident, value in param_keys: + if ident is not None: + if passive and passive & PassiveFlag.LOAD_AGAINST_COMMITTED: + value = mapper._get_committed_state_attr_by_column( + state, dict_, ident, passive + ) + else: + value = mapper._get_state_attr_by_column( + state, dict_, ident, passive + ) + + params[key] = value + + return criterion, params + + def _invoke_raise_load(self, state, passive, lazy): + raise sa_exc.InvalidRequestError( + "'%s' is not available due to lazy='%s'" % (self, lazy) + ) + + def _load_for_state( + self, + state, + passive, + loadopt=None, + extra_criteria=(), + extra_options=(), + alternate_effective_path=None, + execution_options=util.EMPTY_DICT, + ): + if not state.key and ( + ( + not self.parent_property.load_on_pending + and not state._load_pending + ) + or not state.session_id + ): + return LoaderCallableStatus.ATTR_EMPTY + + pending = not state.key + primary_key_identity = None + + use_get = self.use_get and (not loadopt or not loadopt._extra_criteria) + + if (not passive & PassiveFlag.SQL_OK and not use_get) or ( + not passive & attributes.NON_PERSISTENT_OK and pending + ): + return LoaderCallableStatus.PASSIVE_NO_RESULT + + if ( + # we were given lazy="raise" + self._raise_always + # the no_raise history-related flag was not passed + and not passive & PassiveFlag.NO_RAISE + and ( + # if we are use_get and related_object_ok is disabled, + # which means we are at most looking in the identity map + # for history purposes or otherwise returning + # PASSIVE_NO_RESULT, don't raise. This is also a + # history-related flag + not use_get + or passive & PassiveFlag.RELATED_OBJECT_OK + ) + ): + self._invoke_raise_load(state, passive, "raise") + + session = _state_session(state) + if not session: + if passive & PassiveFlag.NO_RAISE: + return LoaderCallableStatus.PASSIVE_NO_RESULT + + raise orm_exc.DetachedInstanceError( + "Parent instance %s is not bound to a Session; " + "lazy load operation of attribute '%s' cannot proceed" + % (orm_util.state_str(state), self.key) + ) + + # if we have a simple primary key load, check the + # identity map without generating a Query at all + if use_get: + primary_key_identity = self._get_ident_for_use_get( + session, state, passive + ) + if LoaderCallableStatus.PASSIVE_NO_RESULT in primary_key_identity: + return LoaderCallableStatus.PASSIVE_NO_RESULT + elif LoaderCallableStatus.NEVER_SET in primary_key_identity: + return LoaderCallableStatus.NEVER_SET + + if _none_set.issuperset(primary_key_identity): + return None + + if ( + self.key in state.dict + and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD + ): + return LoaderCallableStatus.ATTR_WAS_SET + + # look for this identity in the identity map. Delegate to the + # Query class in use, as it may have special rules for how it + # does this, including how it decides what the correct + # identity_token would be for this identity. + + instance = session._identity_lookup( + self.entity, + primary_key_identity, + passive=passive, + lazy_loaded_from=state, + ) + + if instance is not None: + if instance is LoaderCallableStatus.PASSIVE_CLASS_MISMATCH: + return None + else: + return instance + elif ( + not passive & PassiveFlag.SQL_OK + or not passive & PassiveFlag.RELATED_OBJECT_OK + ): + return LoaderCallableStatus.PASSIVE_NO_RESULT + + return self._emit_lazyload( + session, + state, + primary_key_identity, + passive, + loadopt, + extra_criteria, + extra_options, + alternate_effective_path, + execution_options, + ) + + def _get_ident_for_use_get(self, session, state, passive): + instance_mapper = state.manager.mapper + + if passive & PassiveFlag.LOAD_AGAINST_COMMITTED: + get_attr = instance_mapper._get_committed_state_attr_by_column + else: + get_attr = instance_mapper._get_state_attr_by_column + + dict_ = state.dict + + return [ + get_attr(state, dict_, self._equated_columns[pk], passive=passive) + for pk in self.mapper.primary_key + ] + + @util.preload_module("sqlalchemy.orm.strategy_options") + def _emit_lazyload( + self, + session, + state, + primary_key_identity, + passive, + loadopt, + extra_criteria, + extra_options, + alternate_effective_path, + execution_options, + ): + strategy_options = util.preloaded.orm_strategy_options + + clauseelement = self.entity.__clause_element__() + stmt = Select._create_raw_select( + _raw_columns=[clauseelement], + _propagate_attrs=clauseelement._propagate_attrs, + _label_style=LABEL_STYLE_TABLENAME_PLUS_COL, + _compile_options=ORMCompileState.default_compile_options, + ) + load_options = QueryContext.default_load_options + + load_options += { + "_invoke_all_eagers": False, + "_lazy_loaded_from": state, + } + + if self.parent_property.secondary is not None: + stmt = stmt.select_from( + self.mapper, self.parent_property.secondary + ) + + pending = not state.key + + # don't autoflush on pending + if pending or passive & attributes.NO_AUTOFLUSH: + stmt._execution_options = util.immutabledict({"autoflush": False}) + + use_get = self.use_get + + if state.load_options or (loadopt and loadopt._extra_criteria): + if alternate_effective_path is None: + effective_path = state.load_path[self.parent_property] + else: + effective_path = alternate_effective_path[self.parent_property] + + opts = state.load_options + + if loadopt and loadopt._extra_criteria: + use_get = False + opts += ( + orm_util.LoaderCriteriaOption(self.entity, extra_criteria), + ) + + stmt._with_options = opts + elif alternate_effective_path is None: + # this path is used if there are not already any options + # in the query, but an event may want to add them + effective_path = state.mapper._path_registry[self.parent_property] + else: + # added by immediateloader + effective_path = alternate_effective_path[self.parent_property] + + if extra_options: + stmt._with_options += extra_options + + stmt._compile_options += {"_current_path": effective_path} + + if use_get: + if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE: + self._invoke_raise_load(state, passive, "raise_on_sql") + + return loading.load_on_pk_identity( + session, + stmt, + primary_key_identity, + load_options=load_options, + execution_options=execution_options, + ) + + if self._order_by: + stmt._order_by_clauses = self._order_by + + def _lazyload_reverse(compile_context): + for rev in self.parent_property._reverse_property: + # reverse props that are MANYTOONE are loading *this* + # object from get(), so don't need to eager out to those. + if ( + rev.direction is interfaces.MANYTOONE + and rev._use_get + and not isinstance(rev.strategy, LazyLoader) + ): + strategy_options.Load._construct_for_existing_path( + compile_context.compile_options._current_path[ + rev.parent + ] + ).lazyload(rev).process_compile_state(compile_context) + + stmt._with_context_options += ( + (_lazyload_reverse, self.parent_property), + ) + + lazy_clause, params = self._generate_lazy_clause(state, passive) + + if execution_options: + execution_options = util.EMPTY_DICT.merge_with( + execution_options, + { + "_sa_orm_load_options": load_options, + }, + ) + else: + execution_options = { + "_sa_orm_load_options": load_options, + } + + if ( + self.key in state.dict + and not passive & PassiveFlag.DEFERRED_HISTORY_LOAD + ): + return LoaderCallableStatus.ATTR_WAS_SET + + if pending: + if util.has_intersection(orm_util._none_set, params.values()): + return None + + elif util.has_intersection(orm_util._never_set, params.values()): + return None + + if self._raise_on_sql and not passive & PassiveFlag.NO_RAISE: + self._invoke_raise_load(state, passive, "raise_on_sql") + + stmt._where_criteria = (lazy_clause,) + + result = session.execute( + stmt, params, execution_options=execution_options + ) + + result = result.unique().scalars().all() + + if self.uselist: + return result + else: + l = len(result) + if l: + if l > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for lazily-loaded attribute '%s' " + % self.parent_property + ) + + return result[0] + else: + return None + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + key = self.key + + if ( + context.load_options._is_user_refresh + and context.query._compile_options._only_load_props + and self.key in context.query._compile_options._only_load_props + ): + return self._immediateload_create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + if not self.is_class_level or (loadopt and loadopt._extra_criteria): + # we are not the primary manager for this attribute + # on this class - set up a + # per-instance lazyloader, which will override the + # class-level behavior. + # this currently only happens when using a + # "lazyload" option on a "no load" + # attribute - "eager" attributes always have a + # class-level lazyloader installed. + set_lazy_callable = ( + InstanceState._instance_level_callable_processor + )( + mapper.class_manager, + LoadLazyAttribute( + key, + self, + loadopt, + ( + loadopt._generate_extra_criteria(context) + if loadopt._extra_criteria + else None + ), + ), + key, + ) + + populators["new"].append((self.key, set_lazy_callable)) + elif context.populate_existing or mapper.always_refresh: + + def reset_for_lazy_callable(state, dict_, row): + # we are the primary manager for this attribute on + # this class - reset its + # per-instance attribute state, so that the class-level + # lazy loader is + # executed when next referenced on this instance. + # this is needed in + # populate_existing() types of scenarios to reset + # any existing state. + state._reset(dict_, key) + + populators["new"].append((self.key, reset_for_lazy_callable)) + + +class LoadLazyAttribute: + """semi-serializable loader object used by LazyLoader + + Historically, this object would be carried along with instances that + needed to run lazyloaders, so it had to be serializable to support + cached instances. + + this is no longer a general requirement, and the case where this object + is used is exactly the case where we can't really serialize easily, + which is when extra criteria in the loader option is present. + + We can't reliably serialize that as it refers to mapped entities and + AliasedClass objects that are local to the current process, which would + need to be matched up on deserialize e.g. the sqlalchemy.ext.serializer + approach. + + """ + + def __init__(self, key, initiating_strategy, loadopt, extra_criteria): + self.key = key + self.strategy_key = initiating_strategy.strategy_key + self.loadopt = loadopt + self.extra_criteria = extra_criteria + + def __getstate__(self): + if self.extra_criteria is not None: + util.warn( + "Can't reliably serialize a lazyload() option that " + "contains additional criteria; please use eager loading " + "for this case" + ) + return { + "key": self.key, + "strategy_key": self.strategy_key, + "loadopt": self.loadopt, + "extra_criteria": (), + } + + def __call__(self, state, passive=attributes.PASSIVE_OFF): + key = self.key + instance_mapper = state.manager.mapper + prop = instance_mapper._props[key] + strategy = prop._strategies[self.strategy_key] + + return strategy._load_for_state( + state, + passive, + loadopt=self.loadopt, + extra_criteria=self.extra_criteria, + ) + + +class PostLoader(AbstractRelationshipLoader): + """A relationship loader that emits a second SELECT statement.""" + + __slots__ = () + + def _setup_for_recursion(self, context, path, loadopt, join_depth=None): + effective_path = ( + context.compile_state.current_path or orm_util.PathRegistry.root + ) + path + + top_level_context = context._get_top_level_context() + execution_options = util.immutabledict( + {"sa_top_level_orm_context": top_level_context} + ) + + if loadopt: + recursion_depth = loadopt.local_opts.get("recursion_depth", None) + unlimited_recursion = recursion_depth == -1 + else: + recursion_depth = None + unlimited_recursion = False + + if recursion_depth is not None: + if not self.parent_property._is_self_referential: + raise sa_exc.InvalidRequestError( + f"recursion_depth option on relationship " + f"{self.parent_property} not valid for " + "non-self-referential relationship" + ) + recursion_depth = context.execution_options.get( + f"_recursion_depth_{id(self)}", recursion_depth + ) + + if not unlimited_recursion and recursion_depth < 0: + return ( + effective_path, + False, + execution_options, + recursion_depth, + ) + + if not unlimited_recursion: + execution_options = execution_options.union( + { + f"_recursion_depth_{id(self)}": recursion_depth - 1, + } + ) + + if loading.PostLoad.path_exists( + context, effective_path, self.parent_property + ): + return effective_path, False, execution_options, recursion_depth + + path_w_prop = path[self.parent_property] + effective_path_w_prop = effective_path[self.parent_property] + + if not path_w_prop.contains(context.attributes, "loader"): + if join_depth: + if effective_path_w_prop.length / 2 > join_depth: + return ( + effective_path, + False, + execution_options, + recursion_depth, + ) + elif effective_path_w_prop.contains_mapper(self.mapper): + return ( + effective_path, + False, + execution_options, + recursion_depth, + ) + + return effective_path, True, execution_options, recursion_depth + + +@relationships.RelationshipProperty.strategy_for(lazy="immediate") +class ImmediateLoader(PostLoader): + __slots__ = ("join_depth",) + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + ( + effective_path, + run_loader, + execution_options, + recursion_depth, + ) = self._setup_for_recursion(context, path, loadopt, self.join_depth) + if not run_loader: + # this will not emit SQL and will only emit for a many-to-one + # "use get" load. the "_RELATED" part means it may return + # instance even if its expired, since this is a mutually-recursive + # load operation. + flags = attributes.PASSIVE_NO_FETCH_RELATED | PassiveFlag.NO_RAISE + else: + flags = attributes.PASSIVE_OFF | PassiveFlag.NO_RAISE + + loading.PostLoad.callable_for_path( + context, + effective_path, + self.parent, + self.parent_property, + self._load_for_path, + loadopt, + flags, + recursion_depth, + execution_options, + ) + + def _load_for_path( + self, + context, + path, + states, + load_only, + loadopt, + flags, + recursion_depth, + execution_options, + ): + if recursion_depth: + new_opt = Load(loadopt.path.entity) + new_opt.context = ( + loadopt, + loadopt._recurse(), + ) + alternate_effective_path = path._truncate_recursive() + extra_options = (new_opt,) + else: + new_opt = None + alternate_effective_path = path + extra_options = () + + key = self.key + lazyloader = self.parent_property._get_strategy((("lazy", "select"),)) + for state, overwrite in states: + dict_ = state.dict + + if overwrite or key not in dict_: + value = lazyloader._load_for_state( + state, + flags, + extra_options=extra_options, + alternate_effective_path=alternate_effective_path, + execution_options=execution_options, + ) + if value not in ( + ATTR_WAS_SET, + LoaderCallableStatus.PASSIVE_NO_RESULT, + ): + state.get_impl(key).set_committed_value( + state, dict_, value + ) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="subquery") +class SubqueryLoader(PostLoader): + __slots__ = ("join_depth",) + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def _get_leftmost( + self, + orig_query_entity_index, + subq_path, + current_compile_state, + is_root, + ): + given_subq_path = subq_path + subq_path = subq_path.path + subq_mapper = orm_util._class_to_mapper(subq_path[0]) + + # determine attributes of the leftmost mapper + if ( + self.parent.isa(subq_mapper) + and self.parent_property is subq_path[1] + ): + leftmost_mapper, leftmost_prop = self.parent, self.parent_property + else: + leftmost_mapper, leftmost_prop = subq_mapper, subq_path[1] + + if is_root: + # the subq_path is also coming from cached state, so when we start + # building up this path, it has to also be converted to be in terms + # of the current state. this is for the specific case of the entity + # is an AliasedClass against a subquery that's not otherwise going + # to adapt + new_subq_path = current_compile_state._entities[ + orig_query_entity_index + ].entity_zero._path_registry[leftmost_prop] + additional = len(subq_path) - len(new_subq_path) + if additional: + new_subq_path += path_registry.PathRegistry.coerce( + subq_path[-additional:] + ) + else: + new_subq_path = given_subq_path + + leftmost_cols = leftmost_prop.local_columns + + leftmost_attr = [ + getattr( + new_subq_path.path[0].entity, + leftmost_mapper._columntoproperty[c].key, + ) + for c in leftmost_cols + ] + + return leftmost_mapper, leftmost_attr, leftmost_prop, new_subq_path + + def _generate_from_original_query( + self, + orig_compile_state, + orig_query, + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + orig_entity, + ): + # reformat the original query + # to look only for significant columns + q = orig_query._clone().correlate(None) + + # LEGACY: make a Query back from the select() !! + # This suits at least two legacy cases: + # 1. applications which expect before_compile() to be called + # below when we run .subquery() on this query (Keystone) + # 2. applications which are doing subqueryload with complex + # from_self() queries, as query.subquery() / .statement + # has to do the full compile context for multiply-nested + # from_self() (Neutron) - see test_subqload_from_self + # for demo. + q2 = query.Query.__new__(query.Query) + q2.__dict__.update(q.__dict__) + q = q2 + + # set the query's "FROM" list explicitly to what the + # FROM list would be in any case, as we will be limiting + # the columns in the SELECT list which may no longer include + # all entities mentioned in things like WHERE, JOIN, etc. + if not q._from_obj: + q._enable_assertions = False + q.select_from.non_generative( + q, + *{ + ent["entity"] + for ent in _column_descriptions( + orig_query, compile_state=orig_compile_state + ) + if ent["entity"] is not None + }, + ) + + # select from the identity columns of the outer (specifically, these + # are the 'local_cols' of the property). This will remove other + # columns from the query that might suggest the right entity which is + # why we do set select_from above. The attributes we have are + # coerced and adapted using the original query's adapter, which is + # needed only for the case of adapting a subclass column to + # that of a polymorphic selectable, e.g. we have + # Engineer.primary_language and the entity is Person. All other + # adaptations, e.g. from_self, select_entity_from(), will occur + # within the new query when it compiles, as the compile_state we are + # using here is only a partial one. If the subqueryload is from a + # with_polymorphic() or other aliased() object, left_attr will already + # be the correct attributes so no adaptation is needed. + target_cols = orig_compile_state._adapt_col_list( + [ + sql.coercions.expect(sql.roles.ColumnsClauseRole, o) + for o in leftmost_attr + ], + orig_compile_state._get_current_adapter(), + ) + q._raw_columns = target_cols + + distinct_target_key = leftmost_relationship.distinct_target_key + + if distinct_target_key is True: + q._distinct = True + elif distinct_target_key is None: + # if target_cols refer to a non-primary key or only + # part of a composite primary key, set the q as distinct + for t in {c.table for c in target_cols}: + if not set(target_cols).issuperset(t.primary_key): + q._distinct = True + break + + # don't need ORDER BY if no limit/offset + if not q._has_row_limiting_clause: + q._order_by_clauses = () + + if q._distinct is True and q._order_by_clauses: + # the logic to automatically add the order by columns to the query + # when distinct is True is deprecated in the query + to_add = sql_util.expand_column_list_from_order_by( + target_cols, q._order_by_clauses + ) + if to_add: + q._set_entities(target_cols + to_add) + + # the original query now becomes a subquery + # which we'll join onto. + # LEGACY: as "q" is a Query, the before_compile() event is invoked + # here. + embed_q = q.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).subquery() + left_alias = orm_util.AliasedClass( + leftmost_mapper, embed_q, use_mapper_path=True + ) + return left_alias + + def _prep_for_joins(self, left_alias, subq_path): + # figure out what's being joined. a.k.a. the fun part + to_join = [] + pairs = list(subq_path.pairs()) + + for i, (mapper, prop) in enumerate(pairs): + if i > 0: + # look at the previous mapper in the chain - + # if it is as or more specific than this prop's + # mapper, use that instead. + # note we have an assumption here that + # the non-first element is always going to be a mapper, + # not an AliasedClass + + prev_mapper = pairs[i - 1][1].mapper + to_append = prev_mapper if prev_mapper.isa(mapper) else mapper + else: + to_append = mapper + + to_join.append((to_append, prop.key)) + + # determine the immediate parent class we are joining from, + # which needs to be aliased. + + if len(to_join) < 2: + # in the case of a one level eager load, this is the + # leftmost "left_alias". + parent_alias = left_alias + else: + info = inspect(to_join[-1][0]) + if info.is_aliased_class: + parent_alias = info.entity + else: + # alias a plain mapper as we may be + # joining multiple times + parent_alias = orm_util.AliasedClass( + info.entity, use_mapper_path=True + ) + + local_cols = self.parent_property.local_columns + + local_attr = [ + getattr(parent_alias, self.parent._columntoproperty[c].key) + for c in local_cols + ] + return to_join, local_attr, parent_alias + + def _apply_joins( + self, q, to_join, left_alias, parent_alias, effective_entity + ): + ltj = len(to_join) + if ltj == 1: + to_join = [ + getattr(left_alias, to_join[0][1]).of_type(effective_entity) + ] + elif ltj == 2: + to_join = [ + getattr(left_alias, to_join[0][1]).of_type(parent_alias), + getattr(parent_alias, to_join[-1][1]).of_type( + effective_entity + ), + ] + elif ltj > 2: + middle = [ + ( + ( + orm_util.AliasedClass(item[0]) + if not inspect(item[0]).is_aliased_class + else item[0].entity + ), + item[1], + ) + for item in to_join[1:-1] + ] + inner = [] + + while middle: + item = middle.pop(0) + attr = getattr(item[0], item[1]) + if middle: + attr = attr.of_type(middle[0][0]) + else: + attr = attr.of_type(parent_alias) + + inner.append(attr) + + to_join = ( + [getattr(left_alias, to_join[0][1]).of_type(inner[0].parent)] + + inner + + [ + getattr(parent_alias, to_join[-1][1]).of_type( + effective_entity + ) + ] + ) + + for attr in to_join: + q = q.join(attr) + + return q + + def _setup_options( + self, + context, + q, + subq_path, + rewritten_path, + orig_query, + effective_entity, + loadopt, + ): + # note that because the subqueryload object + # does not re-use the cached query, instead always making + # use of the current invoked query, while we have two queries + # here (orig and context.query), they are both non-cached + # queries and we can transfer the options as is without + # adjusting for new criteria. Some work on #6881 / #6889 + # brought this into question. + new_options = orig_query._with_options + + if loadopt and loadopt._extra_criteria: + new_options += ( + orm_util.LoaderCriteriaOption( + self.entity, + loadopt._generate_extra_criteria(context), + ), + ) + + # propagate loader options etc. to the new query. + # these will fire relative to subq_path. + q = q._with_current_path(rewritten_path) + q = q.options(*new_options) + + return q + + def _setup_outermost_orderby(self, q): + if self.parent_property.order_by: + + def _setup_outermost_orderby(compile_context): + compile_context.eager_order_by += tuple( + util.to_list(self.parent_property.order_by) + ) + + q = q._add_context_option( + _setup_outermost_orderby, self.parent_property + ) + + return q + + class _SubqCollections: + """Given a :class:`_query.Query` used to emit the "subquery load", + provide a load interface that executes the query at the + first moment a value is needed. + + """ + + __slots__ = ( + "session", + "execution_options", + "load_options", + "params", + "subq", + "_data", + ) + + def __init__(self, context, subq): + # avoid creating a cycle by storing context + # even though that's preferable + self.session = context.session + self.execution_options = context.execution_options + self.load_options = context.load_options + self.params = context.params or {} + self.subq = subq + self._data = None + + def get(self, key, default): + if self._data is None: + self._load() + return self._data.get(key, default) + + def _load(self): + self._data = collections.defaultdict(list) + + q = self.subq + assert q.session is None + + q = q.with_session(self.session) + + if self.load_options._populate_existing: + q = q.populate_existing() + # to work with baked query, the parameters may have been + # updated since this query was created, so take these into account + + rows = list(q.params(self.params)) + for k, v in itertools.groupby(rows, lambda x: x[1:]): + self._data[k].extend(vv[0] for vv in v) + + def loader(self, state, dict_, row): + if self._data is None: + self._load() + + def _setup_query_from_rowproc( + self, + context, + query_entity, + path, + entity, + loadopt, + adapter, + ): + compile_state = context.compile_state + if ( + not compile_state.compile_options._enable_eagerloads + or compile_state.compile_options._for_refresh_state + ): + return + + orig_query_entity_index = compile_state._entities.index(query_entity) + context.loaders_require_buffering = True + + path = path[self.parent_property] + + # build up a path indicating the path from the leftmost + # entity to the thing we're subquery loading. + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic", None + ) + if with_poly_entity is not None: + effective_entity = with_poly_entity + else: + effective_entity = self.entity + + subq_path, rewritten_path = context.query._execution_options.get( + ("subquery_paths", None), + (orm_util.PathRegistry.root, orm_util.PathRegistry.root), + ) + is_root = subq_path is orm_util.PathRegistry.root + subq_path = subq_path + path + rewritten_path = rewritten_path + path + + # use the current query being invoked, not the compile state + # one. this is so that we get the current parameters. however, + # it means we can't use the existing compile state, we have to make + # a new one. other approaches include possibly using the + # compiled query but swapping the params, seems only marginally + # less time spent but more complicated + orig_query = context.query._execution_options.get( + ("orig_query", SubqueryLoader), context.query + ) + + # make a new compile_state for the query that's probably cached, but + # we're sort of undoing a bit of that caching :( + compile_state_cls = ORMCompileState._get_plugin_class_for_plugin( + orig_query, "orm" + ) + + if orig_query._is_lambda_element: + if context.load_options._lazy_loaded_from is None: + util.warn( + 'subqueryloader for "%s" must invoke lambda callable ' + "at %r in " + "order to produce a new query, decreasing the efficiency " + "of caching for this statement. Consider using " + "selectinload() for more effective full-lambda caching" + % (self, orig_query) + ) + orig_query = orig_query._resolved + + # this is the more "quick" version, however it's not clear how + # much of this we need. in particular I can't get a test to + # fail if the "set_base_alias" is missing and not sure why that is. + orig_compile_state = compile_state_cls._create_entities_collection( + orig_query, legacy=False + ) + + ( + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + rewritten_path, + ) = self._get_leftmost( + orig_query_entity_index, + rewritten_path, + orig_compile_state, + is_root, + ) + + # generate a new Query from the original, then + # produce a subquery from it. + left_alias = self._generate_from_original_query( + orig_compile_state, + orig_query, + leftmost_mapper, + leftmost_attr, + leftmost_relationship, + entity, + ) + + # generate another Query that will join the + # left alias to the target relationships. + # basically doing a longhand + # "from_self()". (from_self() itself not quite industrial + # strength enough for all contingencies...but very close) + + q = query.Query(effective_entity) + + q._execution_options = context.query._execution_options.merge_with( + context.execution_options, + { + ("orig_query", SubqueryLoader): orig_query, + ("subquery_paths", None): (subq_path, rewritten_path), + }, + ) + + q = q._set_enable_single_crit(False) + to_join, local_attr, parent_alias = self._prep_for_joins( + left_alias, subq_path + ) + + q = q.add_columns(*local_attr) + q = self._apply_joins( + q, to_join, left_alias, parent_alias, effective_entity + ) + + q = self._setup_options( + context, + q, + subq_path, + rewritten_path, + orig_query, + effective_entity, + loadopt, + ) + q = self._setup_outermost_orderby(q) + + return q + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + if context.refresh_state: + return self._immediateload_create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + _, run_loader, _, _ = self._setup_for_recursion( + context, path, loadopt, self.join_depth + ) + if not run_loader: + return + + if not isinstance(context.compile_state, ORMSelectCompileState): + # issue 7505 - subqueryload() in 1.3 and previous would silently + # degrade for from_statement() without warning. this behavior + # is restored here + return + + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % self + ) + + # a little dance here as the "path" is still something that only + # semi-tracks the exact series of things we are loading, still not + # telling us about with_polymorphic() and stuff like that when it's at + # the root.. the initial MapperEntity is more accurate for this case. + if len(path) == 1: + if not orm_util._entity_isa(query_entity.entity_zero, self.parent): + return + elif not orm_util._entity_isa(path[-1], self.parent): + return + + subq = self._setup_query_from_rowproc( + context, + query_entity, + path, + path[-1], + loadopt, + adapter, + ) + + if subq is None: + return + + assert subq.session is None + + path = path[self.parent_property] + + local_cols = self.parent_property.local_columns + + # cache the loaded collections in the context + # so that inheriting mappers don't re-load when they + # call upon create_row_processor again + collections = path.get(context.attributes, "collections") + if collections is None: + collections = self._SubqCollections(context, subq) + path.set(context.attributes, "collections", collections) + + if adapter: + local_cols = [adapter.columns[c] for c in local_cols] + + if self.uselist: + self._create_collection_loader( + context, result, collections, local_cols, populators + ) + else: + self._create_scalar_loader( + context, result, collections, local_cols, populators + ) + + def _create_collection_loader( + self, context, result, collections, local_cols, populators + ): + tuple_getter = result._tuple_getter(local_cols) + + def load_collection_from_subq(state, dict_, row): + collection = collections.get(tuple_getter(row), ()) + state.get_impl(self.key).set_committed_value( + state, dict_, collection + ) + + def load_collection_from_subq_existing_row(state, dict_, row): + if self.key not in dict_: + load_collection_from_subq(state, dict_, row) + + populators["new"].append((self.key, load_collection_from_subq)) + populators["existing"].append( + (self.key, load_collection_from_subq_existing_row) + ) + + if context.invoke_all_eagers: + populators["eager"].append((self.key, collections.loader)) + + def _create_scalar_loader( + self, context, result, collections, local_cols, populators + ): + tuple_getter = result._tuple_getter(local_cols) + + def load_scalar_from_subq(state, dict_, row): + collection = collections.get(tuple_getter(row), (None,)) + if len(collection) > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " % self + ) + + scalar = collection[0] + state.get_impl(self.key).set_committed_value(state, dict_, scalar) + + def load_scalar_from_subq_existing_row(state, dict_, row): + if self.key not in dict_: + load_scalar_from_subq(state, dict_, row) + + populators["new"].append((self.key, load_scalar_from_subq)) + populators["existing"].append( + (self.key, load_scalar_from_subq_existing_row) + ) + if context.invoke_all_eagers: + populators["eager"].append((self.key, collections.loader)) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="joined") +@relationships.RelationshipProperty.strategy_for(lazy=False) +class JoinedLoader(AbstractRelationshipLoader): + """Provide loading behavior for a :class:`.Relationship` + using joined eager loading. + + """ + + __slots__ = "join_depth" + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def setup_query( + self, + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection=None, + parentmapper=None, + chained_from_outerjoin=False, + **kwargs, + ): + """Add a left outer join to the statement that's being constructed.""" + + if not compile_state.compile_options._enable_eagerloads: + return + elif self.uselist: + compile_state.multi_row_eager_loaders = True + + path = path[self.parent_property] + + with_polymorphic = None + + user_defined_adapter = ( + self._init_user_defined_eager_proc( + loadopt, compile_state, compile_state.attributes + ) + if loadopt + else False + ) + + if user_defined_adapter is not False: + # setup an adapter but dont create any JOIN, assume it's already + # in the query + ( + clauses, + adapter, + add_to_collection, + ) = self._setup_query_on_user_defined_adapter( + compile_state, + query_entity, + path, + adapter, + user_defined_adapter, + ) + + # don't do "wrap" for multi-row, we want to wrap + # limited/distinct SELECT, + # because we want to put the JOIN on the outside. + + else: + # if not via query option, check for + # a cycle + if not path.contains(compile_state.attributes, "loader"): + if self.join_depth: + if path.length / 2 > self.join_depth: + return + elif path.contains_mapper(self.mapper): + return + + # add the JOIN and create an adapter + ( + clauses, + adapter, + add_to_collection, + chained_from_outerjoin, + ) = self._generate_row_adapter( + compile_state, + query_entity, + path, + loadopt, + adapter, + column_collection, + parentmapper, + chained_from_outerjoin, + ) + + # for multi-row, we want to wrap limited/distinct SELECT, + # because we want to put the JOIN on the outside. + compile_state.eager_adding_joins = True + + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic", None + ) + if with_poly_entity is not None: + with_polymorphic = inspect( + with_poly_entity + ).with_polymorphic_mappers + else: + with_polymorphic = None + + path = path[self.entity] + + loading._setup_entity_query( + compile_state, + self.mapper, + query_entity, + path, + clauses, + add_to_collection, + with_polymorphic=with_polymorphic, + parentmapper=self.mapper, + chained_from_outerjoin=chained_from_outerjoin, + ) + + has_nones = util.NONE_SET.intersection(compile_state.secondary_columns) + + if has_nones: + if with_poly_entity is not None: + raise sa_exc.InvalidRequestError( + "Detected unaliased columns when generating joined " + "load. Make sure to use aliased=True or flat=True " + "when using joined loading with with_polymorphic()." + ) + else: + compile_state.secondary_columns = [ + c for c in compile_state.secondary_columns if c is not None + ] + + def _init_user_defined_eager_proc( + self, loadopt, compile_state, target_attributes + ): + # check if the opt applies at all + if "eager_from_alias" not in loadopt.local_opts: + # nope + return False + + path = loadopt.path.parent + + # the option applies. check if the "user_defined_eager_row_processor" + # has been built up. + adapter = path.get( + compile_state.attributes, "user_defined_eager_row_processor", False + ) + if adapter is not False: + # just return it + return adapter + + # otherwise figure it out. + alias = loadopt.local_opts["eager_from_alias"] + root_mapper, prop = path[-2:] + + if alias is not None: + if isinstance(alias, str): + alias = prop.target.alias(alias) + adapter = orm_util.ORMAdapter( + orm_util._TraceAdaptRole.JOINEDLOAD_USER_DEFINED_ALIAS, + prop.mapper, + selectable=alias, + equivalents=prop.mapper._equivalent_columns, + limit_on_entity=False, + ) + else: + if path.contains( + compile_state.attributes, "path_with_polymorphic" + ): + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic" + ) + adapter = orm_util.ORMAdapter( + orm_util._TraceAdaptRole.JOINEDLOAD_PATH_WITH_POLYMORPHIC, + with_poly_entity, + equivalents=prop.mapper._equivalent_columns, + ) + else: + adapter = compile_state._polymorphic_adapters.get( + prop.mapper, None + ) + path.set( + target_attributes, + "user_defined_eager_row_processor", + adapter, + ) + + return adapter + + def _setup_query_on_user_defined_adapter( + self, context, entity, path, adapter, user_defined_adapter + ): + # apply some more wrapping to the "user defined adapter" + # if we are setting up the query for SQL render. + adapter = entity._get_entity_clauses(context) + + if adapter and user_defined_adapter: + user_defined_adapter = user_defined_adapter.wrap(adapter) + path.set( + context.attributes, + "user_defined_eager_row_processor", + user_defined_adapter, + ) + elif adapter: + user_defined_adapter = adapter + path.set( + context.attributes, + "user_defined_eager_row_processor", + user_defined_adapter, + ) + + add_to_collection = context.primary_columns + return user_defined_adapter, adapter, add_to_collection + + def _generate_row_adapter( + self, + compile_state, + entity, + path, + loadopt, + adapter, + column_collection, + parentmapper, + chained_from_outerjoin, + ): + with_poly_entity = path.get( + compile_state.attributes, "path_with_polymorphic", None + ) + if with_poly_entity: + to_adapt = with_poly_entity + else: + insp = inspect(self.entity) + if insp.is_aliased_class: + alt_selectable = insp.selectable + else: + alt_selectable = None + + to_adapt = orm_util.AliasedClass( + self.mapper, + alias=( + alt_selectable._anonymous_fromclause(flat=True) + if alt_selectable is not None + else None + ), + flat=True, + use_mapper_path=True, + ) + + to_adapt_insp = inspect(to_adapt) + + clauses = to_adapt_insp._memo( + ("joinedloader_ormadapter", self), + orm_util.ORMAdapter, + orm_util._TraceAdaptRole.JOINEDLOAD_MEMOIZED_ADAPTER, + to_adapt_insp, + equivalents=self.mapper._equivalent_columns, + adapt_required=True, + allow_label_resolve=False, + anonymize_labels=True, + ) + + assert clauses.is_aliased_class + + innerjoin = ( + loadopt.local_opts.get("innerjoin", self.parent_property.innerjoin) + if loadopt is not None + else self.parent_property.innerjoin + ) + + if not innerjoin: + # if this is an outer join, all non-nested eager joins from + # this path must also be outer joins + chained_from_outerjoin = True + + compile_state.create_eager_joins.append( + ( + self._create_eager_join, + entity, + path, + adapter, + parentmapper, + clauses, + innerjoin, + chained_from_outerjoin, + loadopt._extra_criteria if loadopt else (), + ) + ) + + add_to_collection = compile_state.secondary_columns + path.set(compile_state.attributes, "eager_row_processor", clauses) + + return clauses, adapter, add_to_collection, chained_from_outerjoin + + def _create_eager_join( + self, + compile_state, + query_entity, + path, + adapter, + parentmapper, + clauses, + innerjoin, + chained_from_outerjoin, + extra_criteria, + ): + if parentmapper is None: + localparent = query_entity.mapper + else: + localparent = parentmapper + + # whether or not the Query will wrap the selectable in a subquery, + # and then attach eager load joins to that (i.e., in the case of + # LIMIT/OFFSET etc.) + should_nest_selectable = ( + compile_state.multi_row_eager_loaders + and compile_state._should_nest_selectable + ) + + query_entity_key = None + + if ( + query_entity not in compile_state.eager_joins + and not should_nest_selectable + and compile_state.from_clauses + ): + indexes = sql_util.find_left_clause_that_matches_given( + compile_state.from_clauses, query_entity.selectable + ) + + if len(indexes) > 1: + # for the eager load case, I can't reproduce this right + # now. For query.join() I can. + raise sa_exc.InvalidRequestError( + "Can't identify which query entity in which to joined " + "eager load from. Please use an exact match when " + "specifying the join path." + ) + + if indexes: + clause = compile_state.from_clauses[indexes[0]] + # join to an existing FROM clause on the query. + # key it to its list index in the eager_joins dict. + # Query._compile_context will adapt as needed and + # append to the FROM clause of the select(). + query_entity_key, default_towrap = indexes[0], clause + + if query_entity_key is None: + query_entity_key, default_towrap = ( + query_entity, + query_entity.selectable, + ) + + towrap = compile_state.eager_joins.setdefault( + query_entity_key, default_towrap + ) + + if adapter: + if getattr(adapter, "is_aliased_class", False): + # joining from an adapted entity. The adapted entity + # might be a "with_polymorphic", so resolve that to our + # specific mapper's entity before looking for our attribute + # name on it. + efm = adapter.aliased_insp._entity_for_mapper( + localparent + if localparent.isa(self.parent) + else self.parent + ) + + # look for our attribute on the adapted entity, else fall back + # to our straight property + onclause = getattr(efm.entity, self.key, self.parent_property) + else: + onclause = getattr( + orm_util.AliasedClass( + self.parent, adapter.selectable, use_mapper_path=True + ), + self.key, + self.parent_property, + ) + + else: + onclause = self.parent_property + + assert clauses.is_aliased_class + + attach_on_outside = ( + not chained_from_outerjoin + or not innerjoin + or innerjoin == "unnested" + or query_entity.entity_zero.represents_outer_join + ) + + extra_join_criteria = extra_criteria + additional_entity_criteria = compile_state.global_attributes.get( + ("additional_entity_criteria", self.mapper), () + ) + if additional_entity_criteria: + extra_join_criteria += tuple( + ae._resolve_where_criteria(self.mapper) + for ae in additional_entity_criteria + if ae.propagate_to_loaders + ) + + if attach_on_outside: + # this is the "classic" eager join case. + eagerjoin = orm_util._ORMJoin( + towrap, + clauses.aliased_insp, + onclause, + isouter=not innerjoin + or query_entity.entity_zero.represents_outer_join + or (chained_from_outerjoin and isinstance(towrap, sql.Join)), + _left_memo=self.parent, + _right_memo=self.mapper, + _extra_criteria=extra_join_criteria, + ) + else: + # all other cases are innerjoin=='nested' approach + eagerjoin = self._splice_nested_inner_join( + path, towrap, clauses, onclause, extra_join_criteria + ) + + compile_state.eager_joins[query_entity_key] = eagerjoin + + # send a hint to the Query as to where it may "splice" this join + eagerjoin.stop_on = query_entity.selectable + + if not parentmapper: + # for parentclause that is the non-eager end of the join, + # ensure all the parent cols in the primaryjoin are actually + # in the + # columns clause (i.e. are not deferred), so that aliasing applied + # by the Query propagates those columns outward. + # This has the effect + # of "undefering" those columns. + for col in sql_util._find_columns( + self.parent_property.primaryjoin + ): + if localparent.persist_selectable.c.contains_column(col): + if adapter: + col = adapter.columns[col] + compile_state._append_dedupe_col_collection( + col, compile_state.primary_columns + ) + + if self.parent_property.order_by: + compile_state.eager_order_by += tuple( + (eagerjoin._target_adapter.copy_and_process)( + util.to_list(self.parent_property.order_by) + ) + ) + + def _splice_nested_inner_join( + self, path, join_obj, clauses, onclause, extra_criteria, splicing=False + ): + # recursive fn to splice a nested join into an existing one. + # splicing=False means this is the outermost call, and it + # should return a value. splicing= is the recursive + # form, where it can return None to indicate the end of the recursion + + if splicing is False: + # first call is always handed a join object + # from the outside + assert isinstance(join_obj, orm_util._ORMJoin) + elif isinstance(join_obj, sql.selectable.FromGrouping): + return self._splice_nested_inner_join( + path, + join_obj.element, + clauses, + onclause, + extra_criteria, + splicing, + ) + elif not isinstance(join_obj, orm_util._ORMJoin): + if path[-2].isa(splicing): + return orm_util._ORMJoin( + join_obj, + clauses.aliased_insp, + onclause, + isouter=False, + _left_memo=splicing, + _right_memo=path[-1].mapper, + _extra_criteria=extra_criteria, + ) + else: + return None + + target_join = self._splice_nested_inner_join( + path, + join_obj.right, + clauses, + onclause, + extra_criteria, + join_obj._right_memo, + ) + if target_join is None: + right_splice = False + target_join = self._splice_nested_inner_join( + path, + join_obj.left, + clauses, + onclause, + extra_criteria, + join_obj._left_memo, + ) + if target_join is None: + # should only return None when recursively called, + # e.g. splicing refers to a from obj + assert ( + splicing is not False + ), "assertion failed attempting to produce joined eager loads" + return None + else: + right_splice = True + + if right_splice: + # for a right splice, attempt to flatten out + # a JOIN b JOIN c JOIN .. to avoid needless + # parenthesis nesting + if not join_obj.isouter and not target_join.isouter: + eagerjoin = join_obj._splice_into_center(target_join) + else: + eagerjoin = orm_util._ORMJoin( + join_obj.left, + target_join, + join_obj.onclause, + isouter=join_obj.isouter, + _left_memo=join_obj._left_memo, + ) + else: + eagerjoin = orm_util._ORMJoin( + target_join, + join_obj.right, + join_obj.onclause, + isouter=join_obj.isouter, + _right_memo=join_obj._right_memo, + ) + + eagerjoin._target_adapter = target_join._target_adapter + return eagerjoin + + def _create_eager_adapter(self, context, result, adapter, path, loadopt): + compile_state = context.compile_state + + user_defined_adapter = ( + self._init_user_defined_eager_proc( + loadopt, compile_state, context.attributes + ) + if loadopt + else False + ) + + if user_defined_adapter is not False: + decorator = user_defined_adapter + # user defined eagerloads are part of the "primary" + # portion of the load. + # the adapters applied to the Query should be honored. + if compile_state.compound_eager_adapter and decorator: + decorator = decorator.wrap( + compile_state.compound_eager_adapter + ) + elif compile_state.compound_eager_adapter: + decorator = compile_state.compound_eager_adapter + else: + decorator = path.get( + compile_state.attributes, "eager_row_processor" + ) + if decorator is None: + return False + + if self.mapper._result_has_identity_key(result, decorator): + return decorator + else: + # no identity key - don't return a row + # processor, will cause a degrade to lazy + return False + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % self + ) + + if self.uselist: + context.loaders_require_uniquing = True + + our_path = path[self.parent_property] + + eager_adapter = self._create_eager_adapter( + context, result, adapter, our_path, loadopt + ) + + if eager_adapter is not False: + key = self.key + + _instance = loading._instance_processor( + query_entity, + self.mapper, + context, + result, + our_path[self.entity], + eager_adapter, + ) + + if not self.uselist: + self._create_scalar_loader(context, key, _instance, populators) + else: + self._create_collection_loader( + context, key, _instance, populators + ) + else: + self.parent_property._get_strategy( + (("lazy", "select"),) + ).create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + def _create_collection_loader(self, context, key, _instance, populators): + def load_collection_from_joined_new_row(state, dict_, row): + # note this must unconditionally clear out any existing collection. + # an existing collection would be present only in the case of + # populate_existing(). + collection = attributes.init_state_collection(state, dict_, key) + result_list = util.UniqueAppender( + collection, "append_without_event" + ) + context.attributes[(state, key)] = result_list + inst = _instance(row) + if inst is not None: + result_list.append(inst) + + def load_collection_from_joined_existing_row(state, dict_, row): + if (state, key) in context.attributes: + result_list = context.attributes[(state, key)] + else: + # appender_key can be absent from context.attributes + # with isnew=False when self-referential eager loading + # is used; the same instance may be present in two + # distinct sets of result columns + collection = attributes.init_state_collection( + state, dict_, key + ) + result_list = util.UniqueAppender( + collection, "append_without_event" + ) + context.attributes[(state, key)] = result_list + inst = _instance(row) + if inst is not None: + result_list.append(inst) + + def load_collection_from_joined_exec(state, dict_, row): + _instance(row) + + populators["new"].append( + (self.key, load_collection_from_joined_new_row) + ) + populators["existing"].append( + (self.key, load_collection_from_joined_existing_row) + ) + if context.invoke_all_eagers: + populators["eager"].append( + (self.key, load_collection_from_joined_exec) + ) + + def _create_scalar_loader(self, context, key, _instance, populators): + def load_scalar_from_joined_new_row(state, dict_, row): + # set a scalar object instance directly on the parent + # object, bypassing InstrumentedAttribute event handlers. + dict_[key] = _instance(row) + + def load_scalar_from_joined_existing_row(state, dict_, row): + # call _instance on the row, even though the object has + # been created, so that we further descend into properties + existing = _instance(row) + + # conflicting value already loaded, this shouldn't happen + if key in dict_: + if existing is not dict_[key]: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded attribute '%s' " + % self + ) + else: + # this case is when one row has multiple loads of the + # same entity (e.g. via aliasing), one has an attribute + # that the other doesn't. + dict_[key] = existing + + def load_scalar_from_joined_exec(state, dict_, row): + _instance(row) + + populators["new"].append((self.key, load_scalar_from_joined_new_row)) + populators["existing"].append( + (self.key, load_scalar_from_joined_existing_row) + ) + if context.invoke_all_eagers: + populators["eager"].append( + (self.key, load_scalar_from_joined_exec) + ) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="selectin") +class SelectInLoader(PostLoader, util.MemoizedSlots): + __slots__ = ( + "join_depth", + "omit_join", + "_parent_alias", + "_query_info", + "_fallback_query_info", + ) + + query_info = collections.namedtuple( + "queryinfo", + [ + "load_only_child", + "load_with_join", + "in_expr", + "pk_cols", + "zero_idx", + "child_lookup_cols", + ], + ) + + _chunksize = 500 + + def __init__(self, parent, strategy_key): + super().__init__(parent, strategy_key) + self.join_depth = self.parent_property.join_depth + is_m2o = self.parent_property.direction is interfaces.MANYTOONE + + if self.parent_property.omit_join is not None: + self.omit_join = self.parent_property.omit_join + else: + lazyloader = self.parent_property._get_strategy( + (("lazy", "select"),) + ) + if is_m2o: + self.omit_join = lazyloader.use_get + else: + self.omit_join = self.parent._get_clause[0].compare( + lazyloader._rev_lazywhere, + use_proxies=True, + compare_keys=False, + equivalents=self.parent._equivalent_columns, + ) + + if self.omit_join: + if is_m2o: + self._query_info = self._init_for_omit_join_m2o() + self._fallback_query_info = self._init_for_join() + else: + self._query_info = self._init_for_omit_join() + else: + self._query_info = self._init_for_join() + + def _init_for_omit_join(self): + pk_to_fk = dict( + self.parent_property._join_condition.local_remote_pairs + ) + pk_to_fk.update( + (equiv, pk_to_fk[k]) + for k in list(pk_to_fk) + for equiv in self.parent._equivalent_columns.get(k, ()) + ) + + pk_cols = fk_cols = [ + pk_to_fk[col] for col in self.parent.primary_key if col in pk_to_fk + ] + if len(fk_cols) > 1: + in_expr = sql.tuple_(*fk_cols) + zero_idx = False + else: + in_expr = fk_cols[0] + zero_idx = True + + return self.query_info(False, False, in_expr, pk_cols, zero_idx, None) + + def _init_for_omit_join_m2o(self): + pk_cols = self.mapper.primary_key + if len(pk_cols) > 1: + in_expr = sql.tuple_(*pk_cols) + zero_idx = False + else: + in_expr = pk_cols[0] + zero_idx = True + + lazyloader = self.parent_property._get_strategy((("lazy", "select"),)) + lookup_cols = [lazyloader._equated_columns[pk] for pk in pk_cols] + + return self.query_info( + True, False, in_expr, pk_cols, zero_idx, lookup_cols + ) + + def _init_for_join(self): + self._parent_alias = AliasedClass(self.parent.class_) + pa_insp = inspect(self._parent_alias) + pk_cols = [ + pa_insp._adapt_element(col) for col in self.parent.primary_key + ] + if len(pk_cols) > 1: + in_expr = sql.tuple_(*pk_cols) + zero_idx = False + else: + in_expr = pk_cols[0] + zero_idx = True + return self.query_info(False, True, in_expr, pk_cols, zero_idx, None) + + def init_class_attribute(self, mapper): + self.parent_property._get_strategy( + (("lazy", "select"),) + ).init_class_attribute(mapper) + + def create_row_processor( + self, + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ): + if context.refresh_state: + return self._immediateload_create_row_processor( + context, + query_entity, + path, + loadopt, + mapper, + result, + adapter, + populators, + ) + + ( + effective_path, + run_loader, + execution_options, + recursion_depth, + ) = self._setup_for_recursion( + context, path, loadopt, join_depth=self.join_depth + ) + + if not run_loader: + return + + if not self.parent.class_manager[self.key].impl.supports_population: + raise sa_exc.InvalidRequestError( + "'%s' does not support object " + "population - eager loading cannot be applied." % self + ) + + # a little dance here as the "path" is still something that only + # semi-tracks the exact series of things we are loading, still not + # telling us about with_polymorphic() and stuff like that when it's at + # the root.. the initial MapperEntity is more accurate for this case. + if len(path) == 1: + if not orm_util._entity_isa(query_entity.entity_zero, self.parent): + return + elif not orm_util._entity_isa(path[-1], self.parent): + return + + selectin_path = effective_path + + path_w_prop = path[self.parent_property] + + # build up a path indicating the path from the leftmost + # entity to the thing we're subquery loading. + with_poly_entity = path_w_prop.get( + context.attributes, "path_with_polymorphic", None + ) + if with_poly_entity is not None: + effective_entity = inspect(with_poly_entity) + else: + effective_entity = self.entity + + loading.PostLoad.callable_for_path( + context, + selectin_path, + self.parent, + self.parent_property, + self._load_for_path, + effective_entity, + loadopt, + recursion_depth, + execution_options, + ) + + def _load_for_path( + self, + context, + path, + states, + load_only, + effective_entity, + loadopt, + recursion_depth, + execution_options, + ): + if load_only and self.key not in load_only: + return + + query_info = self._query_info + + if query_info.load_only_child: + our_states = collections.defaultdict(list) + none_states = [] + + mapper = self.parent + + for state, overwrite in states: + state_dict = state.dict + related_ident = tuple( + mapper._get_state_attr_by_column( + state, + state_dict, + lk, + passive=attributes.PASSIVE_NO_FETCH, + ) + for lk in query_info.child_lookup_cols + ) + # if the loaded parent objects do not have the foreign key + # to the related item loaded, then degrade into the joined + # version of selectinload + if LoaderCallableStatus.PASSIVE_NO_RESULT in related_ident: + query_info = self._fallback_query_info + break + + # organize states into lists keyed to particular foreign + # key values. + if None not in related_ident: + our_states[related_ident].append( + (state, state_dict, overwrite) + ) + else: + # For FK values that have None, add them to a + # separate collection that will be populated separately + none_states.append((state, state_dict, overwrite)) + + # note the above conditional may have changed query_info + if not query_info.load_only_child: + our_states = [ + (state.key[1], state, state.dict, overwrite) + for state, overwrite in states + ] + + pk_cols = query_info.pk_cols + in_expr = query_info.in_expr + + if not query_info.load_with_join: + # in "omit join" mode, the primary key column and the + # "in" expression are in terms of the related entity. So + # if the related entity is polymorphic or otherwise aliased, + # we need to adapt our "pk_cols" and "in_expr" to that + # entity. in non-"omit join" mode, these are against the + # parent entity and do not need adaption. + if effective_entity.is_aliased_class: + pk_cols = [ + effective_entity._adapt_element(col) for col in pk_cols + ] + in_expr = effective_entity._adapt_element(in_expr) + + bundle_ent = orm_util.Bundle("pk", *pk_cols) + bundle_sql = bundle_ent.__clause_element__() + + entity_sql = effective_entity.__clause_element__() + q = Select._create_raw_select( + _raw_columns=[bundle_sql, entity_sql], + _label_style=LABEL_STYLE_TABLENAME_PLUS_COL, + _compile_options=ORMCompileState.default_compile_options, + _propagate_attrs={ + "compile_state_plugin": "orm", + "plugin_subject": effective_entity, + }, + ) + + if not query_info.load_with_join: + # the Bundle we have in the "omit_join" case is against raw, non + # annotated columns, so to ensure the Query knows its primary + # entity, we add it explicitly. If we made the Bundle against + # annotated columns, we hit a performance issue in this specific + # case, which is detailed in issue #4347. + q = q.select_from(effective_entity) + else: + # in the non-omit_join case, the Bundle is against the annotated/ + # mapped column of the parent entity, but the #4347 issue does not + # occur in this case. + q = q.select_from(self._parent_alias).join( + getattr(self._parent_alias, self.parent_property.key).of_type( + effective_entity + ) + ) + + q = q.filter(in_expr.in_(sql.bindparam("primary_keys"))) + + # a test which exercises what these comments talk about is + # test_selectin_relations.py -> test_twolevel_selectin_w_polymorphic + # + # effective_entity above is given to us in terms of the cached + # statement, namely this one: + orig_query = context.compile_state.select_statement + + # the actual statement that was requested is this one: + # context_query = context.query + # + # that's not the cached one, however. So while it is of the identical + # structure, if it has entities like AliasedInsp, which we get from + # aliased() or with_polymorphic(), the AliasedInsp will likely be a + # different object identity each time, and will not match up + # hashing-wise to the corresponding AliasedInsp that's in the + # cached query, meaning it won't match on paths and loader lookups + # and loaders like this one will be skipped if it is used in options. + # + # as it turns out, standard loader options like selectinload(), + # lazyload() that have a path need + # to come from the cached query so that the AliasedInsp etc. objects + # that are in the query line up with the object that's in the path + # of the strategy object. however other options like + # with_loader_criteria() that doesn't have a path (has a fixed entity) + # and needs to have access to the latest closure state in order to + # be correct, we need to use the uncached one. + # + # as of #8399 we let the loader option itself figure out what it + # wants to do given cached and uncached version of itself. + + effective_path = path[self.parent_property] + + if orig_query is context.query: + new_options = orig_query._with_options + else: + cached_options = orig_query._with_options + uncached_options = context.query._with_options + + # propagate compile state options from the original query, + # updating their "extra_criteria" as necessary. + # note this will create a different cache key than + # "orig" options if extra_criteria is present, because the copy + # of extra_criteria will have different boundparam than that of + # the QueryableAttribute in the path + new_options = [ + orig_opt._adapt_cached_option_to_uncached_option( + context, uncached_opt + ) + for orig_opt, uncached_opt in zip( + cached_options, uncached_options + ) + ] + + if loadopt and loadopt._extra_criteria: + new_options += ( + orm_util.LoaderCriteriaOption( + effective_entity, + loadopt._generate_extra_criteria(context), + ), + ) + + if recursion_depth is not None: + effective_path = effective_path._truncate_recursive() + + q = q.options(*new_options) + + q = q._update_compile_options({"_current_path": effective_path}) + if context.populate_existing: + q = q.execution_options(populate_existing=True) + + if self.parent_property.order_by: + if not query_info.load_with_join: + eager_order_by = self.parent_property.order_by + if effective_entity.is_aliased_class: + eager_order_by = [ + effective_entity._adapt_element(elem) + for elem in eager_order_by + ] + q = q.order_by(*eager_order_by) + else: + + def _setup_outermost_orderby(compile_context): + compile_context.eager_order_by += tuple( + util.to_list(self.parent_property.order_by) + ) + + q = q._add_context_option( + _setup_outermost_orderby, self.parent_property + ) + + if query_info.load_only_child: + self._load_via_child( + our_states, + none_states, + query_info, + q, + context, + execution_options, + ) + else: + self._load_via_parent( + our_states, query_info, q, context, execution_options + ) + + def _load_via_child( + self, + our_states, + none_states, + query_info, + q, + context, + execution_options, + ): + uselist = self.uselist + + # this sort is really for the benefit of the unit tests + our_keys = sorted(our_states) + while our_keys: + chunk = our_keys[0 : self._chunksize] + our_keys = our_keys[self._chunksize :] + data = { + k: v + for k, v in context.session.execute( + q, + params={ + "primary_keys": [ + key[0] if query_info.zero_idx else key + for key in chunk + ] + }, + execution_options=execution_options, + ).unique() + } + + for key in chunk: + # for a real foreign key and no concurrent changes to the + # DB while running this method, "key" is always present in + # data. However, for primaryjoins without real foreign keys + # a non-None primaryjoin condition may still refer to no + # related object. + related_obj = data.get(key, None) + for state, dict_, overwrite in our_states[key]: + if not overwrite and self.key in dict_: + continue + + state.get_impl(self.key).set_committed_value( + state, + dict_, + related_obj if not uselist else [related_obj], + ) + # populate none states with empty value / collection + for state, dict_, overwrite in none_states: + if not overwrite and self.key in dict_: + continue + + # note it's OK if this is a uselist=True attribute, the empty + # collection will be populated + state.get_impl(self.key).set_committed_value(state, dict_, None) + + def _load_via_parent( + self, our_states, query_info, q, context, execution_options + ): + uselist = self.uselist + _empty_result = () if uselist else None + + while our_states: + chunk = our_states[0 : self._chunksize] + our_states = our_states[self._chunksize :] + + primary_keys = [ + key[0] if query_info.zero_idx else key + for key, state, state_dict, overwrite in chunk + ] + + data = collections.defaultdict(list) + for k, v in itertools.groupby( + context.session.execute( + q, + params={"primary_keys": primary_keys}, + execution_options=execution_options, + ).unique(), + lambda x: x[0], + ): + data[k].extend(vv[1] for vv in v) + + for key, state, state_dict, overwrite in chunk: + if not overwrite and self.key in state_dict: + continue + + collection = data.get(key, _empty_result) + + if not uselist and collection: + if len(collection) > 1: + util.warn( + "Multiple rows returned with " + "uselist=False for eagerly-loaded " + "attribute '%s' " % self + ) + state.get_impl(self.key).set_committed_value( + state, state_dict, collection[0] + ) + else: + # note that empty tuple set on uselist=False sets the + # value to None + state.get_impl(self.key).set_committed_value( + state, state_dict, collection + ) + + +def single_parent_validator(desc, prop): + def _do_check(state, value, oldvalue, initiator): + if value is not None and initiator.key == prop.key: + hasparent = initiator.hasparent(attributes.instance_state(value)) + if hasparent and oldvalue is not value: + raise sa_exc.InvalidRequestError( + "Instance %s is already associated with an instance " + "of %s via its %s attribute, and is only allowed a " + "single parent." + % (orm_util.instance_str(value), state.class_, prop), + code="bbf1", + ) + return value + + def append(state, value, initiator): + return _do_check(state, value, None, initiator) + + def set_(state, value, oldvalue, initiator): + return _do_check(state, value, oldvalue, initiator) + + event.listen( + desc, "append", append, raw=True, retval=True, active_history=True + ) + event.listen(desc, "set", set_, raw=True, retval=True, active_history=True) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/strategy_options.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/strategy_options.py new file mode 100644 index 0000000..25c6332 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/strategy_options.py @@ -0,0 +1,2555 @@ +# orm/strategy_options.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +""" + +""" + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union + +from . import util as orm_util +from ._typing import insp_is_aliased_class +from ._typing import insp_is_attribute +from ._typing import insp_is_mapper +from ._typing import insp_is_mapper_property +from .attributes import QueryableAttribute +from .base import InspectionAttr +from .interfaces import LoaderOption +from .path_registry import _DEFAULT_TOKEN +from .path_registry import _StrPathToken +from .path_registry import _WILDCARD_TOKEN +from .path_registry import AbstractEntityRegistry +from .path_registry import path_is_property +from .path_registry import PathRegistry +from .path_registry import TokenRegistry +from .util import _orm_full_deannotate +from .util import AliasedInsp +from .. import exc as sa_exc +from .. import inspect +from .. import util +from ..sql import and_ +from ..sql import cache_key +from ..sql import coercions +from ..sql import roles +from ..sql import traversals +from ..sql import visitors +from ..sql.base import _generative +from ..util.typing import Final +from ..util.typing import Literal +from ..util.typing import Self + +_RELATIONSHIP_TOKEN: Final[Literal["relationship"]] = "relationship" +_COLUMN_TOKEN: Final[Literal["column"]] = "column" + +_FN = TypeVar("_FN", bound="Callable[..., Any]") + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _InternalEntityType + from .context import _MapperEntity + from .context import ORMCompileState + from .context import QueryContext + from .interfaces import _StrategyKey + from .interfaces import MapperProperty + from .interfaces import ORMOption + from .mapper import Mapper + from .path_registry import _PathRepresentation + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _FromClauseArgument + from ..sql.cache_key import _CacheKeyTraversalType + from ..sql.cache_key import CacheKey + + +_AttrType = Union[Literal["*"], "QueryableAttribute[Any]"] + +_WildcardKeyType = Literal["relationship", "column"] +_StrategySpec = Dict[str, Any] +_OptsType = Dict[str, Any] +_AttrGroupType = Tuple[_AttrType, ...] + + +class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): + __slots__ = ("propagate_to_loaders",) + + _is_strategy_option = True + propagate_to_loaders: bool + + def contains_eager( + self, + attr: _AttrType, + alias: Optional[_FromClauseArgument] = None, + _is_chain: bool = False, + ) -> Self: + r"""Indicate that the given attribute should be eagerly loaded from + columns stated manually in the query. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + The option is used in conjunction with an explicit join that loads + the desired rows, i.e.:: + + sess.query(Order).join(Order.user).options( + contains_eager(Order.user) + ) + + The above query would join from the ``Order`` entity to its related + ``User`` entity, and the returned ``Order`` objects would have the + ``Order.user`` attribute pre-populated. + + It may also be used for customizing the entries in an eagerly loaded + collection; queries will normally want to use the + :ref:`orm_queryguide_populate_existing` execution option assuming the + primary collection of parent objects may already have been loaded:: + + sess.query(User).join(User.addresses).filter( + Address.email_address.like("%@aol.com") + ).options(contains_eager(User.addresses)).populate_existing() + + See the section :ref:`contains_eager` for complete usage details. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`contains_eager` + + """ + if alias is not None: + if not isinstance(alias, str): + coerced_alias = coercions.expect(roles.FromClauseRole, alias) + else: + util.warn_deprecated( + "Passing a string name for the 'alias' argument to " + "'contains_eager()` is deprecated, and will not work in a " + "future release. Please use a sqlalchemy.alias() or " + "sqlalchemy.orm.aliased() construct.", + version="1.4", + ) + coerced_alias = alias + + elif getattr(attr, "_of_type", None): + assert isinstance(attr, QueryableAttribute) + ot: Optional[_InternalEntityType[Any]] = inspect(attr._of_type) + assert ot is not None + coerced_alias = ot.selectable + else: + coerced_alias = None + + cloned = self._set_relationship_strategy( + attr, + {"lazy": "joined"}, + propagate_to_loaders=False, + opts={"eager_from_alias": coerced_alias}, + _reconcile_to_other=True if _is_chain else None, + ) + return cloned + + def load_only(self, *attrs: _AttrType, raiseload: bool = False) -> Self: + r"""Indicate that for a particular entity, only the given list + of column-based attribute names should be loaded; all others will be + deferred. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + Example - given a class ``User``, load only the ``name`` and + ``fullname`` attributes:: + + session.query(User).options(load_only(User.name, User.fullname)) + + Example - given a relationship ``User.addresses -> Address``, specify + subquery loading for the ``User.addresses`` collection, but on each + ``Address`` object load only the ``email_address`` attribute:: + + session.query(User).options( + subqueryload(User.addresses).load_only(Address.email_address) + ) + + For a statement that has multiple entities, + the lead entity can be + specifically referred to using the :class:`_orm.Load` constructor:: + + stmt = ( + select(User, Address) + .join(User.addresses) + .options( + Load(User).load_only(User.name, User.fullname), + Load(Address).load_only(Address.email_address), + ) + ) + + When used together with the + :ref:`populate_existing ` + execution option only the attributes listed will be refreshed. + + :param \*attrs: Attributes to be loaded, all others will be deferred. + + :param raiseload: raise :class:`.InvalidRequestError` rather than + lazy loading a value when a deferred attribute is accessed. Used + to prevent unwanted SQL from being emitted. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :param \*attrs: Attributes to be loaded, all others will be deferred. + + :param raiseload: raise :class:`.InvalidRequestError` rather than + lazy loading a value when a deferred attribute is accessed. Used + to prevent unwanted SQL from being emitted. + + .. versionadded:: 2.0 + + """ + cloned = self._set_column_strategy( + attrs, + {"deferred": False, "instrument": True}, + ) + + wildcard_strategy = {"deferred": True, "instrument": True} + if raiseload: + wildcard_strategy["raiseload"] = True + + cloned = cloned._set_column_strategy( + ("*",), + wildcard_strategy, + ) + return cloned + + def joinedload( + self, + attr: _AttrType, + innerjoin: Optional[bool] = None, + ) -> Self: + """Indicate that the given attribute should be loaded using joined + eager loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # joined-load the "orders" collection on "User" + select(User).options(joinedload(User.orders)) + + # joined-load Order.items and then Item.keywords + select(Order).options( + joinedload(Order.items).joinedload(Item.keywords) + ) + + # lazily load Order.items, but when Items are loaded, + # joined-load the keywords collection + select(Order).options( + lazyload(Order.items).joinedload(Item.keywords) + ) + + :param innerjoin: if ``True``, indicates that the joined eager load + should use an inner join instead of the default of left outer join:: + + select(Order).options(joinedload(Order.user, innerjoin=True)) + + In order to chain multiple eager joins together where some may be + OUTER and others INNER, right-nested joins are used to link them:: + + select(A).options( + joinedload(A.bs, innerjoin=False).joinedload( + B.cs, innerjoin=True + ) + ) + + The above query, linking A.bs via "outer" join and B.cs via "inner" + join would render the joins as "a LEFT OUTER JOIN (b JOIN c)". When + using older versions of SQLite (< 3.7.16), this form of JOIN is + translated to use full subqueries as this syntax is otherwise not + directly supported. + + The ``innerjoin`` flag can also be stated with the term ``"unnested"``. + This indicates that an INNER JOIN should be used, *unless* the join + is linked to a LEFT OUTER JOIN to the left, in which case it + will render as LEFT OUTER JOIN. For example, supposing ``A.bs`` + is an outerjoin:: + + select(A).options( + joinedload(A.bs).joinedload(B.cs, innerjoin="unnested") + ) + + + The above join will render as "a LEFT OUTER JOIN b LEFT OUTER JOIN c", + rather than as "a LEFT OUTER JOIN (b JOIN c)". + + .. note:: The "unnested" flag does **not** affect the JOIN rendered + from a many-to-many association table, e.g. a table configured as + :paramref:`_orm.relationship.secondary`, to the target table; for + correctness of results, these joins are always INNER and are + therefore right-nested if linked to an OUTER join. + + .. note:: + + The joins produced by :func:`_orm.joinedload` are **anonymously + aliased**. The criteria by which the join proceeds cannot be + modified, nor can the ORM-enabled :class:`_sql.Select` or legacy + :class:`_query.Query` refer to these joins in any way, including + ordering. See :ref:`zen_of_eager_loading` for further detail. + + To produce a specific SQL JOIN which is explicitly available, use + :meth:`_sql.Select.join` and :meth:`_query.Query.join`. To combine + explicit JOINs with eager loading of collections, use + :func:`_orm.contains_eager`; see :ref:`contains_eager`. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`joined_eager_loading` + + """ + loader = self._set_relationship_strategy( + attr, + {"lazy": "joined"}, + opts=( + {"innerjoin": innerjoin} + if innerjoin is not None + else util.EMPTY_DICT + ), + ) + return loader + + def subqueryload(self, attr: _AttrType) -> Self: + """Indicate that the given attribute should be loaded using + subquery eager loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # subquery-load the "orders" collection on "User" + select(User).options(subqueryload(User.orders)) + + # subquery-load Order.items and then Item.keywords + select(Order).options( + subqueryload(Order.items).subqueryload(Item.keywords) + ) + + # lazily load Order.items, but when Items are loaded, + # subquery-load the keywords collection + select(Order).options( + lazyload(Order.items).subqueryload(Item.keywords) + ) + + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`subquery_eager_loading` + + """ + return self._set_relationship_strategy(attr, {"lazy": "subquery"}) + + def selectinload( + self, + attr: _AttrType, + recursion_depth: Optional[int] = None, + ) -> Self: + """Indicate that the given attribute should be loaded using + SELECT IN eager loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + examples:: + + # selectin-load the "orders" collection on "User" + select(User).options(selectinload(User.orders)) + + # selectin-load Order.items and then Item.keywords + select(Order).options( + selectinload(Order.items).selectinload(Item.keywords) + ) + + # lazily load Order.items, but when Items are loaded, + # selectin-load the keywords collection + select(Order).options( + lazyload(Order.items).selectinload(Item.keywords) + ) + + :param recursion_depth: optional int; when set to a positive integer + in conjunction with a self-referential relationship, + indicates "selectin" loading will continue that many levels deep + automatically until no items are found. + + .. note:: The :paramref:`_orm.selectinload.recursion_depth` option + currently supports only self-referential relationships. There + is not yet an option to automatically traverse recursive structures + with more than one relationship involved. + + Additionally, the :paramref:`_orm.selectinload.recursion_depth` + parameter is new and experimental and should be treated as "alpha" + status for the 2.0 series. + + .. versionadded:: 2.0 added + :paramref:`_orm.selectinload.recursion_depth` + + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`selectin_eager_loading` + + """ + return self._set_relationship_strategy( + attr, + {"lazy": "selectin"}, + opts={"recursion_depth": recursion_depth}, + ) + + def lazyload(self, attr: _AttrType) -> Self: + """Indicate that the given attribute should be loaded using "lazy" + loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`lazy_loading` + + """ + return self._set_relationship_strategy(attr, {"lazy": "select"}) + + def immediateload( + self, + attr: _AttrType, + recursion_depth: Optional[int] = None, + ) -> Self: + """Indicate that the given attribute should be loaded using + an immediate load with a per-attribute SELECT statement. + + The load is achieved using the "lazyloader" strategy and does not + fire off any additional eager loaders. + + The :func:`.immediateload` option is superseded in general + by the :func:`.selectinload` option, which performs the same task + more efficiently by emitting a SELECT for all loaded objects. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + :param recursion_depth: optional int; when set to a positive integer + in conjunction with a self-referential relationship, + indicates "selectin" loading will continue that many levels deep + automatically until no items are found. + + .. note:: The :paramref:`_orm.immediateload.recursion_depth` option + currently supports only self-referential relationships. There + is not yet an option to automatically traverse recursive structures + with more than one relationship involved. + + .. warning:: This parameter is new and experimental and should be + treated as "alpha" status + + .. versionadded:: 2.0 added + :paramref:`_orm.immediateload.recursion_depth` + + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`selectin_eager_loading` + + """ + loader = self._set_relationship_strategy( + attr, + {"lazy": "immediate"}, + opts={"recursion_depth": recursion_depth}, + ) + return loader + + def noload(self, attr: _AttrType) -> Self: + """Indicate that the given relationship attribute should remain + unloaded. + + The relationship attribute will return ``None`` when accessed without + producing any loading effect. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + :func:`_orm.noload` applies to :func:`_orm.relationship` attributes + only. + + .. note:: Setting this loading strategy as the default strategy + for a relationship using the :paramref:`.orm.relationship.lazy` + parameter may cause issues with flushes, such if a delete operation + needs to load related objects and instead ``None`` was returned. + + .. seealso:: + + :ref:`loading_toplevel` + + """ + + return self._set_relationship_strategy(attr, {"lazy": "noload"}) + + def raiseload(self, attr: _AttrType, sql_only: bool = False) -> Self: + """Indicate that the given attribute should raise an error if accessed. + + A relationship attribute configured with :func:`_orm.raiseload` will + raise an :exc:`~sqlalchemy.exc.InvalidRequestError` upon access. The + typical way this is useful is when an application is attempting to + ensure that all relationship attributes that are accessed in a + particular context would have been already loaded via eager loading. + Instead of having to read through SQL logs to ensure lazy loads aren't + occurring, this strategy will cause them to raise immediately. + + :func:`_orm.raiseload` applies to :func:`_orm.relationship` attributes + only. In order to apply raise-on-SQL behavior to a column-based + attribute, use the :paramref:`.orm.defer.raiseload` parameter on the + :func:`.defer` loader option. + + :param sql_only: if True, raise only if the lazy load would emit SQL, + but not if it is only checking the identity map, or determining that + the related value should just be None due to missing keys. When False, + the strategy will raise for all varieties of relationship loading. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + .. seealso:: + + :ref:`loading_toplevel` + + :ref:`prevent_lazy_with_raiseload` + + :ref:`orm_queryguide_deferred_raiseload` + + """ + + return self._set_relationship_strategy( + attr, {"lazy": "raise_on_sql" if sql_only else "raise"} + ) + + def defaultload(self, attr: _AttrType) -> Self: + """Indicate an attribute should load using its predefined loader style. + + The behavior of this loading option is to not change the current + loading style of the attribute, meaning that the previously configured + one is used or, if no previous style was selected, the default + loading will be used. + + This method is used to link to other loader options further into + a chain of attributes without altering the loader style of the links + along the chain. For example, to set joined eager loading for an + element of an element:: + + session.query(MyClass).options( + defaultload(MyClass.someattribute).joinedload( + MyOtherClass.someotherattribute + ) + ) + + :func:`.defaultload` is also useful for setting column-level options on + a related class, namely that of :func:`.defer` and :func:`.undefer`:: + + session.scalars( + select(MyClass).options( + defaultload(MyClass.someattribute) + .defer("some_column") + .undefer("some_other_column") + ) + ) + + .. seealso:: + + :ref:`orm_queryguide_relationship_sub_options` + + :meth:`_orm.Load.options` + + """ + return self._set_relationship_strategy(attr, None) + + def defer(self, key: _AttrType, raiseload: bool = False) -> Self: + r"""Indicate that the given column-oriented attribute should be + deferred, e.g. not loaded until accessed. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + e.g.:: + + from sqlalchemy.orm import defer + + session.query(MyClass).options( + defer(MyClass.attribute_one), + defer(MyClass.attribute_two) + ) + + To specify a deferred load of an attribute on a related class, + the path can be specified one token at a time, specifying the loading + style for each link along the chain. To leave the loading style + for a link unchanged, use :func:`_orm.defaultload`:: + + session.query(MyClass).options( + defaultload(MyClass.someattr).defer(RelatedClass.some_column) + ) + + Multiple deferral options related to a relationship can be bundled + at once using :meth:`_orm.Load.options`:: + + + select(MyClass).options( + defaultload(MyClass.someattr).options( + defer(RelatedClass.some_column), + defer(RelatedClass.some_other_column), + defer(RelatedClass.another_column) + ) + ) + + :param key: Attribute to be deferred. + + :param raiseload: raise :class:`.InvalidRequestError` rather than + lazy loading a value when the deferred attribute is accessed. Used + to prevent unwanted SQL from being emitted. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :func:`_orm.load_only` + + :func:`_orm.undefer` + + """ + strategy = {"deferred": True, "instrument": True} + if raiseload: + strategy["raiseload"] = True + return self._set_column_strategy((key,), strategy) + + def undefer(self, key: _AttrType) -> Self: + r"""Indicate that the given column-oriented attribute should be + undeferred, e.g. specified within the SELECT statement of the entity + as a whole. + + The column being undeferred is typically set up on the mapping as a + :func:`.deferred` attribute. + + This function is part of the :class:`_orm.Load` interface and supports + both method-chained and standalone operation. + + Examples:: + + # undefer two columns + session.query(MyClass).options( + undefer(MyClass.col1), undefer(MyClass.col2) + ) + + # undefer all columns specific to a single class using Load + * + session.query(MyClass, MyOtherClass).options( + Load(MyClass).undefer("*") + ) + + # undefer a column on a related object + select(MyClass).options( + defaultload(MyClass.items).undefer(MyClass.text) + ) + + :param key: Attribute to be undeferred. + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :func:`_orm.defer` + + :func:`_orm.undefer_group` + + """ + return self._set_column_strategy( + (key,), {"deferred": False, "instrument": True} + ) + + def undefer_group(self, name: str) -> Self: + """Indicate that columns within the given deferred group name should be + undeferred. + + The columns being undeferred are set up on the mapping as + :func:`.deferred` attributes and include a "group" name. + + E.g:: + + session.query(MyClass).options(undefer_group("large_attrs")) + + To undefer a group of attributes on a related entity, the path can be + spelled out using relationship loader options, such as + :func:`_orm.defaultload`:: + + select(MyClass).options( + defaultload("someattr").undefer_group("large_attrs") + ) + + .. seealso:: + + :ref:`orm_queryguide_column_deferral` - in the + :ref:`queryguide_toplevel` + + :func:`_orm.defer` + + :func:`_orm.undefer` + + """ + return self._set_column_strategy( + (_WILDCARD_TOKEN,), None, {f"undefer_group_{name}": True} + ) + + def with_expression( + self, + key: _AttrType, + expression: _ColumnExpressionArgument[Any], + ) -> Self: + r"""Apply an ad-hoc SQL expression to a "deferred expression" + attribute. + + This option is used in conjunction with the + :func:`_orm.query_expression` mapper-level construct that indicates an + attribute which should be the target of an ad-hoc SQL expression. + + E.g.:: + + stmt = select(SomeClass).options( + with_expression(SomeClass.x_y_expr, SomeClass.x + SomeClass.y) + ) + + .. versionadded:: 1.2 + + :param key: Attribute to be populated + + :param expr: SQL expression to be applied to the attribute. + + .. seealso:: + + :ref:`orm_queryguide_with_expression` - background and usage + examples + + """ + + expression = _orm_full_deannotate( + coercions.expect(roles.LabeledColumnExprRole, expression) + ) + + return self._set_column_strategy( + (key,), {"query_expression": True}, extra_criteria=(expression,) + ) + + def selectin_polymorphic(self, classes: Iterable[Type[Any]]) -> Self: + """Indicate an eager load should take place for all attributes + specific to a subclass. + + This uses an additional SELECT with IN against all matched primary + key values, and is the per-query analogue to the ``"selectin"`` + setting on the :paramref:`.mapper.polymorphic_load` parameter. + + .. versionadded:: 1.2 + + .. seealso:: + + :ref:`polymorphic_selectin` + + """ + self = self._set_class_strategy( + {"selectinload_polymorphic": True}, + opts={ + "entities": tuple( + sorted((inspect(cls) for cls in classes), key=id) + ) + }, + ) + return self + + @overload + def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: ... + + @overload + def _coerce_strat(self, strategy: Literal[None]) -> None: ... + + def _coerce_strat( + self, strategy: Optional[_StrategySpec] + ) -> Optional[_StrategyKey]: + if strategy is not None: + strategy_key = tuple(sorted(strategy.items())) + else: + strategy_key = None + return strategy_key + + @_generative + def _set_relationship_strategy( + self, + attr: _AttrType, + strategy: Optional[_StrategySpec], + propagate_to_loaders: bool = True, + opts: Optional[_OptsType] = None, + _reconcile_to_other: Optional[bool] = None, + ) -> Self: + strategy_key = self._coerce_strat(strategy) + + self._clone_for_bind_strategy( + (attr,), + strategy_key, + _RELATIONSHIP_TOKEN, + opts=opts, + propagate_to_loaders=propagate_to_loaders, + reconcile_to_other=_reconcile_to_other, + ) + return self + + @_generative + def _set_column_strategy( + self, + attrs: Tuple[_AttrType, ...], + strategy: Optional[_StrategySpec], + opts: Optional[_OptsType] = None, + extra_criteria: Optional[Tuple[Any, ...]] = None, + ) -> Self: + strategy_key = self._coerce_strat(strategy) + + self._clone_for_bind_strategy( + attrs, + strategy_key, + _COLUMN_TOKEN, + opts=opts, + attr_group=attrs, + extra_criteria=extra_criteria, + ) + return self + + @_generative + def _set_generic_strategy( + self, + attrs: Tuple[_AttrType, ...], + strategy: _StrategySpec, + _reconcile_to_other: Optional[bool] = None, + ) -> Self: + strategy_key = self._coerce_strat(strategy) + self._clone_for_bind_strategy( + attrs, + strategy_key, + None, + propagate_to_loaders=True, + reconcile_to_other=_reconcile_to_other, + ) + return self + + @_generative + def _set_class_strategy( + self, strategy: _StrategySpec, opts: _OptsType + ) -> Self: + strategy_key = self._coerce_strat(strategy) + + self._clone_for_bind_strategy(None, strategy_key, None, opts=opts) + return self + + def _apply_to_parent(self, parent: Load) -> None: + """apply this :class:`_orm._AbstractLoad` object as a sub-option o + a :class:`_orm.Load` object. + + Implementation is provided by subclasses. + + """ + raise NotImplementedError() + + def options(self, *opts: _AbstractLoad) -> Self: + r"""Apply a series of options as sub-options to this + :class:`_orm._AbstractLoad` object. + + Implementation is provided by subclasses. + + """ + raise NotImplementedError() + + def _clone_for_bind_strategy( + self, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + extra_criteria: Optional[Tuple[Any, ...]] = None, + ) -> Self: + raise NotImplementedError() + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + ) -> None: + if not compile_state.compile_options._enable_eagerloads: + return + + # process is being run here so that the options given are validated + # against what the lead entities were, as well as to accommodate + # for the entities having been replaced with equivalents + self._process( + compile_state, + mapper_entities, + not bool(compile_state.current_path), + ) + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + if not compile_state.compile_options._enable_eagerloads: + return + + self._process( + compile_state, + compile_state._lead_mapper_entities, + not bool(compile_state.current_path) + and not compile_state.compile_options._for_refresh_state, + ) + + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: + """implemented by subclasses""" + raise NotImplementedError() + + @classmethod + def _chop_path( + cls, + to_chop: _PathRepresentation, + path: PathRegistry, + debug: bool = False, + ) -> Optional[_PathRepresentation]: + i = -1 + + for i, (c_token, p_token) in enumerate( + zip(to_chop, path.natural_path) + ): + if isinstance(c_token, str): + if i == 0 and ( + c_token.endswith(f":{_DEFAULT_TOKEN}") + or c_token.endswith(f":{_WILDCARD_TOKEN}") + ): + return to_chop + elif ( + c_token != f"{_RELATIONSHIP_TOKEN}:{_WILDCARD_TOKEN}" + and c_token != p_token.key # type: ignore + ): + return None + + if c_token is p_token: + continue + elif ( + isinstance(c_token, InspectionAttr) + and insp_is_mapper(c_token) + and insp_is_mapper(p_token) + and c_token.isa(p_token) + ): + continue + + else: + return None + return to_chop[i + 1 :] + + +class Load(_AbstractLoad): + """Represents loader options which modify the state of a + ORM-enabled :class:`_sql.Select` or a legacy :class:`_query.Query` in + order to affect how various mapped attributes are loaded. + + The :class:`_orm.Load` object is in most cases used implicitly behind the + scenes when one makes use of a query option like :func:`_orm.joinedload`, + :func:`_orm.defer`, or similar. It typically is not instantiated directly + except for in some very specific cases. + + .. seealso:: + + :ref:`orm_queryguide_relationship_per_entity_wildcard` - illustrates an + example where direct use of :class:`_orm.Load` may be useful + + """ + + __slots__ = ( + "path", + "context", + "additional_source_entities", + ) + + _traverse_internals = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ( + "context", + visitors.InternalTraversal.dp_has_cache_key_list, + ), + ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), + ( + "additional_source_entities", + visitors.InternalTraversal.dp_has_cache_key_list, + ), + ] + _cache_key_traversal = None + + path: PathRegistry + context: Tuple[_LoadElement, ...] + additional_source_entities: Tuple[_InternalEntityType[Any], ...] + + def __init__(self, entity: _EntityType[Any]): + insp = cast("Union[Mapper[Any], AliasedInsp[Any]]", inspect(entity)) + insp._post_inspect + + self.path = insp._path_registry + self.context = () + self.propagate_to_loaders = False + self.additional_source_entities = () + + def __str__(self) -> str: + return f"Load({self.path[0]})" + + @classmethod + def _construct_for_existing_path( + cls, path: AbstractEntityRegistry + ) -> Load: + load = cls.__new__(cls) + load.path = path + load.context = () + load.propagate_to_loaders = False + load.additional_source_entities = () + return load + + def _adapt_cached_option_to_uncached_option( + self, context: QueryContext, uncached_opt: ORMOption + ) -> ORMOption: + if uncached_opt is self: + return self + return self._adjust_for_extra_criteria(context) + + def _prepend_path(self, path: PathRegistry) -> Load: + cloned = self._clone() + cloned.context = tuple( + element._prepend_path(path) for element in self.context + ) + return cloned + + def _adjust_for_extra_criteria(self, context: QueryContext) -> Load: + """Apply the current bound parameters in a QueryContext to all + occurrences "extra_criteria" stored within this ``Load`` object, + returning a new instance of this ``Load`` object. + + """ + + # avoid generating cache keys for the queries if we don't + # actually have any extra_criteria options, which is the + # common case + for value in self.context: + if value._extra_criteria: + break + else: + return self + + replacement_cache_key = context.query._generate_cache_key() + + if replacement_cache_key is None: + return self + + orig_query = context.compile_state.select_statement + orig_cache_key = orig_query._generate_cache_key() + assert orig_cache_key is not None + + def process( + opt: _LoadElement, + replacement_cache_key: CacheKey, + orig_cache_key: CacheKey, + ) -> _LoadElement: + cloned_opt = opt._clone() + + cloned_opt._extra_criteria = tuple( + replacement_cache_key._apply_params_to_element( + orig_cache_key, crit + ) + for crit in cloned_opt._extra_criteria + ) + + return cloned_opt + + cloned = self._clone() + cloned.context = tuple( + ( + process(value, replacement_cache_key, orig_cache_key) + if value._extra_criteria + else value + ) + for value in self.context + ) + return cloned + + def _reconcile_query_entities_with_us(self, mapper_entities, raiseerr): + """called at process time to allow adjustment of the root + entity inside of _LoadElement objects. + + """ + path = self.path + + ezero = None + for ent in mapper_entities: + ezero = ent.entity_zero + if ezero and orm_util._entity_corresponds_to( + # technically this can be a token also, but this is + # safe to pass to _entity_corresponds_to() + ezero, + cast("_InternalEntityType[Any]", path[0]), + ): + return ezero + + return None + + def _process( + self, + compile_state: ORMCompileState, + mapper_entities: Sequence[_MapperEntity], + raiseerr: bool, + ) -> None: + reconciled_lead_entity = self._reconcile_query_entities_with_us( + mapper_entities, raiseerr + ) + + for loader in self.context: + loader.process_compile_state( + self, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ) + + def _apply_to_parent(self, parent: Load) -> None: + """apply this :class:`_orm.Load` object as a sub-option of another + :class:`_orm.Load` object. + + This method is used by the :meth:`_orm.Load.options` method. + + """ + cloned = self._generate() + + assert cloned.propagate_to_loaders == self.propagate_to_loaders + + if not any( + orm_util._entity_corresponds_to_use_path_impl( + elem, cloned.path.odd_element(0) + ) + for elem in (parent.path.odd_element(-1),) + + parent.additional_source_entities + ): + if len(cloned.path) > 1: + attrname = cloned.path[1] + parent_entity = cloned.path[0] + else: + attrname = cloned.path[0] + parent_entity = cloned.path[0] + _raise_for_does_not_link(parent.path, attrname, parent_entity) + + cloned.path = PathRegistry.coerce(parent.path[0:-1] + cloned.path[:]) + + if self.context: + cloned.context = tuple( + value._prepend_path_from(parent) for value in self.context + ) + + if cloned.context: + parent.context += cloned.context + parent.additional_source_entities += ( + cloned.additional_source_entities + ) + + @_generative + def options(self, *opts: _AbstractLoad) -> Self: + r"""Apply a series of options as sub-options to this + :class:`_orm.Load` + object. + + E.g.:: + + query = session.query(Author) + query = query.options( + joinedload(Author.book).options( + load_only(Book.summary, Book.excerpt), + joinedload(Book.citations).options( + joinedload(Citation.author) + ) + ) + ) + + :param \*opts: A series of loader option objects (ultimately + :class:`_orm.Load` objects) which should be applied to the path + specified by this :class:`_orm.Load` object. + + .. versionadded:: 1.3.6 + + .. seealso:: + + :func:`.defaultload` + + :ref:`orm_queryguide_relationship_sub_options` + + """ + for opt in opts: + try: + opt._apply_to_parent(self) + except AttributeError as ae: + if not isinstance(opt, _AbstractLoad): + raise sa_exc.ArgumentError( + f"Loader option {opt} is not compatible with the " + "Load.options() method." + ) from ae + else: + raise + return self + + def _clone_for_bind_strategy( + self, + attrs: Optional[Tuple[_AttrType, ...]], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + opts: Optional[_OptsType] = None, + attr_group: Optional[_AttrGroupType] = None, + propagate_to_loaders: bool = True, + reconcile_to_other: Optional[bool] = None, + extra_criteria: Optional[Tuple[Any, ...]] = None, + ) -> Self: + # for individual strategy that needs to propagate, set the whole + # Load container to also propagate, so that it shows up in + # InstanceState.load_options + if propagate_to_loaders: + self.propagate_to_loaders = True + + if self.path.is_token: + raise sa_exc.ArgumentError( + "Wildcard token cannot be followed by another entity" + ) + + elif path_is_property(self.path): + # re-use the lookup which will raise a nicely formatted + # LoaderStrategyException + if strategy: + self.path.prop._strategy_lookup(self.path.prop, strategy[0]) + else: + raise sa_exc.ArgumentError( + f"Mapped attribute '{self.path.prop}' does not " + "refer to a mapped entity" + ) + + if attrs is None: + load_element = _ClassStrategyLoad.create( + self.path, + None, + strategy, + wildcard_key, + opts, + propagate_to_loaders, + attr_group=attr_group, + reconcile_to_other=reconcile_to_other, + extra_criteria=extra_criteria, + ) + if load_element: + self.context += (load_element,) + assert opts is not None + self.additional_source_entities += cast( + "Tuple[_InternalEntityType[Any]]", opts["entities"] + ) + + else: + for attr in attrs: + if isinstance(attr, str): + load_element = _TokenStrategyLoad.create( + self.path, + attr, + strategy, + wildcard_key, + opts, + propagate_to_loaders, + attr_group=attr_group, + reconcile_to_other=reconcile_to_other, + extra_criteria=extra_criteria, + ) + else: + load_element = _AttributeStrategyLoad.create( + self.path, + attr, + strategy, + wildcard_key, + opts, + propagate_to_loaders, + attr_group=attr_group, + reconcile_to_other=reconcile_to_other, + extra_criteria=extra_criteria, + ) + + if load_element: + # for relationship options, update self.path on this Load + # object with the latest path. + if wildcard_key is _RELATIONSHIP_TOKEN: + self.path = load_element.path + self.context += (load_element,) + + # this seems to be effective for selectinloader, + # giving the extra match to one more level deep. + # but does not work for immediateloader, which still + # must add additional options at load time + if load_element.local_opts.get("recursion_depth", False): + r1 = load_element._recurse() + self.context += (r1,) + + return self + + def __getstate__(self): + d = self._shallow_to_dict() + d["path"] = self.path.serialize() + return d + + def __setstate__(self, state): + state["path"] = PathRegistry.deserialize(state["path"]) + self._shallow_from_dict(state) + + +class _WildcardLoad(_AbstractLoad): + """represent a standalone '*' load operation""" + + __slots__ = ("strategy", "path", "local_opts") + + _traverse_internals = [ + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("path", visitors.ExtendedInternalTraversal.dp_plain_obj), + ( + "local_opts", + visitors.ExtendedInternalTraversal.dp_string_multi_dict, + ), + ] + cache_key_traversal: _CacheKeyTraversalType = None + + strategy: Optional[Tuple[Any, ...]] + local_opts: _OptsType + path: Union[Tuple[()], Tuple[str]] + propagate_to_loaders = False + + def __init__(self) -> None: + self.path = () + self.strategy = None + self.local_opts = util.EMPTY_DICT + + def _clone_for_bind_strategy( + self, + attrs, + strategy, + wildcard_key, + opts=None, + attr_group=None, + propagate_to_loaders=True, + reconcile_to_other=None, + extra_criteria=None, + ): + assert attrs is not None + attr = attrs[0] + assert ( + wildcard_key + and isinstance(attr, str) + and attr in (_WILDCARD_TOKEN, _DEFAULT_TOKEN) + ) + + attr = f"{wildcard_key}:{attr}" + + self.strategy = strategy + self.path = (attr,) + if opts: + self.local_opts = util.immutabledict(opts) + + assert extra_criteria is None + + def options(self, *opts: _AbstractLoad) -> Self: + raise NotImplementedError("Star option does not support sub-options") + + def _apply_to_parent(self, parent: Load) -> None: + """apply this :class:`_orm._WildcardLoad` object as a sub-option of + a :class:`_orm.Load` object. + + This method is used by the :meth:`_orm.Load.options` method. Note + that :class:`_orm.WildcardLoad` itself can't have sub-options, but + it may be used as the sub-option of a :class:`_orm.Load` object. + + """ + assert self.path + attr = self.path[0] + if attr.endswith(_DEFAULT_TOKEN): + attr = f"{attr.split(':')[0]}:{_WILDCARD_TOKEN}" + + effective_path = cast(AbstractEntityRegistry, parent.path).token(attr) + + assert effective_path.is_token + + loader = _TokenStrategyLoad.create( + effective_path, + None, + self.strategy, + None, + self.local_opts, + self.propagate_to_loaders, + ) + + parent.context += (loader,) + + def _process(self, compile_state, mapper_entities, raiseerr): + is_refresh = compile_state.compile_options._for_refresh_state + + if is_refresh and not self.propagate_to_loaders: + return + + entities = [ent.entity_zero for ent in mapper_entities] + current_path = compile_state.current_path + + start_path: _PathRepresentation = self.path + + if current_path: + # TODO: no cases in test suite where we actually get + # None back here + new_path = self._chop_path(start_path, current_path) + if new_path is None: + return + + # chop_path does not actually "chop" a wildcard token path, + # just returns it + assert new_path == start_path + + # start_path is a single-token tuple + assert start_path and len(start_path) == 1 + + token = start_path[0] + assert isinstance(token, str) + entity = self._find_entity_basestring(entities, token, raiseerr) + + if not entity: + return + + path_element = entity + + # transfer our entity-less state into a Load() object + # with a real entity path. Start with the lead entity + # we just located, then go through the rest of our path + # tokens and populate into the Load(). + + assert isinstance(token, str) + loader = _TokenStrategyLoad.create( + path_element._path_registry, + token, + self.strategy, + None, + self.local_opts, + self.propagate_to_loaders, + raiseerr=raiseerr, + ) + if not loader: + return + + assert loader.path.is_token + + # don't pass a reconciled lead entity here + loader.process_compile_state( + self, compile_state, mapper_entities, None, raiseerr + ) + + return loader + + def _find_entity_basestring( + self, + entities: Iterable[_InternalEntityType[Any]], + token: str, + raiseerr: bool, + ) -> Optional[_InternalEntityType[Any]]: + if token.endswith(f":{_WILDCARD_TOKEN}"): + if len(list(entities)) != 1: + if raiseerr: + raise sa_exc.ArgumentError( + "Can't apply wildcard ('*') or load_only() " + f"loader option to multiple entities " + f"{', '.join(str(ent) for ent in entities)}. Specify " + "loader options for each entity individually, such as " + f"""{ + ", ".join( + f"Load({ent}).some_option('*')" + for ent in entities + ) + }.""" + ) + elif token.endswith(_DEFAULT_TOKEN): + raiseerr = False + + for ent in entities: + # return only the first _MapperEntity when searching + # based on string prop name. Ideally object + # attributes are used to specify more exactly. + return ent + else: + if raiseerr: + raise sa_exc.ArgumentError( + "Query has only expression-based entities - " + f'can\'t find property named "{token}".' + ) + else: + return None + + def __getstate__(self) -> Dict[str, Any]: + d = self._shallow_to_dict() + return d + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._shallow_from_dict(state) + + +class _LoadElement( + cache_key.HasCacheKey, traversals.HasShallowCopy, visitors.Traversible +): + """represents strategy information to select for a LoaderStrategy + and pass options to it. + + :class:`._LoadElement` objects provide the inner datastructure + stored by a :class:`_orm.Load` object and are also the object passed + to methods like :meth:`.LoaderStrategy.setup_query`. + + .. versionadded:: 2.0 + + """ + + __slots__ = ( + "path", + "strategy", + "propagate_to_loaders", + "local_opts", + "_extra_criteria", + "_reconcile_to_other", + ) + __visit_name__ = "load_element" + + _traverse_internals = [ + ("path", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("strategy", visitors.ExtendedInternalTraversal.dp_plain_obj), + ( + "local_opts", + visitors.ExtendedInternalTraversal.dp_string_multi_dict, + ), + ("_extra_criteria", visitors.InternalTraversal.dp_clauseelement_list), + ("propagate_to_loaders", visitors.InternalTraversal.dp_plain_obj), + ("_reconcile_to_other", visitors.InternalTraversal.dp_plain_obj), + ] + _cache_key_traversal = None + + _extra_criteria: Tuple[Any, ...] + + _reconcile_to_other: Optional[bool] + strategy: Optional[_StrategyKey] + path: PathRegistry + propagate_to_loaders: bool + + local_opts: util.immutabledict[str, Any] + + is_token_strategy: bool + is_class_strategy: bool + + def __hash__(self) -> int: + return id(self) + + def __eq__(self, other): + return traversals.compare(self, other) + + @property + def is_opts_only(self) -> bool: + return bool(self.local_opts and self.strategy is None) + + def _clone(self, **kw: Any) -> _LoadElement: + cls = self.__class__ + s = cls.__new__(cls) + + self._shallow_copy_to(s) + return s + + def _update_opts(self, **kw: Any) -> _LoadElement: + new = self._clone() + new.local_opts = new.local_opts.union(kw) + return new + + def __getstate__(self) -> Dict[str, Any]: + d = self._shallow_to_dict() + d["path"] = self.path.serialize() + return d + + def __setstate__(self, state: Dict[str, Any]) -> None: + state["path"] = PathRegistry.deserialize(state["path"]) + self._shallow_from_dict(state) + + def _raise_for_no_match(self, parent_loader, mapper_entities): + path = parent_loader.path + + found_entities = False + for ent in mapper_entities: + ezero = ent.entity_zero + if ezero: + found_entities = True + break + + if not found_entities: + raise sa_exc.ArgumentError( + "Query has only expression-based entities; " + f"attribute loader options for {path[0]} can't " + "be applied here." + ) + else: + raise sa_exc.ArgumentError( + f"Mapped class {path[0]} does not apply to any of the " + f"root entities in this query, e.g. " + f"""{ + ", ".join( + str(x.entity_zero) + for x in mapper_entities if x.entity_zero + )}. Please """ + "specify the full path " + "from one of the root entities to the target " + "attribute. " + ) + + def _adjust_effective_path_for_current_path( + self, effective_path: PathRegistry, current_path: PathRegistry + ) -> Optional[PathRegistry]: + """receives the 'current_path' entry from an :class:`.ORMCompileState` + instance, which is set during lazy loads and secondary loader strategy + loads, and adjusts the given path to be relative to the + current_path. + + E.g. given a loader path and current path:: + + lp: User -> orders -> Order -> items -> Item -> keywords -> Keyword + + cp: User -> orders -> Order -> items + + The adjusted path would be:: + + Item -> keywords -> Keyword + + + """ + chopped_start_path = Load._chop_path( + effective_path.natural_path, current_path + ) + if not chopped_start_path: + return None + + tokens_removed_from_start_path = len(effective_path) - len( + chopped_start_path + ) + + loader_lead_path_element = self.path[tokens_removed_from_start_path] + + effective_path = PathRegistry.coerce( + (loader_lead_path_element,) + chopped_start_path[1:] + ) + + return effective_path + + def _init_path( + self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria + ): + """Apply ORM attributes and/or wildcard to an existing path, producing + a new path. + + This method is used within the :meth:`.create` method to initialize + a :class:`._LoadElement` object. + + """ + raise NotImplementedError() + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + """implemented by subclasses.""" + raise NotImplementedError() + + def process_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + """populate ORMCompileState.attributes with loader state for this + _LoadElement. + + """ + keys = self._prepare_for_compile_state( + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ) + for key in keys: + if key in compile_state.attributes: + compile_state.attributes[key] = _LoadElement._reconcile( + self, compile_state.attributes[key] + ) + else: + compile_state.attributes[key] = self + + @classmethod + def create( + cls, + path: PathRegistry, + attr: Union[_AttrType, _StrPathToken, None], + strategy: Optional[_StrategyKey], + wildcard_key: Optional[_WildcardKeyType], + local_opts: Optional[_OptsType], + propagate_to_loaders: bool, + raiseerr: bool = True, + attr_group: Optional[_AttrGroupType] = None, + reconcile_to_other: Optional[bool] = None, + extra_criteria: Optional[Tuple[Any, ...]] = None, + ) -> _LoadElement: + """Create a new :class:`._LoadElement` object.""" + + opt = cls.__new__(cls) + opt.path = path + opt.strategy = strategy + opt.propagate_to_loaders = propagate_to_loaders + opt.local_opts = ( + util.immutabledict(local_opts) if local_opts else util.EMPTY_DICT + ) + opt._extra_criteria = () + + if reconcile_to_other is not None: + opt._reconcile_to_other = reconcile_to_other + elif strategy is None and not local_opts: + opt._reconcile_to_other = True + else: + opt._reconcile_to_other = None + + path = opt._init_path( + path, attr, wildcard_key, attr_group, raiseerr, extra_criteria + ) + + if not path: + return None # type: ignore + + assert opt.is_token_strategy == path.is_token + + opt.path = path + return opt + + def __init__(self) -> None: + raise NotImplementedError() + + def _recurse(self) -> _LoadElement: + cloned = self._clone() + cloned.path = PathRegistry.coerce(self.path[:] + self.path[-2:]) + + return cloned + + def _prepend_path_from(self, parent: Load) -> _LoadElement: + """adjust the path of this :class:`._LoadElement` to be + a subpath of that of the given parent :class:`_orm.Load` object's + path. + + This is used by the :meth:`_orm.Load._apply_to_parent` method, + which is in turn part of the :meth:`_orm.Load.options` method. + + """ + + if not any( + orm_util._entity_corresponds_to_use_path_impl( + elem, + self.path.odd_element(0), + ) + for elem in (parent.path.odd_element(-1),) + + parent.additional_source_entities + ): + raise sa_exc.ArgumentError( + f'Attribute "{self.path[1]}" does not link ' + f'from element "{parent.path[-1]}".' + ) + + return self._prepend_path(parent.path) + + def _prepend_path(self, path: PathRegistry) -> _LoadElement: + cloned = self._clone() + + assert cloned.strategy == self.strategy + assert cloned.local_opts == self.local_opts + assert cloned.is_class_strategy == self.is_class_strategy + + cloned.path = PathRegistry.coerce(path[0:-1] + cloned.path[:]) + + return cloned + + @staticmethod + def _reconcile( + replacement: _LoadElement, existing: _LoadElement + ) -> _LoadElement: + """define behavior for when two Load objects are to be put into + the context.attributes under the same key. + + :param replacement: ``_LoadElement`` that seeks to replace the + existing one + + :param existing: ``_LoadElement`` that is already present. + + """ + # mapper inheritance loading requires fine-grained "block other + # options" / "allow these options to be overridden" behaviors + # see test_poly_loading.py + + if replacement._reconcile_to_other: + return existing + elif replacement._reconcile_to_other is False: + return replacement + elif existing._reconcile_to_other: + return replacement + elif existing._reconcile_to_other is False: + return existing + + if existing is replacement: + return replacement + elif ( + existing.strategy == replacement.strategy + and existing.local_opts == replacement.local_opts + ): + return replacement + elif replacement.is_opts_only: + existing = existing._clone() + existing.local_opts = existing.local_opts.union( + replacement.local_opts + ) + existing._extra_criteria += replacement._extra_criteria + return existing + elif existing.is_opts_only: + replacement = replacement._clone() + replacement.local_opts = replacement.local_opts.union( + existing.local_opts + ) + replacement._extra_criteria += existing._extra_criteria + return replacement + elif replacement.path.is_token: + # use 'last one wins' logic for wildcard options. this is also + # kind of inconsistent vs. options that are specific paths which + # will raise as below + return replacement + + raise sa_exc.InvalidRequestError( + f"Loader strategies for {replacement.path} conflict" + ) + + +class _AttributeStrategyLoad(_LoadElement): + """Loader strategies against specific relationship or column paths. + + e.g.:: + + joinedload(User.addresses) + defer(Order.name) + selectinload(User.orders).lazyload(Order.items) + + """ + + __slots__ = ("_of_type", "_path_with_polymorphic_path") + + __visit_name__ = "attribute_strategy_load_element" + + _traverse_internals = _LoadElement._traverse_internals + [ + ("_of_type", visitors.ExtendedInternalTraversal.dp_multi), + ( + "_path_with_polymorphic_path", + visitors.ExtendedInternalTraversal.dp_has_cache_key, + ), + ] + + _of_type: Union[Mapper[Any], AliasedInsp[Any], None] + _path_with_polymorphic_path: Optional[PathRegistry] + + is_class_strategy = False + is_token_strategy = False + + def _init_path( + self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria + ): + assert attr is not None + self._of_type = None + self._path_with_polymorphic_path = None + insp, _, prop = _parse_attr_argument(attr) + + if insp.is_property: + # direct property can be sent from internal strategy logic + # that sets up specific loaders, such as + # emit_lazyload->_lazyload_reverse + # prop = found_property = attr + prop = attr + path = path[prop] + + if path.has_entity: + path = path.entity_path + return path + + elif not insp.is_attribute: + # should not reach here; + assert False + + # here we assume we have user-passed InstrumentedAttribute + if not orm_util._entity_corresponds_to_use_path_impl( + path[-1], attr.parent + ): + if raiseerr: + if attr_group and attr is not attr_group[0]: + raise sa_exc.ArgumentError( + "Can't apply wildcard ('*') or load_only() " + "loader option to multiple entities in the " + "same option. Use separate options per entity." + ) + else: + _raise_for_does_not_link(path, str(attr), attr.parent) + else: + return None + + # note the essential logic of this attribute was very different in + # 1.4, where there were caching failures in e.g. + # test_relationship_criteria.py::RelationshipCriteriaTest:: + # test_selectinload_nested_criteria[True] if an existing + # "_extra_criteria" on a Load object were replaced with that coming + # from an attribute. This appears to have been an artifact of how + # _UnboundLoad / Load interacted together, which was opaque and + # poorly defined. + if extra_criteria: + assert not attr._extra_criteria + self._extra_criteria = extra_criteria + else: + self._extra_criteria = attr._extra_criteria + + if getattr(attr, "_of_type", None): + ac = attr._of_type + ext_info = inspect(ac) + self._of_type = ext_info + + self._path_with_polymorphic_path = path.entity_path[prop] + + path = path[prop][ext_info] + + else: + path = path[prop] + + if path.has_entity: + path = path.entity_path + + return path + + def _generate_extra_criteria(self, context): + """Apply the current bound parameters in a QueryContext to the + immediate "extra_criteria" stored with this Load object. + + Load objects are typically pulled from the cached version of + the statement from a QueryContext. The statement currently being + executed will have new values (and keys) for bound parameters in the + extra criteria which need to be applied by loader strategies when + they handle this criteria for a result set. + + """ + + assert ( + self._extra_criteria + ), "this should only be called if _extra_criteria is present" + + orig_query = context.compile_state.select_statement + current_query = context.query + + # NOTE: while it seems like we should not do the "apply" operation + # here if orig_query is current_query, skipping it in the "optimized" + # case causes the query to be different from a cache key perspective, + # because we are creating a copy of the criteria which is no longer + # the same identity of the _extra_criteria in the loader option + # itself. cache key logic produces a different key for + # (A, copy_of_A) vs. (A, A), because in the latter case it shortens + # the second part of the key to just indicate on identity. + + # if orig_query is current_query: + # not cached yet. just do the and_() + # return and_(*self._extra_criteria) + + k1 = orig_query._generate_cache_key() + k2 = current_query._generate_cache_key() + + return k2._apply_params_to_element(k1, and_(*self._extra_criteria)) + + def _set_of_type_info(self, context, current_path): + assert self._path_with_polymorphic_path + + pwpi = self._of_type + assert pwpi + if not pwpi.is_aliased_class: + pwpi = inspect( + orm_util.AliasedInsp._with_polymorphic_factory( + pwpi.mapper.base_mapper, + (pwpi.mapper,), + aliased=True, + _use_mapper_path=True, + ) + ) + start_path = self._path_with_polymorphic_path + if current_path: + new_path = self._adjust_effective_path_for_current_path( + start_path, current_path + ) + if new_path is None: + return + start_path = new_path + + key = ("path_with_polymorphic", start_path.natural_path) + if key in context: + existing_aliased_insp = context[key] + this_aliased_insp = pwpi + new_aliased_insp = existing_aliased_insp._merge_with( + this_aliased_insp + ) + context[key] = new_aliased_insp + else: + context[key] = pwpi + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + # _AttributeStrategyLoad + + current_path = compile_state.current_path + is_refresh = compile_state.compile_options._for_refresh_state + assert not self.path.is_token + + if is_refresh and not self.propagate_to_loaders: + return [] + + if self._of_type: + # apply additional with_polymorphic alias that may have been + # generated. this has to happen even if this is a defaultload + self._set_of_type_info(compile_state.attributes, current_path) + + # omit setting loader attributes for a "defaultload" type of option + if not self.strategy and not self.local_opts: + return [] + + if raiseerr and not reconciled_lead_entity: + self._raise_for_no_match(parent_loader, mapper_entities) + + if self.path.has_entity: + effective_path = self.path.parent + else: + effective_path = self.path + + if current_path: + assert effective_path is not None + effective_path = self._adjust_effective_path_for_current_path( + effective_path, current_path + ) + if effective_path is None: + return [] + + return [("loader", cast(PathRegistry, effective_path).natural_path)] + + def __getstate__(self): + d = super().__getstate__() + + # can't pickle this. See + # test_pickled.py -> test_lazyload_extra_criteria_not_supported + # where we should be emitting a warning for the usual case where this + # would be non-None + d["_extra_criteria"] = () + + if self._path_with_polymorphic_path: + d["_path_with_polymorphic_path"] = ( + self._path_with_polymorphic_path.serialize() + ) + + if self._of_type: + if self._of_type.is_aliased_class: + d["_of_type"] = None + elif self._of_type.is_mapper: + d["_of_type"] = self._of_type.class_ + else: + assert False, "unexpected object for _of_type" + + return d + + def __setstate__(self, state): + super().__setstate__(state) + + if state.get("_path_with_polymorphic_path", None): + self._path_with_polymorphic_path = PathRegistry.deserialize( + state["_path_with_polymorphic_path"] + ) + else: + self._path_with_polymorphic_path = None + + if state.get("_of_type", None): + self._of_type = inspect(state["_of_type"]) + else: + self._of_type = None + + +class _TokenStrategyLoad(_LoadElement): + """Loader strategies against wildcard attributes + + e.g.:: + + raiseload('*') + Load(User).lazyload('*') + defer('*') + load_only(User.name, User.email) # will create a defer('*') + joinedload(User.addresses).raiseload('*') + + """ + + __visit_name__ = "token_strategy_load_element" + + inherit_cache = True + is_class_strategy = False + is_token_strategy = True + + def _init_path( + self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria + ): + # assert isinstance(attr, str) or attr is None + if attr is not None: + default_token = attr.endswith(_DEFAULT_TOKEN) + if attr.endswith(_WILDCARD_TOKEN) or default_token: + if wildcard_key: + attr = f"{wildcard_key}:{attr}" + + path = path.token(attr) + return path + else: + raise sa_exc.ArgumentError( + "Strings are not accepted for attribute names in loader " + "options; please use class-bound attributes directly." + ) + return path + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + # _TokenStrategyLoad + + current_path = compile_state.current_path + is_refresh = compile_state.compile_options._for_refresh_state + + assert self.path.is_token + + if is_refresh and not self.propagate_to_loaders: + return [] + + # omit setting attributes for a "defaultload" type of option + if not self.strategy and not self.local_opts: + return [] + + effective_path = self.path + if reconciled_lead_entity: + effective_path = PathRegistry.coerce( + (reconciled_lead_entity,) + effective_path.path[1:] + ) + + if current_path: + new_effective_path = self._adjust_effective_path_for_current_path( + effective_path, current_path + ) + if new_effective_path is None: + return [] + effective_path = new_effective_path + + # for a wildcard token, expand out the path we set + # to encompass everything from the query entity on + # forward. not clear if this is necessary when current_path + # is set. + + return [ + ("loader", natural_path) + for natural_path in ( + cast( + TokenRegistry, effective_path + )._generate_natural_for_superclasses() + ) + ] + + +class _ClassStrategyLoad(_LoadElement): + """Loader strategies that deals with a class as a target, not + an attribute path + + e.g.:: + + q = s.query(Person).options( + selectin_polymorphic(Person, [Engineer, Manager]) + ) + + """ + + inherit_cache = True + is_class_strategy = True + is_token_strategy = False + + __visit_name__ = "class_strategy_load_element" + + def _init_path( + self, path, attr, wildcard_key, attr_group, raiseerr, extra_criteria + ): + return path + + def _prepare_for_compile_state( + self, + parent_loader, + compile_state, + mapper_entities, + reconciled_lead_entity, + raiseerr, + ): + # _ClassStrategyLoad + + current_path = compile_state.current_path + is_refresh = compile_state.compile_options._for_refresh_state + + if is_refresh and not self.propagate_to_loaders: + return [] + + # omit setting attributes for a "defaultload" type of option + if not self.strategy and not self.local_opts: + return [] + + effective_path = self.path + + if current_path: + new_effective_path = self._adjust_effective_path_for_current_path( + effective_path, current_path + ) + if new_effective_path is None: + return [] + effective_path = new_effective_path + + return [("loader", effective_path.natural_path)] + + +def _generate_from_keys( + meth: Callable[..., _AbstractLoad], + keys: Tuple[_AttrType, ...], + chained: bool, + kw: Any, +) -> _AbstractLoad: + lead_element: Optional[_AbstractLoad] = None + + attr: Any + for is_default, _keys in (True, keys[0:-1]), (False, keys[-1:]): + for attr in _keys: + if isinstance(attr, str): + if attr.startswith("." + _WILDCARD_TOKEN): + util.warn_deprecated( + "The undocumented `.{WILDCARD}` format is " + "deprecated " + "and will be removed in a future version as " + "it is " + "believed to be unused. " + "If you have been using this functionality, " + "please " + "comment on Issue #4390 on the SQLAlchemy project " + "tracker.", + version="1.4", + ) + attr = attr[1:] + + if attr == _WILDCARD_TOKEN: + if is_default: + raise sa_exc.ArgumentError( + "Wildcard token cannot be followed by " + "another entity", + ) + + if lead_element is None: + lead_element = _WildcardLoad() + + lead_element = meth(lead_element, _DEFAULT_TOKEN, **kw) + + else: + raise sa_exc.ArgumentError( + "Strings are not accepted for attribute names in " + "loader options; please use class-bound " + "attributes directly.", + ) + else: + if lead_element is None: + _, lead_entity, _ = _parse_attr_argument(attr) + lead_element = Load(lead_entity) + + if is_default: + if not chained: + lead_element = lead_element.defaultload(attr) + else: + lead_element = meth( + lead_element, attr, _is_chain=True, **kw + ) + else: + lead_element = meth(lead_element, attr, **kw) + + assert lead_element + return lead_element + + +def _parse_attr_argument( + attr: _AttrType, +) -> Tuple[InspectionAttr, _InternalEntityType[Any], MapperProperty[Any]]: + """parse an attribute or wildcard argument to produce an + :class:`._AbstractLoad` instance. + + This is used by the standalone loader strategy functions like + ``joinedload()``, ``defer()``, etc. to produce :class:`_orm.Load` or + :class:`._WildcardLoad` objects. + + """ + try: + # TODO: need to figure out this None thing being returned by + # inspect(), it should not have None as an option in most cases + # if at all + insp: InspectionAttr = inspect(attr) # type: ignore + except sa_exc.NoInspectionAvailable as err: + raise sa_exc.ArgumentError( + "expected ORM mapped attribute for loader strategy argument" + ) from err + + lead_entity: _InternalEntityType[Any] + + if insp_is_mapper_property(insp): + lead_entity = insp.parent + prop = insp + elif insp_is_attribute(insp): + lead_entity = insp.parent + prop = insp.prop + else: + raise sa_exc.ArgumentError( + "expected ORM mapped attribute for loader strategy argument" + ) + + return insp, lead_entity, prop + + +def loader_unbound_fn(fn: _FN) -> _FN: + """decorator that applies docstrings between standalone loader functions + and the loader methods on :class:`._AbstractLoad`. + + """ + bound_fn = getattr(_AbstractLoad, fn.__name__) + fn_doc = bound_fn.__doc__ + bound_fn.__doc__ = f"""Produce a new :class:`_orm.Load` object with the +:func:`_orm.{fn.__name__}` option applied. + +See :func:`_orm.{fn.__name__}` for usage examples. + +""" + + fn.__doc__ = fn_doc + return fn + + +# standalone functions follow. docstrings are filled in +# by the ``@loader_unbound_fn`` decorator. + + +@loader_unbound_fn +def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad: + return _generate_from_keys(Load.contains_eager, keys, True, kw) + + +@loader_unbound_fn +def load_only(*attrs: _AttrType, raiseload: bool = False) -> _AbstractLoad: + # TODO: attrs against different classes. we likely have to + # add some extra state to Load of some kind + _, lead_element, _ = _parse_attr_argument(attrs[0]) + return Load(lead_element).load_only(*attrs, raiseload=raiseload) + + +@loader_unbound_fn +def joinedload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: + return _generate_from_keys(Load.joinedload, keys, False, kw) + + +@loader_unbound_fn +def subqueryload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.subqueryload, keys, False, {}) + + +@loader_unbound_fn +def selectinload( + *keys: _AttrType, recursion_depth: Optional[int] = None +) -> _AbstractLoad: + return _generate_from_keys( + Load.selectinload, keys, False, {"recursion_depth": recursion_depth} + ) + + +@loader_unbound_fn +def lazyload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.lazyload, keys, False, {}) + + +@loader_unbound_fn +def immediateload( + *keys: _AttrType, recursion_depth: Optional[int] = None +) -> _AbstractLoad: + return _generate_from_keys( + Load.immediateload, keys, False, {"recursion_depth": recursion_depth} + ) + + +@loader_unbound_fn +def noload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.noload, keys, False, {}) + + +@loader_unbound_fn +def raiseload(*keys: _AttrType, **kw: Any) -> _AbstractLoad: + return _generate_from_keys(Load.raiseload, keys, False, kw) + + +@loader_unbound_fn +def defaultload(*keys: _AttrType) -> _AbstractLoad: + return _generate_from_keys(Load.defaultload, keys, False, {}) + + +@loader_unbound_fn +def defer( + key: _AttrType, *addl_attrs: _AttrType, raiseload: bool = False +) -> _AbstractLoad: + if addl_attrs: + util.warn_deprecated( + "The *addl_attrs on orm.defer is deprecated. Please use " + "method chaining in conjunction with defaultload() to " + "indicate a path.", + version="1.3", + ) + + if raiseload: + kw = {"raiseload": raiseload} + else: + kw = {} + + return _generate_from_keys(Load.defer, (key,) + addl_attrs, False, kw) + + +@loader_unbound_fn +def undefer(key: _AttrType, *addl_attrs: _AttrType) -> _AbstractLoad: + if addl_attrs: + util.warn_deprecated( + "The *addl_attrs on orm.undefer is deprecated. Please use " + "method chaining in conjunction with defaultload() to " + "indicate a path.", + version="1.3", + ) + return _generate_from_keys(Load.undefer, (key,) + addl_attrs, False, {}) + + +@loader_unbound_fn +def undefer_group(name: str) -> _AbstractLoad: + element = _WildcardLoad() + return element.undefer_group(name) + + +@loader_unbound_fn +def with_expression( + key: _AttrType, expression: _ColumnExpressionArgument[Any] +) -> _AbstractLoad: + return _generate_from_keys( + Load.with_expression, (key,), False, {"expression": expression} + ) + + +@loader_unbound_fn +def selectin_polymorphic( + base_cls: _EntityType[Any], classes: Iterable[Type[Any]] +) -> _AbstractLoad: + ul = Load(base_cls) + return ul.selectin_polymorphic(classes) + + +def _raise_for_does_not_link(path, attrname, parent_entity): + if len(path) > 1: + path_is_of_type = path[-1].entity is not path[-2].mapper.class_ + if insp_is_aliased_class(parent_entity): + parent_entity_str = str(parent_entity) + else: + parent_entity_str = parent_entity.class_.__name__ + + raise sa_exc.ArgumentError( + f'ORM mapped entity or attribute "{attrname}" does not ' + f'link from relationship "{path[-2]}%s".%s' + % ( + f".of_type({path[-1]})" if path_is_of_type else "", + ( + " Did you mean to use " + f'"{path[-2]}' + f'.of_type({parent_entity_str})" or "loadopt.options(' + f"selectin_polymorphic({path[-2].mapper.class_.__name__}, " + f'[{parent_entity_str}]), ...)" ?' + if not path_is_of_type + and not path[-1].is_aliased_class + and orm_util._entity_corresponds_to( + path.entity, inspect(parent_entity).mapper + ) + else "" + ), + ) + ) + else: + raise sa_exc.ArgumentError( + f'ORM mapped attribute "{attrname}" does not ' + f'link mapped class "{path[-1]}"' + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/sync.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/sync.py new file mode 100644 index 0000000..db09a3e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/sync.py @@ -0,0 +1,164 @@ +# orm/sync.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + + +"""private module containing functions used for copying data +between instances based on join conditions. + +""" + +from __future__ import annotations + +from . import exc +from . import util as orm_util +from .base import PassiveFlag + + +def populate( + source, + source_mapper, + dest, + dest_mapper, + synchronize_pairs, + uowcommit, + flag_cascaded_pks, +): + source_dict = source.dict + dest_dict = dest.dict + + for l, r in synchronize_pairs: + try: + # inline of source_mapper._get_state_attr_by_column + prop = source_mapper._columntoproperty[l] + value = source.manager[prop.key].impl.get( + source, source_dict, PassiveFlag.PASSIVE_OFF + ) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, dest_mapper, r, err) + + try: + # inline of dest_mapper._set_state_attr_by_column + prop = dest_mapper._columntoproperty[r] + dest.manager[prop.key].impl.set(dest, dest_dict, value, None) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, source_mapper, l, dest_mapper, r, err) + + # technically the "r.primary_key" check isn't + # needed here, but we check for this condition to limit + # how often this logic is invoked for memory/performance + # reasons, since we only need this info for a primary key + # destination. + if ( + flag_cascaded_pks + and l.primary_key + and r.primary_key + and r.references(l) + ): + uowcommit.attributes[("pk_cascaded", dest, r)] = True + + +def bulk_populate_inherit_keys(source_dict, source_mapper, synchronize_pairs): + # a simplified version of populate() used by bulk insert mode + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + value = source_dict[prop.key] + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, source_mapper, r, err) + + try: + prop = source_mapper._columntoproperty[r] + source_dict[prop.key] = value + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, source_mapper, l, source_mapper, r, err) + + +def clear(dest, dest_mapper, synchronize_pairs): + for l, r in synchronize_pairs: + if ( + r.primary_key + and dest_mapper._get_state_attr_by_column(dest, dest.dict, r) + not in orm_util._none_set + ): + raise AssertionError( + f"Dependency rule on column '{l}' " + "tried to blank-out primary key " + f"column '{r}' on instance '{orm_util.state_str(dest)}'" + ) + try: + dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(True, None, l, dest_mapper, r, err) + + +def update(source, source_mapper, dest, old_prefix, synchronize_pairs): + for l, r in synchronize_pairs: + try: + oldvalue = source_mapper._get_committed_attr_by_column( + source.obj(), l + ) + value = source_mapper._get_state_attr_by_column( + source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF + ) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) + dest[r.key] = value + dest[old_prefix + r.key] = oldvalue + + +def populate_dict(source, source_mapper, dict_, synchronize_pairs): + for l, r in synchronize_pairs: + try: + value = source_mapper._get_state_attr_by_column( + source, source.dict, l, passive=PassiveFlag.PASSIVE_OFF + ) + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) + + dict_[r.key] = value + + +def source_modified(uowcommit, source, source_mapper, synchronize_pairs): + """return true if the source object has changes from an old to a + new value on the given synchronize pairs + + """ + for l, r in synchronize_pairs: + try: + prop = source_mapper._columntoproperty[l] + except exc.UnmappedColumnError as err: + _raise_col_to_prop(False, source_mapper, l, None, r, err) + history = uowcommit.get_attribute_history( + source, prop.key, PassiveFlag.PASSIVE_NO_INITIALIZE + ) + if bool(history.deleted): + return True + else: + return False + + +def _raise_col_to_prop( + isdest, source_mapper, source_column, dest_mapper, dest_column, err +): + if isdest: + raise exc.UnmappedColumnError( + "Can't execute sync rule for " + "destination column '%s'; mapper '%s' does not map " + "this column. Try using an explicit `foreign_keys` " + "collection which does not include this column (or use " + "a viewonly=True relation)." % (dest_column, dest_mapper) + ) from err + else: + raise exc.UnmappedColumnError( + "Can't execute sync rule for " + "source column '%s'; mapper '%s' does not map this " + "column. Try using an explicit `foreign_keys` " + "collection which does not include destination column " + "'%s' (or use a viewonly=True relation)." + % (source_column, source_mapper, dest_column) + ) from err diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/unitofwork.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/unitofwork.py new file mode 100644 index 0000000..7e2df2b --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/unitofwork.py @@ -0,0 +1,796 @@ +# orm/unitofwork.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""The internals for the unit of work system. + +The session's flush() process passes objects to a contextual object +here, which assembles flush tasks based on mappers and their properties, +organizes them in order of dependency, and executes. + +""" + +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Optional +from typing import Set +from typing import TYPE_CHECKING + +from . import attributes +from . import exc as orm_exc +from . import util as orm_util +from .. import event +from .. import util +from ..util import topological + + +if TYPE_CHECKING: + from .dependency import DependencyProcessor + from .interfaces import MapperProperty + from .mapper import Mapper + from .session import Session + from .session import SessionTransaction + from .state import InstanceState + + +def track_cascade_events(descriptor, prop): + """Establish event listeners on object attributes which handle + cascade-on-set/append. + + """ + key = prop.key + + def append(state, item, initiator, **kw): + # process "save_update" cascade rules for when + # an instance is appended to the list of another instance + + if item is None: + return + + sess = state.session + if sess: + if sess._warn_on_events: + sess._flush_warning("collection append") + + prop = state.manager.mapper._props[key] + item_state = attributes.instance_state(item) + + if ( + prop._cascade.save_update + and (key == initiator.key) + and not sess._contains_state(item_state) + ): + sess._save_or_update_state(item_state) + return item + + def remove(state, item, initiator, **kw): + if item is None: + return + + sess = state.session + + prop = state.manager.mapper._props[key] + + if sess and sess._warn_on_events: + sess._flush_warning( + "collection remove" + if prop.uselist + else "related attribute delete" + ) + + if ( + item is not None + and item is not attributes.NEVER_SET + and item is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): + # expunge pending orphans + item_state = attributes.instance_state(item) + + if prop.mapper._is_orphan(item_state): + if sess and item_state in sess._new: + sess.expunge(item) + else: + # the related item may or may not itself be in a + # Session, however the parent for which we are catching + # the event is not in a session, so memoize this on the + # item + item_state._orphaned_outside_of_session = True + + def set_(state, newvalue, oldvalue, initiator, **kw): + # process "save_update" cascade rules for when an instance + # is attached to another instance + if oldvalue is newvalue: + return newvalue + + sess = state.session + if sess: + if sess._warn_on_events: + sess._flush_warning("related attribute set") + + prop = state.manager.mapper._props[key] + if newvalue is not None: + newvalue_state = attributes.instance_state(newvalue) + if ( + prop._cascade.save_update + and (key == initiator.key) + and not sess._contains_state(newvalue_state) + ): + sess._save_or_update_state(newvalue_state) + + if ( + oldvalue is not None + and oldvalue is not attributes.NEVER_SET + and oldvalue is not attributes.PASSIVE_NO_RESULT + and prop._cascade.delete_orphan + ): + # possible to reach here with attributes.NEVER_SET ? + oldvalue_state = attributes.instance_state(oldvalue) + + if oldvalue_state in sess._new and prop.mapper._is_orphan( + oldvalue_state + ): + sess.expunge(oldvalue) + return newvalue + + event.listen( + descriptor, "append_wo_mutation", append, raw=True, include_key=True + ) + event.listen( + descriptor, "append", append, raw=True, retval=True, include_key=True + ) + event.listen( + descriptor, "remove", remove, raw=True, retval=True, include_key=True + ) + event.listen( + descriptor, "set", set_, raw=True, retval=True, include_key=True + ) + + +class UOWTransaction: + session: Session + transaction: SessionTransaction + attributes: Dict[str, Any] + deps: util.defaultdict[Mapper[Any], Set[DependencyProcessor]] + mappers: util.defaultdict[Mapper[Any], Set[InstanceState[Any]]] + + def __init__(self, session: Session): + self.session = session + + # dictionary used by external actors to + # store arbitrary state information. + self.attributes = {} + + # dictionary of mappers to sets of + # DependencyProcessors, which are also + # set to be part of the sorted flush actions, + # which have that mapper as a parent. + self.deps = util.defaultdict(set) + + # dictionary of mappers to sets of InstanceState + # items pending for flush which have that mapper + # as a parent. + self.mappers = util.defaultdict(set) + + # a dictionary of Preprocess objects, which gather + # additional states impacted by the flush + # and determine if a flush action is needed + self.presort_actions = {} + + # dictionary of PostSortRec objects, each + # one issues work during the flush within + # a certain ordering. + self.postsort_actions = {} + + # a set of 2-tuples, each containing two + # PostSortRec objects where the second + # is dependent on the first being executed + # first + self.dependencies = set() + + # dictionary of InstanceState-> (isdelete, listonly) + # tuples, indicating if this state is to be deleted + # or insert/updated, or just refreshed + self.states = {} + + # tracks InstanceStates which will be receiving + # a "post update" call. Keys are mappers, + # values are a set of states and a set of the + # columns which should be included in the update. + self.post_update_states = util.defaultdict(lambda: (set(), set())) + + @property + def has_work(self): + return bool(self.states) + + def was_already_deleted(self, state): + """Return ``True`` if the given state is expired and was deleted + previously. + """ + if state.expired: + try: + state._load_expired(state, attributes.PASSIVE_OFF) + except orm_exc.ObjectDeletedError: + self.session._remove_newly_deleted([state]) + return True + return False + + def is_deleted(self, state): + """Return ``True`` if the given state is marked as deleted + within this uowtransaction.""" + + return state in self.states and self.states[state][0] + + def memo(self, key, callable_): + if key in self.attributes: + return self.attributes[key] + else: + self.attributes[key] = ret = callable_() + return ret + + def remove_state_actions(self, state): + """Remove pending actions for a state from the uowtransaction.""" + + isdelete = self.states[state][0] + + self.states[state] = (isdelete, True) + + def get_attribute_history( + self, state, key, passive=attributes.PASSIVE_NO_INITIALIZE + ): + """Facade to attributes.get_state_history(), including + caching of results.""" + + hashkey = ("history", state, key) + + # cache the objects, not the states; the strong reference here + # prevents newly loaded objects from being dereferenced during the + # flush process + + if hashkey in self.attributes: + history, state_history, cached_passive = self.attributes[hashkey] + # if the cached lookup was "passive" and now + # we want non-passive, do a non-passive lookup and re-cache + + if ( + not cached_passive & attributes.SQL_OK + and passive & attributes.SQL_OK + ): + impl = state.manager[key].impl + history = impl.get_history( + state, + state.dict, + attributes.PASSIVE_OFF + | attributes.LOAD_AGAINST_COMMITTED + | attributes.NO_RAISE, + ) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, passive) + else: + impl = state.manager[key].impl + # TODO: store the history as (state, object) tuples + # so we don't have to keep converting here + history = impl.get_history( + state, + state.dict, + passive + | attributes.LOAD_AGAINST_COMMITTED + | attributes.NO_RAISE, + ) + if history and impl.uses_objects: + state_history = history.as_state() + else: + state_history = history + self.attributes[hashkey] = (history, state_history, passive) + + return state_history + + def has_dep(self, processor): + return (processor, True) in self.presort_actions + + def register_preprocessor(self, processor, fromparent): + key = (processor, fromparent) + if key not in self.presort_actions: + self.presort_actions[key] = Preprocess(processor, fromparent) + + def register_object( + self, + state: InstanceState[Any], + isdelete: bool = False, + listonly: bool = False, + cancel_delete: bool = False, + operation: Optional[str] = None, + prop: Optional[MapperProperty] = None, + ) -> bool: + if not self.session._contains_state(state): + # this condition is normal when objects are registered + # as part of a relationship cascade operation. it should + # not occur for the top-level register from Session.flush(). + if not state.deleted and operation is not None: + util.warn( + "Object of type %s not in session, %s operation " + "along '%s' will not proceed" + % (orm_util.state_class_str(state), operation, prop) + ) + return False + + if state not in self.states: + mapper = state.manager.mapper + + if mapper not in self.mappers: + self._per_mapper_flush_actions(mapper) + + self.mappers[mapper].add(state) + self.states[state] = (isdelete, listonly) + else: + if not listonly and (isdelete or cancel_delete): + self.states[state] = (isdelete, False) + return True + + def register_post_update(self, state, post_update_cols): + mapper = state.manager.mapper.base_mapper + states, cols = self.post_update_states[mapper] + states.add(state) + cols.update(post_update_cols) + + def _per_mapper_flush_actions(self, mapper): + saves = SaveUpdateAll(self, mapper.base_mapper) + deletes = DeleteAll(self, mapper.base_mapper) + self.dependencies.add((saves, deletes)) + + for dep in mapper._dependency_processors: + dep.per_property_preprocessors(self) + + for prop in mapper.relationships: + if prop.viewonly: + continue + dep = prop._dependency_processor + dep.per_property_preprocessors(self) + + @util.memoized_property + def _mapper_for_dep(self): + """return a dynamic mapping of (Mapper, DependencyProcessor) to + True or False, indicating if the DependencyProcessor operates + on objects of that Mapper. + + The result is stored in the dictionary persistently once + calculated. + + """ + return util.PopulateDict( + lambda tup: tup[0]._props.get(tup[1].key) is tup[1].prop + ) + + def filter_states_for_dep(self, dep, states): + """Filter the given list of InstanceStates to those relevant to the + given DependencyProcessor. + + """ + mapper_for_dep = self._mapper_for_dep + return [s for s in states if mapper_for_dep[(s.manager.mapper, dep)]] + + def states_for_mapper_hierarchy(self, mapper, isdelete, listonly): + checktup = (isdelete, listonly) + for mapper in mapper.base_mapper.self_and_descendants: + for state in self.mappers[mapper]: + if self.states[state] == checktup: + yield state + + def _generate_actions(self): + """Generate the full, unsorted collection of PostSortRecs as + well as dependency pairs for this UOWTransaction. + + """ + # execute presort_actions, until all states + # have been processed. a presort_action might + # add new states to the uow. + while True: + ret = False + for action in list(self.presort_actions.values()): + if action.execute(self): + ret = True + if not ret: + break + + # see if the graph of mapper dependencies has cycles. + self.cycles = cycles = topological.find_cycles( + self.dependencies, list(self.postsort_actions.values()) + ) + + if cycles: + # if yes, break the per-mapper actions into + # per-state actions + convert = { + rec: set(rec.per_state_flush_actions(self)) for rec in cycles + } + + # rewrite the existing dependencies to point to + # the per-state actions for those per-mapper actions + # that were broken up. + for edge in list(self.dependencies): + if ( + None in edge + or edge[0].disabled + or edge[1].disabled + or cycles.issuperset(edge) + ): + self.dependencies.remove(edge) + elif edge[0] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[0]]: + self.dependencies.add((dep, edge[1])) + elif edge[1] in cycles: + self.dependencies.remove(edge) + for dep in convert[edge[1]]: + self.dependencies.add((edge[0], dep)) + + return { + a for a in self.postsort_actions.values() if not a.disabled + }.difference(cycles) + + def execute(self) -> None: + postsort_actions = self._generate_actions() + + postsort_actions = sorted( + postsort_actions, + key=lambda item: item.sort_key, + ) + # sort = topological.sort(self.dependencies, postsort_actions) + # print "--------------" + # print "\ndependencies:", self.dependencies + # print "\ncycles:", self.cycles + # print "\nsort:", list(sort) + # print "\nCOUNT OF POSTSORT ACTIONS", len(postsort_actions) + + # execute + if self.cycles: + for subset in topological.sort_as_subsets( + self.dependencies, postsort_actions + ): + set_ = set(subset) + while set_: + n = set_.pop() + n.execute_aggregate(self, set_) + else: + for rec in topological.sort(self.dependencies, postsort_actions): + rec.execute(self) + + def finalize_flush_changes(self) -> None: + """Mark processed objects as clean / deleted after a successful + flush(). + + This method is called within the flush() method after the + execute() method has succeeded and the transaction has been committed. + + """ + if not self.states: + return + + states = set(self.states) + isdel = { + s for (s, (isdelete, listonly)) in self.states.items() if isdelete + } + other = states.difference(isdel) + if isdel: + self.session._remove_newly_deleted(isdel) + if other: + self.session._register_persistent(other) + + +class IterateMappersMixin: + __slots__ = () + + def _mappers(self, uow): + if self.fromparent: + return iter( + m + for m in self.dependency_processor.parent.self_and_descendants + if uow._mapper_for_dep[(m, self.dependency_processor)] + ) + else: + return self.dependency_processor.mapper.self_and_descendants + + +class Preprocess(IterateMappersMixin): + __slots__ = ( + "dependency_processor", + "fromparent", + "processed", + "setup_flush_actions", + ) + + def __init__(self, dependency_processor, fromparent): + self.dependency_processor = dependency_processor + self.fromparent = fromparent + self.processed = set() + self.setup_flush_actions = False + + def execute(self, uow): + delete_states = set() + save_states = set() + + for mapper in self._mappers(uow): + for state in uow.mappers[mapper].difference(self.processed): + (isdelete, listonly) = uow.states[state] + if not listonly: + if isdelete: + delete_states.add(state) + else: + save_states.add(state) + + if delete_states: + self.dependency_processor.presort_deletes(uow, delete_states) + self.processed.update(delete_states) + if save_states: + self.dependency_processor.presort_saves(uow, save_states) + self.processed.update(save_states) + + if delete_states or save_states: + if not self.setup_flush_actions and ( + self.dependency_processor.prop_has_changes( + uow, delete_states, True + ) + or self.dependency_processor.prop_has_changes( + uow, save_states, False + ) + ): + self.dependency_processor.per_property_flush_actions(uow) + self.setup_flush_actions = True + return True + else: + return False + + +class PostSortRec: + __slots__ = ("disabled",) + + def __new__(cls, uow, *args): + key = (cls,) + args + if key in uow.postsort_actions: + return uow.postsort_actions[key] + else: + uow.postsort_actions[key] = ret = object.__new__(cls) + ret.disabled = False + return ret + + def execute_aggregate(self, uow, recs): + self.execute(uow) + + +class ProcessAll(IterateMappersMixin, PostSortRec): + __slots__ = "dependency_processor", "isdelete", "fromparent", "sort_key" + + def __init__(self, uow, dependency_processor, isdelete, fromparent): + self.dependency_processor = dependency_processor + self.sort_key = ( + "ProcessAll", + self.dependency_processor.sort_key, + isdelete, + ) + self.isdelete = isdelete + self.fromparent = fromparent + uow.deps[dependency_processor.parent.base_mapper].add( + dependency_processor + ) + + def execute(self, uow): + states = self._elements(uow) + if self.isdelete: + self.dependency_processor.process_deletes(uow, states) + else: + self.dependency_processor.process_saves(uow, states) + + def per_state_flush_actions(self, uow): + # this is handled by SaveUpdateAll and DeleteAll, + # since a ProcessAll should unconditionally be pulled + # into per-state if either the parent/child mappers + # are part of a cycle + return iter([]) + + def __repr__(self): + return "%s(%s, isdelete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + self.isdelete, + ) + + def _elements(self, uow): + for mapper in self._mappers(uow): + for state in uow.mappers[mapper]: + (isdelete, listonly) = uow.states[state] + if isdelete == self.isdelete and not listonly: + yield state + + +class PostUpdateAll(PostSortRec): + __slots__ = "mapper", "isdelete", "sort_key" + + def __init__(self, uow, mapper, isdelete): + self.mapper = mapper + self.isdelete = isdelete + self.sort_key = ("PostUpdateAll", mapper._sort_key, isdelete) + + @util.preload_module("sqlalchemy.orm.persistence") + def execute(self, uow): + persistence = util.preloaded.orm_persistence + states, cols = uow.post_update_states[self.mapper] + states = [s for s in states if uow.states[s][0] == self.isdelete] + + persistence.post_update(self.mapper, states, uow, cols) + + +class SaveUpdateAll(PostSortRec): + __slots__ = ("mapper", "sort_key") + + def __init__(self, uow, mapper): + self.mapper = mapper + self.sort_key = ("SaveUpdateAll", mapper._sort_key) + assert mapper is mapper.base_mapper + + @util.preload_module("sqlalchemy.orm.persistence") + def execute(self, uow): + util.preloaded.orm_persistence.save_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, False, False), + uow, + ) + + def per_state_flush_actions(self, uow): + states = list( + uow.states_for_mapper_hierarchy(self.mapper, False, False) + ) + base_mapper = self.mapper.base_mapper + delete_all = DeleteAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = SaveUpdateState(uow, state) + uow.dependencies.add((action, delete_all)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, False) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.mapper) + + +class DeleteAll(PostSortRec): + __slots__ = ("mapper", "sort_key") + + def __init__(self, uow, mapper): + self.mapper = mapper + self.sort_key = ("DeleteAll", mapper._sort_key) + assert mapper is mapper.base_mapper + + @util.preload_module("sqlalchemy.orm.persistence") + def execute(self, uow): + util.preloaded.orm_persistence.delete_obj( + self.mapper, + uow.states_for_mapper_hierarchy(self.mapper, True, False), + uow, + ) + + def per_state_flush_actions(self, uow): + states = list( + uow.states_for_mapper_hierarchy(self.mapper, True, False) + ) + base_mapper = self.mapper.base_mapper + save_all = SaveUpdateAll(uow, base_mapper) + for state in states: + # keep saves before deletes - + # this ensures 'row switch' operations work + action = DeleteState(uow, state) + uow.dependencies.add((save_all, action)) + yield action + + for dep in uow.deps[self.mapper]: + states_for_prop = uow.filter_states_for_dep(dep, states) + dep.per_state_flush_actions(uow, states_for_prop, True) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.mapper) + + +class ProcessState(PostSortRec): + __slots__ = "dependency_processor", "isdelete", "state", "sort_key" + + def __init__(self, uow, dependency_processor, isdelete, state): + self.dependency_processor = dependency_processor + self.sort_key = ("ProcessState", dependency_processor.sort_key) + self.isdelete = isdelete + self.state = state + + def execute_aggregate(self, uow, recs): + cls_ = self.__class__ + dependency_processor = self.dependency_processor + isdelete = self.isdelete + our_recs = [ + r + for r in recs + if r.__class__ is cls_ + and r.dependency_processor is dependency_processor + and r.isdelete is isdelete + ] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + if isdelete: + dependency_processor.process_deletes(uow, states) + else: + dependency_processor.process_saves(uow, states) + + def __repr__(self): + return "%s(%s, %s, delete=%s)" % ( + self.__class__.__name__, + self.dependency_processor, + orm_util.state_str(self.state), + self.isdelete, + ) + + +class SaveUpdateState(PostSortRec): + __slots__ = "state", "mapper", "sort_key" + + def __init__(self, uow, state): + self.state = state + self.mapper = state.mapper.base_mapper + self.sort_key = ("ProcessState", self.mapper._sort_key) + + @util.preload_module("sqlalchemy.orm.persistence") + def execute_aggregate(self, uow, recs): + persistence = util.preloaded.orm_persistence + cls_ = self.__class__ + mapper = self.mapper + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] + recs.difference_update(our_recs) + persistence.save_obj( + mapper, [self.state] + [r.state for r in our_recs], uow + ) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state), + ) + + +class DeleteState(PostSortRec): + __slots__ = "state", "mapper", "sort_key" + + def __init__(self, uow, state): + self.state = state + self.mapper = state.mapper.base_mapper + self.sort_key = ("DeleteState", self.mapper._sort_key) + + @util.preload_module("sqlalchemy.orm.persistence") + def execute_aggregate(self, uow, recs): + persistence = util.preloaded.orm_persistence + cls_ = self.__class__ + mapper = self.mapper + our_recs = [ + r for r in recs if r.__class__ is cls_ and r.mapper is mapper + ] + recs.difference_update(our_recs) + states = [self.state] + [r.state for r in our_recs] + persistence.delete_obj( + mapper, [s for s in states if uow.states[s][0]], uow + ) + + def __repr__(self): + return "%s(%s)" % ( + self.__class__.__name__, + orm_util.state_str(self.state), + ) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/util.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/util.py new file mode 100644 index 0000000..8e153e6 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/util.py @@ -0,0 +1,2416 @@ +# orm/util.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +from __future__ import annotations + +import enum +import functools +import re +import types +import typing +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Match +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union +import weakref + +from . import attributes # noqa +from . import exc +from ._typing import _O +from ._typing import insp_is_aliased_class +from ._typing import insp_is_mapper +from ._typing import prop_is_relationship +from .base import _class_to_mapper as _class_to_mapper +from .base import _MappedAnnotationBase +from .base import _never_set as _never_set # noqa: F401 +from .base import _none_set as _none_set # noqa: F401 +from .base import attribute_str as attribute_str # noqa: F401 +from .base import class_mapper as class_mapper +from .base import DynamicMapped +from .base import InspectionAttr as InspectionAttr +from .base import instance_str as instance_str # noqa: F401 +from .base import Mapped +from .base import object_mapper as object_mapper +from .base import object_state as object_state # noqa: F401 +from .base import opt_manager_of_class +from .base import ORMDescriptor +from .base import state_attribute_str as state_attribute_str # noqa: F401 +from .base import state_class_str as state_class_str # noqa: F401 +from .base import state_str as state_str # noqa: F401 +from .base import WriteOnlyMapped +from .interfaces import CriteriaOption +from .interfaces import MapperProperty as MapperProperty +from .interfaces import ORMColumnsClauseRole +from .interfaces import ORMEntityColumnsClauseRole +from .interfaces import ORMFromClauseRole +from .path_registry import PathRegistry as PathRegistry +from .. import event +from .. import exc as sa_exc +from .. import inspection +from .. import sql +from .. import util +from ..engine.result import result_tuple +from ..sql import coercions +from ..sql import expression +from ..sql import lambdas +from ..sql import roles +from ..sql import util as sql_util +from ..sql import visitors +from ..sql._typing import is_selectable +from ..sql.annotation import SupportsCloneAnnotations +from ..sql.base import ColumnCollection +from ..sql.cache_key import HasCacheKey +from ..sql.cache_key import MemoizedHasCacheKey +from ..sql.elements import ColumnElement +from ..sql.elements import KeyedColumnElement +from ..sql.selectable import FromClause +from ..util.langhelpers import MemoizedSlots +from ..util.typing import de_stringify_annotation as _de_stringify_annotation +from ..util.typing import ( + de_stringify_union_elements as _de_stringify_union_elements, +) +from ..util.typing import eval_name_only as _eval_name_only +from ..util.typing import is_origin_of_cls +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import typing_get_origin + +if typing.TYPE_CHECKING: + from ._typing import _EntityType + from ._typing import _IdentityKeyType + from ._typing import _InternalEntityType + from ._typing import _ORMCOLEXPR + from .context import _MapperEntity + from .context import ORMCompileState + from .mapper import Mapper + from .path_registry import AbstractEntityRegistry + from .query import Query + from .relationships import RelationshipProperty + from ..engine import Row + from ..engine import RowMapping + from ..sql._typing import _CE + from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _EquivalentColumnMap + from ..sql._typing import _FromClauseArgument + from ..sql._typing import _OnClauseArgument + from ..sql._typing import _PropagateAttrsType + from ..sql.annotation import _SA + from ..sql.base import ReadOnlyColumnCollection + from ..sql.elements import BindParameter + from ..sql.selectable import _ColumnsClauseElement + from ..sql.selectable import Select + from ..sql.selectable import Selectable + from ..sql.visitors import anon_map + from ..util.typing import _AnnotationScanType + from ..util.typing import ArgsTypeProcotol + +_T = TypeVar("_T", bound=Any) + +all_cascades = frozenset( + ( + "delete", + "delete-orphan", + "all", + "merge", + "expunge", + "save-update", + "refresh-expire", + "none", + ) +) + + +_de_stringify_partial = functools.partial( + functools.partial, + locals_=util.immutabledict( + { + "Mapped": Mapped, + "WriteOnlyMapped": WriteOnlyMapped, + "DynamicMapped": DynamicMapped, + } + ), +) + +# partial is practically useless as we have to write out the whole +# function and maintain the signature anyway + + +class _DeStringifyAnnotation(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: _AnnotationScanType, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + include_generic: bool = False, + ) -> Type[Any]: ... + + +de_stringify_annotation = cast( + _DeStringifyAnnotation, _de_stringify_partial(_de_stringify_annotation) +) + + +class _DeStringifyUnionElements(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: ArgsTypeProcotol, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + ) -> Type[Any]: ... + + +de_stringify_union_elements = cast( + _DeStringifyUnionElements, + _de_stringify_partial(_de_stringify_union_elements), +) + + +class _EvalNameOnly(Protocol): + def __call__(self, name: str, module_name: str) -> Any: ... + + +eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) + + +class CascadeOptions(FrozenSet[str]): + """Keeps track of the options sent to + :paramref:`.relationship.cascade`""" + + _add_w_all_cascades = all_cascades.difference( + ["all", "none", "delete-orphan"] + ) + _allowed_cascades = all_cascades + + _viewonly_cascades = ["expunge", "all", "none", "refresh-expire", "merge"] + + __slots__ = ( + "save_update", + "delete", + "refresh_expire", + "merge", + "expunge", + "delete_orphan", + ) + + save_update: bool + delete: bool + refresh_expire: bool + merge: bool + expunge: bool + delete_orphan: bool + + def __new__( + cls, value_list: Optional[Union[Iterable[str], str]] + ) -> CascadeOptions: + if isinstance(value_list, str) or value_list is None: + return cls.from_string(value_list) # type: ignore + values = set(value_list) + if values.difference(cls._allowed_cascades): + raise sa_exc.ArgumentError( + "Invalid cascade option(s): %s" + % ", ".join( + [ + repr(x) + for x in sorted( + values.difference(cls._allowed_cascades) + ) + ] + ) + ) + + if "all" in values: + values.update(cls._add_w_all_cascades) + if "none" in values: + values.clear() + values.discard("all") + + self = super().__new__(cls, values) + self.save_update = "save-update" in values + self.delete = "delete" in values + self.refresh_expire = "refresh-expire" in values + self.merge = "merge" in values + self.expunge = "expunge" in values + self.delete_orphan = "delete-orphan" in values + + if self.delete_orphan and not self.delete: + util.warn("The 'delete-orphan' cascade option requires 'delete'.") + return self + + def __repr__(self): + return "CascadeOptions(%r)" % (",".join([x for x in sorted(self)])) + + @classmethod + def from_string(cls, arg): + values = [c for c in re.split(r"\s*,\s*", arg or "") if c] + return cls(values) + + +def _validator_events(desc, key, validator, include_removes, include_backrefs): + """Runs a validation method on an attribute value to be set or + appended. + """ + + if not include_backrefs: + + def detect_is_backref(state, initiator): + impl = state.manager[key].impl + return initiator.impl is not impl + + if include_removes: + + def append(state, value, initiator): + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) + ): + return validator(state.obj(), key, value, False) + else: + return value + + def bulk_set(state, values, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + obj = state.obj() + values[:] = [ + validator(obj, key, value, False) for value in values + ] + + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value, False) + else: + return value + + def remove(state, value, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + validator(state.obj(), key, value, True) + + else: + + def append(state, value, initiator): + if initiator.op is not attributes.OP_BULK_REPLACE and ( + include_backrefs or not detect_is_backref(state, initiator) + ): + return validator(state.obj(), key, value) + else: + return value + + def bulk_set(state, values, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + obj = state.obj() + values[:] = [validator(obj, key, value) for value in values] + + def set_(state, value, oldvalue, initiator): + if include_backrefs or not detect_is_backref(state, initiator): + return validator(state.obj(), key, value) + else: + return value + + event.listen(desc, "append", append, raw=True, retval=True) + event.listen(desc, "bulk_replace", bulk_set, raw=True) + event.listen(desc, "set", set_, raw=True, retval=True) + if include_removes: + event.listen(desc, "remove", remove, raw=True, retval=True) + + +def polymorphic_union( + table_map, typecolname, aliasname="p_union", cast_nulls=True +): + """Create a ``UNION`` statement used by a polymorphic mapper. + + See :ref:`concrete_inheritance` for an example of how + this is used. + + :param table_map: mapping of polymorphic identities to + :class:`_schema.Table` objects. + :param typecolname: string name of a "discriminator" column, which will be + derived from the query, producing the polymorphic identity for + each row. If ``None``, no polymorphic discriminator is generated. + :param aliasname: name of the :func:`~sqlalchemy.sql.expression.alias()` + construct generated. + :param cast_nulls: if True, non-existent columns, which are represented + as labeled NULLs, will be passed into CAST. This is a legacy behavior + that is problematic on some backends such as Oracle - in which case it + can be set to False. + + """ + + colnames: util.OrderedSet[str] = util.OrderedSet() + colnamemaps = {} + types = {} + for key in table_map: + table = table_map[key] + + table = coercions.expect( + roles.StrictFromClauseRole, table, allow_select=True + ) + table_map[key] = table + + m = {} + for c in table.c: + if c.key == typecolname: + raise sa_exc.InvalidRequestError( + "Polymorphic union can't use '%s' as the discriminator " + "column due to mapped column %r; please apply the " + "'typecolname' " + "argument; this is available on " + "ConcreteBase as '_concrete_discriminator_name'" + % (typecolname, c) + ) + colnames.add(c.key) + m[c.key] = c + types[c.key] = c.type + colnamemaps[table] = m + + def col(name, table): + try: + return colnamemaps[table][name] + except KeyError: + if cast_nulls: + return sql.cast(sql.null(), types[name]).label(name) + else: + return sql.type_coerce(sql.null(), types[name]).label(name) + + result = [] + for type_, table in table_map.items(): + if typecolname is not None: + result.append( + sql.select( + *( + [col(name, table) for name in colnames] + + [ + sql.literal_column( + sql_util._quote_ddl_expr(type_) + ).label(typecolname) + ] + ) + ).select_from(table) + ) + else: + result.append( + sql.select( + *[col(name, table) for name in colnames] + ).select_from(table) + ) + return sql.union_all(*result).alias(aliasname) + + +def identity_key( + class_: Optional[Type[_T]] = None, + ident: Union[Any, Tuple[Any, ...]] = None, + *, + instance: Optional[_T] = None, + row: Optional[Union[Row[Any], RowMapping]] = None, + identity_token: Optional[Any] = None, +) -> _IdentityKeyType[_T]: + r"""Generate "identity key" tuples, as are used as keys in the + :attr:`.Session.identity_map` dictionary. + + This function has several call styles: + + * ``identity_key(class, ident, identity_token=token)`` + + This form receives a mapped class and a primary key scalar or + tuple as an argument. + + E.g.:: + + >>> identity_key(MyClass, (1, 2)) + (, (1, 2), None) + + :param class: mapped class (must be a positional argument) + :param ident: primary key, may be a scalar or tuple argument. + :param identity_token: optional identity token + + .. versionadded:: 1.2 added identity_token + + + * ``identity_key(instance=instance)`` + + This form will produce the identity key for a given instance. The + instance need not be persistent, only that its primary key attributes + are populated (else the key will contain ``None`` for those missing + values). + + E.g.:: + + >>> instance = MyClass(1, 2) + >>> identity_key(instance=instance) + (, (1, 2), None) + + In this form, the given instance is ultimately run though + :meth:`_orm.Mapper.identity_key_from_instance`, which will have the + effect of performing a database check for the corresponding row + if the object is expired. + + :param instance: object instance (must be given as a keyword arg) + + * ``identity_key(class, row=row, identity_token=token)`` + + This form is similar to the class/tuple form, except is passed a + database result row as a :class:`.Row` or :class:`.RowMapping` object. + + E.g.:: + + >>> row = engine.execute(\ + text("select * from table where a=1 and b=2")\ + ).first() + >>> identity_key(MyClass, row=row) + (, (1, 2), None) + + :param class: mapped class (must be a positional argument) + :param row: :class:`.Row` row returned by a :class:`_engine.CursorResult` + (must be given as a keyword arg) + :param identity_token: optional identity token + + .. versionadded:: 1.2 added identity_token + + """ + if class_ is not None: + mapper = class_mapper(class_) + if row is None: + if ident is None: + raise sa_exc.ArgumentError("ident or row is required") + return mapper.identity_key_from_primary_key( + tuple(util.to_list(ident)), identity_token=identity_token + ) + else: + return mapper.identity_key_from_row( + row, identity_token=identity_token + ) + elif instance is not None: + mapper = object_mapper(instance) + return mapper.identity_key_from_instance(instance) + else: + raise sa_exc.ArgumentError("class or instance is required") + + +class _TraceAdaptRole(enum.Enum): + """Enumeration of all the use cases for ORMAdapter. + + ORMAdapter remains one of the most complicated aspects of the ORM, as it is + used for in-place adaption of column expressions to be applied to a SELECT, + replacing :class:`.Table` and other objects that are mapped to classes with + aliases of those tables in the case of joined eager loading, or in the case + of polymorphic loading as used with concrete mappings or other custom "with + polymorphic" parameters, with whole user-defined subqueries. The + enumerations provide an overview of all the use cases used by ORMAdapter, a + layer of formality as to the introduction of new ORMAdapter use cases (of + which none are anticipated), as well as a means to trace the origins of a + particular ORMAdapter within runtime debugging. + + SQLAlchemy 2.0 has greatly scaled back ORM features which relied heavily on + open-ended statement adaption, including the ``Query.with_polymorphic()`` + method and the ``Query.select_from_entity()`` methods, favoring + user-explicit aliasing schemes using the ``aliased()`` and + ``with_polymorphic()`` standalone constructs; these still use adaption, + however the adaption is applied in a narrower scope. + + """ + + # aliased() use that is used to adapt individual attributes at query + # construction time + ALIASED_INSP = enum.auto() + + # joinedload cases; typically adapt an ON clause of a relationship + # join + JOINEDLOAD_USER_DEFINED_ALIAS = enum.auto() + JOINEDLOAD_PATH_WITH_POLYMORPHIC = enum.auto() + JOINEDLOAD_MEMOIZED_ADAPTER = enum.auto() + + # polymorphic cases - these are complex ones that replace FROM + # clauses, replacing tables with subqueries + MAPPER_POLYMORPHIC_ADAPTER = enum.auto() + WITH_POLYMORPHIC_ADAPTER = enum.auto() + WITH_POLYMORPHIC_ADAPTER_RIGHT_JOIN = enum.auto() + DEPRECATED_JOIN_ADAPT_RIGHT_SIDE = enum.auto() + + # the from_statement() case, used only to adapt individual attributes + # from a given statement to local ORM attributes at result fetching + # time. assigned to ORMCompileState._from_obj_alias + ADAPT_FROM_STATEMENT = enum.auto() + + # the joinedload for queries that have LIMIT/OFFSET/DISTINCT case; + # the query is placed inside of a subquery with the LIMIT/OFFSET/etc., + # joinedloads are then placed on the outside. + # assigned to ORMCompileState.compound_eager_adapter + COMPOUND_EAGER_STATEMENT = enum.auto() + + # the legacy Query._set_select_from() case. + # this is needed for Query's set operations (i.e. UNION, etc. ) + # as well as "legacy from_self()", which while removed from 2.0 as + # public API, is used for the Query.count() method. this one + # still does full statement traversal + # assigned to ORMCompileState._from_obj_alias + LEGACY_SELECT_FROM_ALIAS = enum.auto() + + +class ORMStatementAdapter(sql_util.ColumnAdapter): + """ColumnAdapter which includes a role attribute.""" + + __slots__ = ("role",) + + def __init__( + self, + role: _TraceAdaptRole, + selectable: Selectable, + *, + equivalents: Optional[_EquivalentColumnMap] = None, + adapt_required: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + adapt_on_names: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, + ): + self.role = role + super().__init__( + selectable, + equivalents=equivalents, + adapt_required=adapt_required, + allow_label_resolve=allow_label_resolve, + anonymize_labels=anonymize_labels, + adapt_on_names=adapt_on_names, + adapt_from_selectables=adapt_from_selectables, + ) + + +class ORMAdapter(sql_util.ColumnAdapter): + """ColumnAdapter subclass which excludes adaptation of entities from + non-matching mappers. + + """ + + __slots__ = ("role", "mapper", "is_aliased_class", "aliased_insp") + + is_aliased_class: bool + aliased_insp: Optional[AliasedInsp[Any]] + + def __init__( + self, + role: _TraceAdaptRole, + entity: _InternalEntityType[Any], + *, + equivalents: Optional[_EquivalentColumnMap] = None, + adapt_required: bool = False, + allow_label_resolve: bool = True, + anonymize_labels: bool = False, + selectable: Optional[Selectable] = None, + limit_on_entity: bool = True, + adapt_on_names: bool = False, + adapt_from_selectables: Optional[AbstractSet[FromClause]] = None, + ): + self.role = role + self.mapper = entity.mapper + if selectable is None: + selectable = entity.selectable + if insp_is_aliased_class(entity): + self.is_aliased_class = True + self.aliased_insp = entity + else: + self.is_aliased_class = False + self.aliased_insp = None + + super().__init__( + selectable, + equivalents, + adapt_required=adapt_required, + allow_label_resolve=allow_label_resolve, + anonymize_labels=anonymize_labels, + include_fn=self._include_fn if limit_on_entity else None, + adapt_on_names=adapt_on_names, + adapt_from_selectables=adapt_from_selectables, + ) + + def _include_fn(self, elem): + entity = elem._annotations.get("parentmapper", None) + + return not entity or entity.isa(self.mapper) or self.mapper.isa(entity) + + +class AliasedClass( + inspection.Inspectable["AliasedInsp[_O]"], ORMColumnsClauseRole[_O] +): + r"""Represents an "aliased" form of a mapped class for usage with Query. + + The ORM equivalent of a :func:`~sqlalchemy.sql.expression.alias` + construct, this object mimics the mapped class using a + ``__getattr__`` scheme and maintains a reference to a + real :class:`~sqlalchemy.sql.expression.Alias` object. + + A primary purpose of :class:`.AliasedClass` is to serve as an alternate + within a SQL statement generated by the ORM, such that an existing + mapped entity can be used in multiple contexts. A simple example:: + + # find all pairs of users with the same name + user_alias = aliased(User) + session.query(User, user_alias).\ + join((user_alias, User.id > user_alias.id)).\ + filter(User.name == user_alias.name) + + :class:`.AliasedClass` is also capable of mapping an existing mapped + class to an entirely new selectable, provided this selectable is column- + compatible with the existing mapped selectable, and it can also be + configured in a mapping as the target of a :func:`_orm.relationship`. + See the links below for examples. + + The :class:`.AliasedClass` object is constructed typically using the + :func:`_orm.aliased` function. It also is produced with additional + configuration when using the :func:`_orm.with_polymorphic` function. + + The resulting object is an instance of :class:`.AliasedClass`. + This object implements an attribute scheme which produces the + same attribute and method interface as the original mapped + class, allowing :class:`.AliasedClass` to be compatible + with any attribute technique which works on the original class, + including hybrid attributes (see :ref:`hybrids_toplevel`). + + The :class:`.AliasedClass` can be inspected for its underlying + :class:`_orm.Mapper`, aliased selectable, and other information + using :func:`_sa.inspect`:: + + from sqlalchemy import inspect + my_alias = aliased(MyClass) + insp = inspect(my_alias) + + The resulting inspection object is an instance of :class:`.AliasedInsp`. + + + .. seealso:: + + :func:`.aliased` + + :func:`.with_polymorphic` + + :ref:`relationship_aliased_class` + + :ref:`relationship_to_window_function` + + + """ + + __name__: str + + def __init__( + self, + mapped_class_or_ac: _EntityType[_O], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]] = None, + with_polymorphic_discriminator: Optional[ColumnElement[Any]] = None, + base_alias: Optional[AliasedInsp[Any]] = None, + use_mapper_path: bool = False, + represents_outer_join: bool = False, + ): + insp = cast( + "_InternalEntityType[_O]", inspection.inspect(mapped_class_or_ac) + ) + mapper = insp.mapper + + nest_adapters = False + + if alias is None: + if insp.is_aliased_class and insp.selectable._is_subquery: + alias = insp.selectable.alias() + else: + alias = ( + mapper._with_polymorphic_selectable._anonymous_fromclause( + name=name, + flat=flat, + ) + ) + elif insp.is_aliased_class: + nest_adapters = True + + assert alias is not None + self._aliased_insp = AliasedInsp( + self, + insp, + alias, + name, + ( + with_polymorphic_mappers + if with_polymorphic_mappers + else mapper.with_polymorphic_mappers + ), + ( + with_polymorphic_discriminator + if with_polymorphic_discriminator is not None + else mapper.polymorphic_on + ), + base_alias, + use_mapper_path, + adapt_on_names, + represents_outer_join, + nest_adapters, + ) + + self.__name__ = f"aliased({mapper.class_.__name__})" + + @classmethod + def _reconstitute_from_aliased_insp( + cls, aliased_insp: AliasedInsp[_O] + ) -> AliasedClass[_O]: + obj = cls.__new__(cls) + obj.__name__ = f"aliased({aliased_insp.mapper.class_.__name__})" + obj._aliased_insp = aliased_insp + + if aliased_insp._is_with_polymorphic: + for sub_aliased_insp in aliased_insp._with_polymorphic_entities: + if sub_aliased_insp is not aliased_insp: + ent = AliasedClass._reconstitute_from_aliased_insp( + sub_aliased_insp + ) + setattr(obj, sub_aliased_insp.class_.__name__, ent) + + return obj + + def __getattr__(self, key: str) -> Any: + try: + _aliased_insp = self.__dict__["_aliased_insp"] + except KeyError: + raise AttributeError() + else: + target = _aliased_insp._target + # maintain all getattr mechanics + attr = getattr(target, key) + + # attribute is a method, that will be invoked against a + # "self"; so just return a new method with the same function and + # new self + if hasattr(attr, "__call__") and hasattr(attr, "__self__"): + return types.MethodType(attr.__func__, self) + + # attribute is a descriptor, that will be invoked against a + # "self"; so invoke the descriptor against this self + if hasattr(attr, "__get__"): + attr = attr.__get__(None, self) + + # attributes within the QueryableAttribute system will want this + # to be invoked so the object can be adapted + if hasattr(attr, "adapt_to_entity"): + attr = attr.adapt_to_entity(_aliased_insp) + setattr(self, key, attr) + + return attr + + def _get_from_serialized( + self, key: str, mapped_class: _O, aliased_insp: AliasedInsp[_O] + ) -> Any: + # this method is only used in terms of the + # sqlalchemy.ext.serializer extension + attr = getattr(mapped_class, key) + if hasattr(attr, "__call__") and hasattr(attr, "__self__"): + return types.MethodType(attr.__func__, self) + + # attribute is a descriptor, that will be invoked against a + # "self"; so invoke the descriptor against this self + if hasattr(attr, "__get__"): + attr = attr.__get__(None, self) + + # attributes within the QueryableAttribute system will want this + # to be invoked so the object can be adapted + if hasattr(attr, "adapt_to_entity"): + aliased_insp._weak_entity = weakref.ref(self) + attr = attr.adapt_to_entity(aliased_insp) + setattr(self, key, attr) + + return attr + + def __repr__(self) -> str: + return "" % ( + id(self), + self._aliased_insp._target.__name__, + ) + + def __str__(self) -> str: + return str(self._aliased_insp) + + +@inspection._self_inspects +class AliasedInsp( + ORMEntityColumnsClauseRole[_O], + ORMFromClauseRole, + HasCacheKey, + InspectionAttr, + MemoizedSlots, + inspection.Inspectable["AliasedInsp[_O]"], + Generic[_O], +): + """Provide an inspection interface for an + :class:`.AliasedClass` object. + + The :class:`.AliasedInsp` object is returned + given an :class:`.AliasedClass` using the + :func:`_sa.inspect` function:: + + from sqlalchemy import inspect + from sqlalchemy.orm import aliased + + my_alias = aliased(MyMappedClass) + insp = inspect(my_alias) + + Attributes on :class:`.AliasedInsp` + include: + + * ``entity`` - the :class:`.AliasedClass` represented. + * ``mapper`` - the :class:`_orm.Mapper` mapping the underlying class. + * ``selectable`` - the :class:`_expression.Alias` + construct which ultimately + represents an aliased :class:`_schema.Table` or + :class:`_expression.Select` + construct. + * ``name`` - the name of the alias. Also is used as the attribute + name when returned in a result tuple from :class:`_query.Query`. + * ``with_polymorphic_mappers`` - collection of :class:`_orm.Mapper` + objects + indicating all those mappers expressed in the select construct + for the :class:`.AliasedClass`. + * ``polymorphic_on`` - an alternate column or SQL expression which + will be used as the "discriminator" for a polymorphic load. + + .. seealso:: + + :ref:`inspection_toplevel` + + """ + + __slots__ = ( + "__weakref__", + "_weak_entity", + "mapper", + "selectable", + "name", + "_adapt_on_names", + "with_polymorphic_mappers", + "polymorphic_on", + "_use_mapper_path", + "_base_alias", + "represents_outer_join", + "persist_selectable", + "local_table", + "_is_with_polymorphic", + "_with_polymorphic_entities", + "_adapter", + "_target", + "__clause_element__", + "_memoized_values", + "_all_column_expressions", + "_nest_adapters", + ) + + _cache_key_traversal = [ + ("name", visitors.ExtendedInternalTraversal.dp_string), + ("_adapt_on_names", visitors.ExtendedInternalTraversal.dp_boolean), + ("_use_mapper_path", visitors.ExtendedInternalTraversal.dp_boolean), + ("_target", visitors.ExtendedInternalTraversal.dp_inspectable), + ("selectable", visitors.ExtendedInternalTraversal.dp_clauseelement), + ( + "with_polymorphic_mappers", + visitors.InternalTraversal.dp_has_cache_key_list, + ), + ("polymorphic_on", visitors.InternalTraversal.dp_clauseelement), + ] + + mapper: Mapper[_O] + selectable: FromClause + _adapter: ORMAdapter + with_polymorphic_mappers: Sequence[Mapper[Any]] + _with_polymorphic_entities: Sequence[AliasedInsp[Any]] + + _weak_entity: weakref.ref[AliasedClass[_O]] + """the AliasedClass that refers to this AliasedInsp""" + + _target: Union[Type[_O], AliasedClass[_O]] + """the thing referenced by the AliasedClass/AliasedInsp. + + In the vast majority of cases, this is the mapped class. However + it may also be another AliasedClass (alias of alias). + + """ + + def __init__( + self, + entity: AliasedClass[_O], + inspected: _InternalEntityType[_O], + selectable: FromClause, + name: Optional[str], + with_polymorphic_mappers: Optional[Sequence[Mapper[Any]]], + polymorphic_on: Optional[ColumnElement[Any]], + _base_alias: Optional[AliasedInsp[Any]], + _use_mapper_path: bool, + adapt_on_names: bool, + represents_outer_join: bool, + nest_adapters: bool, + ): + mapped_class_or_ac = inspected.entity + mapper = inspected.mapper + + self._weak_entity = weakref.ref(entity) + self.mapper = mapper + self.selectable = self.persist_selectable = self.local_table = ( + selectable + ) + self.name = name + self.polymorphic_on = polymorphic_on + self._base_alias = weakref.ref(_base_alias or self) + self._use_mapper_path = _use_mapper_path + self.represents_outer_join = represents_outer_join + self._nest_adapters = nest_adapters + + if with_polymorphic_mappers: + self._is_with_polymorphic = True + self.with_polymorphic_mappers = with_polymorphic_mappers + self._with_polymorphic_entities = [] + for poly in self.with_polymorphic_mappers: + if poly is not mapper: + ent = AliasedClass( + poly.class_, + selectable, + base_alias=self, + adapt_on_names=adapt_on_names, + use_mapper_path=_use_mapper_path, + ) + + setattr(self.entity, poly.class_.__name__, ent) + self._with_polymorphic_entities.append(ent._aliased_insp) + + else: + self._is_with_polymorphic = False + self.with_polymorphic_mappers = [mapper] + + self._adapter = ORMAdapter( + _TraceAdaptRole.ALIASED_INSP, + mapper, + selectable=selectable, + equivalents=mapper._equivalent_columns, + adapt_on_names=adapt_on_names, + anonymize_labels=True, + # make sure the adapter doesn't try to grab other tables that + # are not even the thing we are mapping, such as embedded + # selectables in subqueries or CTEs. See issue #6060 + adapt_from_selectables={ + m.selectable + for m in self.with_polymorphic_mappers + if not adapt_on_names + }, + limit_on_entity=False, + ) + + if nest_adapters: + # supports "aliased class of aliased class" use case + assert isinstance(inspected, AliasedInsp) + self._adapter = inspected._adapter.wrap(self._adapter) + + self._adapt_on_names = adapt_on_names + self._target = mapped_class_or_ac + + @classmethod + def _alias_factory( + cls, + element: Union[_EntityType[_O], FromClause], + alias: Optional[FromClause] = None, + name: Optional[str] = None, + flat: bool = False, + adapt_on_names: bool = False, + ) -> Union[AliasedClass[_O], FromClause]: + if isinstance(element, FromClause): + if adapt_on_names: + raise sa_exc.ArgumentError( + "adapt_on_names only applies to ORM elements" + ) + if name: + return element.alias(name=name, flat=flat) + else: + return coercions.expect( + roles.AnonymizedFromClauseRole, element, flat=flat + ) + else: + return AliasedClass( + element, + alias=alias, + flat=flat, + name=name, + adapt_on_names=adapt_on_names, + ) + + @classmethod + def _with_polymorphic_factory( + cls, + base: Union[Type[_O], Mapper[_O]], + classes: Union[Literal["*"], Iterable[_EntityType[Any]]], + selectable: Union[Literal[False, None], FromClause] = False, + flat: bool = False, + polymorphic_on: Optional[ColumnElement[Any]] = None, + aliased: bool = False, + innerjoin: bool = False, + adapt_on_names: bool = False, + _use_mapper_path: bool = False, + ) -> AliasedClass[_O]: + primary_mapper = _class_to_mapper(base) + + if selectable not in (None, False) and flat: + raise sa_exc.ArgumentError( + "the 'flat' and 'selectable' arguments cannot be passed " + "simultaneously to with_polymorphic()" + ) + + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, selectable, innerjoin=innerjoin + ) + if aliased or flat: + assert selectable is not None + selectable = selectable._anonymous_fromclause(flat=flat) + + return AliasedClass( + base, + selectable, + with_polymorphic_mappers=mappers, + adapt_on_names=adapt_on_names, + with_polymorphic_discriminator=polymorphic_on, + use_mapper_path=_use_mapper_path, + represents_outer_join=not innerjoin, + ) + + @property + def entity(self) -> AliasedClass[_O]: + # to eliminate reference cycles, the AliasedClass is held weakly. + # this produces some situations where the AliasedClass gets lost, + # particularly when one is created internally and only the AliasedInsp + # is passed around. + # to work around this case, we just generate a new one when we need + # it, as it is a simple class with very little initial state on it. + ent = self._weak_entity() + if ent is None: + ent = AliasedClass._reconstitute_from_aliased_insp(self) + self._weak_entity = weakref.ref(ent) + return ent + + is_aliased_class = True + "always returns True" + + def _memoized_method___clause_element__(self) -> FromClause: + return self.selectable._annotate( + { + "parentmapper": self.mapper, + "parententity": self, + "entity_namespace": self, + } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + @property + def entity_namespace(self) -> AliasedClass[_O]: + return self.entity + + @property + def class_(self) -> Type[_O]: + """Return the mapped class ultimately represented by this + :class:`.AliasedInsp`.""" + return self.mapper.class_ + + @property + def _path_registry(self) -> AbstractEntityRegistry: + if self._use_mapper_path: + return self.mapper._path_registry + else: + return PathRegistry.per_mapper(self) + + def __getstate__(self) -> Dict[str, Any]: + return { + "entity": self.entity, + "mapper": self.mapper, + "alias": self.selectable, + "name": self.name, + "adapt_on_names": self._adapt_on_names, + "with_polymorphic_mappers": self.with_polymorphic_mappers, + "with_polymorphic_discriminator": self.polymorphic_on, + "base_alias": self._base_alias(), + "use_mapper_path": self._use_mapper_path, + "represents_outer_join": self.represents_outer_join, + "nest_adapters": self._nest_adapters, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__init__( # type: ignore + state["entity"], + state["mapper"], + state["alias"], + state["name"], + state["with_polymorphic_mappers"], + state["with_polymorphic_discriminator"], + state["base_alias"], + state["use_mapper_path"], + state["adapt_on_names"], + state["represents_outer_join"], + state["nest_adapters"], + ) + + def _merge_with(self, other: AliasedInsp[_O]) -> AliasedInsp[_O]: + # assert self._is_with_polymorphic + # assert other._is_with_polymorphic + + primary_mapper = other.mapper + + assert self.mapper is primary_mapper + + our_classes = util.to_set( + mp.class_ for mp in self.with_polymorphic_mappers + ) + new_classes = {mp.class_ for mp in other.with_polymorphic_mappers} + if our_classes == new_classes: + return other + else: + classes = our_classes.union(new_classes) + + mappers, selectable = primary_mapper._with_polymorphic_args( + classes, None, innerjoin=not other.represents_outer_join + ) + selectable = selectable._anonymous_fromclause(flat=True) + return AliasedClass( + primary_mapper, + selectable, + with_polymorphic_mappers=mappers, + with_polymorphic_discriminator=other.polymorphic_on, + use_mapper_path=other._use_mapper_path, + represents_outer_join=other.represents_outer_join, + )._aliased_insp + + def _adapt_element( + self, expr: _ORMCOLEXPR, key: Optional[str] = None + ) -> _ORMCOLEXPR: + assert isinstance(expr, ColumnElement) + d: Dict[str, Any] = { + "parententity": self, + "parentmapper": self.mapper, + } + if key: + d["proxy_key"] = key + + # IMO mypy should see this one also as returning the same type + # we put into it, but it's not + return ( + self._adapter.traverse(expr) + ._annotate(d) + ._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + ) + + if TYPE_CHECKING: + # establish compatibility with the _ORMAdapterProto protocol, + # which in turn is compatible with _CoreAdapterProto. + + def _orm_adapt_element( + self, + obj: _CE, + key: Optional[str] = None, + ) -> _CE: ... + + else: + _orm_adapt_element = _adapt_element + + def _entity_for_mapper(self, mapper): + self_poly = self.with_polymorphic_mappers + if mapper in self_poly: + if mapper is self.mapper: + return self + else: + return getattr( + self.entity, mapper.class_.__name__ + )._aliased_insp + elif mapper.isa(self.mapper): + return self + else: + assert False, "mapper %s doesn't correspond to %s" % (mapper, self) + + def _memoized_attr__get_clause(self): + onclause, replacemap = self.mapper._get_clause + return ( + self._adapter.traverse(onclause), + { + self._adapter.traverse(col): param + for col, param in replacemap.items() + }, + ) + + def _memoized_attr__memoized_values(self): + return {} + + def _memoized_attr__all_column_expressions(self): + if self._is_with_polymorphic: + cols_plus_keys = self.mapper._columns_plus_keys( + [ent.mapper for ent in self._with_polymorphic_entities] + ) + else: + cols_plus_keys = self.mapper._columns_plus_keys() + + cols_plus_keys = [ + (key, self._adapt_element(col)) for key, col in cols_plus_keys + ] + + return ColumnCollection(cols_plus_keys) + + def _memo(self, key, callable_, *args, **kw): + if key in self._memoized_values: + return self._memoized_values[key] + else: + self._memoized_values[key] = value = callable_(*args, **kw) + return value + + def __repr__(self): + if self.with_polymorphic_mappers: + with_poly = "(%s)" % ", ".join( + mp.class_.__name__ for mp in self.with_polymorphic_mappers + ) + else: + with_poly = "" + return "" % ( + id(self), + self.class_.__name__, + with_poly, + ) + + def __str__(self): + if self._is_with_polymorphic: + return "with_polymorphic(%s, [%s])" % ( + self._target.__name__, + ", ".join( + mp.class_.__name__ + for mp in self.with_polymorphic_mappers + if mp is not self.mapper + ), + ) + else: + return "aliased(%s)" % (self._target.__name__,) + + +class _WrapUserEntity: + """A wrapper used within the loader_criteria lambda caller so that + we can bypass declared_attr descriptors on unmapped mixins, which + normally emit a warning for such use. + + might also be useful for other per-lambda instrumentations should + the need arise. + + """ + + __slots__ = ("subject",) + + def __init__(self, subject): + self.subject = subject + + @util.preload_module("sqlalchemy.orm.decl_api") + def __getattribute__(self, name): + decl_api = util.preloaded.orm.decl_api + + subject = object.__getattribute__(self, "subject") + if name in subject.__dict__ and isinstance( + subject.__dict__[name], decl_api.declared_attr + ): + return subject.__dict__[name].fget(subject) + else: + return getattr(subject, name) + + +class LoaderCriteriaOption(CriteriaOption): + """Add additional WHERE criteria to the load for all occurrences of + a particular entity. + + :class:`_orm.LoaderCriteriaOption` is invoked using the + :func:`_orm.with_loader_criteria` function; see that function for + details. + + .. versionadded:: 1.4 + + """ + + __slots__ = ( + "root_entity", + "entity", + "deferred_where_criteria", + "where_criteria", + "_where_crit_orig", + "include_aliases", + "propagate_to_loaders", + ) + + _traverse_internals = [ + ("root_entity", visitors.ExtendedInternalTraversal.dp_plain_obj), + ("entity", visitors.ExtendedInternalTraversal.dp_has_cache_key), + ("where_criteria", visitors.InternalTraversal.dp_clauseelement), + ("include_aliases", visitors.InternalTraversal.dp_boolean), + ("propagate_to_loaders", visitors.InternalTraversal.dp_boolean), + ] + + root_entity: Optional[Type[Any]] + entity: Optional[_InternalEntityType[Any]] + where_criteria: Union[ColumnElement[bool], lambdas.DeferredLambdaElement] + deferred_where_criteria: bool + include_aliases: bool + propagate_to_loaders: bool + + _where_crit_orig: Any + + def __init__( + self, + entity_or_base: _EntityType[Any], + where_criteria: Union[ + _ColumnExpressionArgument[bool], + Callable[[Any], _ColumnExpressionArgument[bool]], + ], + loader_only: bool = False, + include_aliases: bool = False, + propagate_to_loaders: bool = True, + track_closure_variables: bool = True, + ): + entity = cast( + "_InternalEntityType[Any]", + inspection.inspect(entity_or_base, False), + ) + if entity is None: + self.root_entity = cast("Type[Any]", entity_or_base) + self.entity = None + else: + self.root_entity = None + self.entity = entity + + self._where_crit_orig = where_criteria + if callable(where_criteria): + if self.root_entity is not None: + wrap_entity = self.root_entity + else: + assert entity is not None + wrap_entity = entity.entity + + self.deferred_where_criteria = True + self.where_criteria = lambdas.DeferredLambdaElement( + where_criteria, + roles.WhereHavingRole, + lambda_args=(_WrapUserEntity(wrap_entity),), + opts=lambdas.LambdaOptions( + track_closure_variables=track_closure_variables + ), + ) + else: + self.deferred_where_criteria = False + self.where_criteria = coercions.expect( + roles.WhereHavingRole, where_criteria + ) + + self.include_aliases = include_aliases + self.propagate_to_loaders = propagate_to_loaders + + @classmethod + def _unreduce( + cls, entity, where_criteria, include_aliases, propagate_to_loaders + ): + return LoaderCriteriaOption( + entity, + where_criteria, + include_aliases=include_aliases, + propagate_to_loaders=propagate_to_loaders, + ) + + def __reduce__(self): + return ( + LoaderCriteriaOption._unreduce, + ( + self.entity.class_ if self.entity else self.root_entity, + self._where_crit_orig, + self.include_aliases, + self.propagate_to_loaders, + ), + ) + + def _all_mappers(self) -> Iterator[Mapper[Any]]: + if self.entity: + yield from self.entity.mapper.self_and_descendants + else: + assert self.root_entity + stack = list(self.root_entity.__subclasses__()) + while stack: + subclass = stack.pop(0) + ent = cast( + "_InternalEntityType[Any]", + inspection.inspect(subclass, raiseerr=False), + ) + if ent: + yield from ent.mapper.self_and_descendants + else: + stack.extend(subclass.__subclasses__()) + + def _should_include(self, compile_state: ORMCompileState) -> bool: + if ( + compile_state.select_statement._annotations.get( + "for_loader_criteria", None + ) + is self + ): + return False + return True + + def _resolve_where_criteria( + self, ext_info: _InternalEntityType[Any] + ) -> ColumnElement[bool]: + if self.deferred_where_criteria: + crit = cast( + "ColumnElement[bool]", + self.where_criteria._resolve_with_args(ext_info.entity), + ) + else: + crit = self.where_criteria # type: ignore + assert isinstance(crit, ColumnElement) + return sql_util._deep_annotate( + crit, + {"for_loader_criteria": self}, + detect_subquery_cols=True, + ind_cols_on_fromclause=True, + ) + + def process_compile_state_replaced_entities( + self, + compile_state: ORMCompileState, + mapper_entities: Iterable[_MapperEntity], + ) -> None: + self.process_compile_state(compile_state) + + def process_compile_state(self, compile_state: ORMCompileState) -> None: + """Apply a modification to a given :class:`.CompileState`.""" + + # if options to limit the criteria to immediate query only, + # use compile_state.attributes instead + + self.get_global_criteria(compile_state.global_attributes) + + def get_global_criteria(self, attributes: Dict[Any, Any]) -> None: + for mp in self._all_mappers(): + load_criteria = attributes.setdefault( + ("additional_entity_criteria", mp), [] + ) + + load_criteria.append(self) + + +inspection._inspects(AliasedClass)(lambda target: target._aliased_insp) + + +@inspection._inspects(type) +def _inspect_mc( + class_: Type[_O], +) -> Optional[Mapper[_O]]: + try: + class_manager = opt_manager_of_class(class_) + if class_manager is None or not class_manager.is_mapped: + return None + mapper = class_manager.mapper + except exc.NO_STATE: + return None + else: + return mapper + + +GenericAlias = type(List[Any]) + + +@inspection._inspects(GenericAlias) +def _inspect_generic_alias( + class_: Type[_O], +) -> Optional[Mapper[_O]]: + origin = cast("Type[_O]", typing_get_origin(class_)) + return _inspect_mc(origin) + + +@inspection._self_inspects +class Bundle( + ORMColumnsClauseRole[_T], + SupportsCloneAnnotations, + MemoizedHasCacheKey, + inspection.Inspectable["Bundle[_T]"], + InspectionAttr, +): + """A grouping of SQL expressions that are returned by a :class:`.Query` + under one namespace. + + The :class:`.Bundle` essentially allows nesting of the tuple-based + results returned by a column-oriented :class:`_query.Query` object. + It also + is extensible via simple subclassing, where the primary capability + to override is that of how the set of expressions should be returned, + allowing post-processing as well as custom return types, without + involving ORM identity-mapped classes. + + .. seealso:: + + :ref:`bundles` + + + """ + + single_entity = False + """If True, queries for a single Bundle will be returned as a single + entity, rather than an element within a keyed tuple.""" + + is_clause_element = False + + is_mapper = False + + is_aliased_class = False + + is_bundle = True + + _propagate_attrs: _PropagateAttrsType = util.immutabledict() + + proxy_set = util.EMPTY_SET # type: ignore + + exprs: List[_ColumnsClauseElement] + + def __init__( + self, name: str, *exprs: _ColumnExpressionArgument[Any], **kw: Any + ): + r"""Construct a new :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + for row in session.query(bn).filter( + bn.c.x == 5).filter(bn.c.y == 4): + print(row.mybundle.x, row.mybundle.y) + + :param name: name of the bundle. + :param \*exprs: columns or SQL expressions comprising the bundle. + :param single_entity=False: if True, rows for this :class:`.Bundle` + can be returned as a "single entity" outside of any enclosing tuple + in the same manner as a mapped entity. + + """ + self.name = self._label = name + coerced_exprs = [ + coercions.expect( + roles.ColumnsClauseRole, expr, apply_propagate_attrs=self + ) + for expr in exprs + ] + self.exprs = coerced_exprs + + self.c = self.columns = ColumnCollection( + (getattr(col, "key", col._label), col) + for col in [e._annotations.get("bundle", e) for e in coerced_exprs] + ).as_readonly() + self.single_entity = kw.pop("single_entity", self.single_entity) + + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: + return (self.__class__, self.name, self.single_entity) + tuple( + [expr._gen_cache_key(anon_map, bindparams) for expr in self.exprs] + ) + + @property + def mapper(self) -> Optional[Mapper[Any]]: + mp: Optional[Mapper[Any]] = self.exprs[0]._annotations.get( + "parentmapper", None + ) + return mp + + @property + def entity(self) -> Optional[_InternalEntityType[Any]]: + ie: Optional[_InternalEntityType[Any]] = self.exprs[ + 0 + ]._annotations.get("parententity", None) + return ie + + @property + def entity_namespace( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: + return self.c + + columns: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] + + """A namespace of SQL expressions referred to by this :class:`.Bundle`. + + e.g.:: + + bn = Bundle("mybundle", MyClass.x, MyClass.y) + + q = sess.query(bn).filter(bn.c.x == 5) + + Nesting of bundles is also supported:: + + b1 = Bundle("b1", + Bundle('b2', MyClass.a, MyClass.b), + Bundle('b3', MyClass.x, MyClass.y) + ) + + q = sess.query(b1).filter( + b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) + + .. seealso:: + + :attr:`.Bundle.c` + + """ + + c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] + """An alias for :attr:`.Bundle.columns`.""" + + def _clone(self): + cloned = self.__class__.__new__(self.__class__) + cloned.__dict__.update(self.__dict__) + return cloned + + def __clause_element__(self): + # ensure existing entity_namespace remains + annotations = {"bundle": self, "entity_namespace": self} + annotations.update(self._annotations) + + plugin_subject = self.exprs[0]._propagate_attrs.get( + "plugin_subject", self.entity + ) + return ( + expression.ClauseList( + _literal_as_text_role=roles.ColumnsClauseRole, + group=False, + *[e._annotations.get("bundle", e) for e in self.exprs], + ) + ._annotate(annotations) + ._set_propagate_attrs( + # the Bundle *must* use the orm plugin no matter what. the + # subject can be None but it's much better if it's not. + { + "compile_state_plugin": "orm", + "plugin_subject": plugin_subject, + } + ) + ) + + @property + def clauses(self): + return self.__clause_element__().clauses + + def label(self, name): + """Provide a copy of this :class:`.Bundle` passing a new label.""" + + cloned = self._clone() + cloned.name = name + return cloned + + def create_row_processor( + self, + query: Select[Any], + procs: Sequence[Callable[[Row[Any]], Any]], + labels: Sequence[str], + ) -> Callable[[Row[Any]], Any]: + """Produce the "row processing" function for this :class:`.Bundle`. + + May be overridden by subclasses to provide custom behaviors when + results are fetched. The method is passed the statement object and a + set of "row processor" functions at query execution time; these + processor functions when given a result row will return the individual + attribute value, which can then be adapted into any kind of return data + structure. + + The example below illustrates replacing the usual :class:`.Row` + return structure with a straight Python dictionary:: + + from sqlalchemy.orm import Bundle + + class DictBundle(Bundle): + def create_row_processor(self, query, procs, labels): + 'Override create_row_processor to return values as + dictionaries' + + def proc(row): + return dict( + zip(labels, (proc(row) for proc in procs)) + ) + return proc + + A result from the above :class:`_orm.Bundle` will return dictionary + values:: + + bn = DictBundle('mybundle', MyClass.data1, MyClass.data2) + for row in session.execute(select(bn)).where(bn.c.data1 == 'd1'): + print(row.mybundle['data1'], row.mybundle['data2']) + + """ + keyed_tuple = result_tuple(labels, [() for l in labels]) + + def proc(row: Row[Any]) -> Any: + return keyed_tuple([proc(row) for proc in procs]) + + return proc + + +def _orm_annotate(element: _SA, exclude: Optional[Any] = None) -> _SA: + """Deep copy the given ClauseElement, annotating each element with the + "_orm_adapt" flag. + + Elements within the exclude collection will be cloned but not annotated. + + """ + return sql_util._deep_annotate(element, {"_orm_adapt": True}, exclude) + + +def _orm_deannotate(element: _SA) -> _SA: + """Remove annotations that link a column to a particular mapping. + + Note this doesn't affect "remote" and "foreign" annotations + passed by the :func:`_orm.foreign` and :func:`_orm.remote` + annotators. + + """ + + return sql_util._deep_deannotate( + element, values=("_orm_adapt", "parententity") + ) + + +def _orm_full_deannotate(element: _SA) -> _SA: + return sql_util._deep_deannotate(element) + + +class _ORMJoin(expression.Join): + """Extend Join to support ORM constructs as input.""" + + __visit_name__ = expression.Join.__visit_name__ + + inherit_cache = True + + def __init__( + self, + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + _left_memo: Optional[Any] = None, + _right_memo: Optional[Any] = None, + _extra_criteria: Tuple[ColumnElement[bool], ...] = (), + ): + left_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(left), + ) + + right_info = cast( + "Union[FromClause, _InternalEntityType[Any]]", + inspection.inspect(right), + ) + adapt_to = right_info.selectable + + # used by joined eager loader + self._left_memo = _left_memo + self._right_memo = _right_memo + + if isinstance(onclause, attributes.QueryableAttribute): + if TYPE_CHECKING: + assert isinstance( + onclause.comparator, RelationshipProperty.Comparator + ) + on_selectable = onclause.comparator._source_selectable() + prop = onclause.property + _extra_criteria += onclause._extra_criteria + elif isinstance(onclause, MapperProperty): + # used internally by joined eager loader...possibly not ideal + prop = onclause + on_selectable = prop.parent.selectable + else: + prop = None + on_selectable = None + + left_selectable = left_info.selectable + if prop: + adapt_from: Optional[FromClause] + if sql_util.clause_is_present(on_selectable, left_selectable): + adapt_from = on_selectable + else: + assert isinstance(left_selectable, FromClause) + adapt_from = left_selectable + + ( + pj, + sj, + source, + dest, + secondary, + target_adapter, + ) = prop._create_joins( + source_selectable=adapt_from, + dest_selectable=adapt_to, + source_polymorphic=True, + of_type_entity=right_info, + alias_secondary=True, + extra_criteria=_extra_criteria, + ) + + if sj is not None: + if isouter: + # note this is an inner join from secondary->right + right = sql.join(secondary, right, sj) + onclause = pj + else: + left = sql.join(left, secondary, pj, isouter) + onclause = sj + else: + onclause = pj + + self._target_adapter = target_adapter + + # we don't use the normal coercions logic for _ORMJoin + # (probably should), so do some gymnastics to get the entity. + # logic here is for #8721, which was a major bug in 1.4 + # for almost two years, not reported/fixed until 1.4.43 (!) + if is_selectable(left_info): + parententity = left_selectable._annotations.get( + "parententity", None + ) + elif insp_is_mapper(left_info) or insp_is_aliased_class(left_info): + parententity = left_info + else: + parententity = None + + if parententity is not None: + self._annotations = self._annotations.union( + {"parententity": parententity} + ) + + augment_onclause = bool(_extra_criteria) and not prop + expression.Join.__init__(self, left, right, onclause, isouter, full) + + assert self.onclause is not None + + if augment_onclause: + self.onclause &= sql.and_(*_extra_criteria) + + if ( + not prop + and getattr(right_info, "mapper", None) + and right_info.mapper.single # type: ignore + ): + right_info = cast("_InternalEntityType[Any]", right_info) + # if single inheritance target and we are using a manual + # or implicit ON clause, augment it the same way we'd augment the + # WHERE. + single_crit = right_info.mapper._single_table_criterion + if single_crit is not None: + if insp_is_aliased_class(right_info): + single_crit = right_info._adapter.traverse(single_crit) + self.onclause = self.onclause & single_crit + + def _splice_into_center(self, other): + """Splice a join into the center. + + Given join(a, b) and join(b, c), return join(a, b).join(c) + + """ + leftmost = other + while isinstance(leftmost, sql.Join): + leftmost = leftmost.left + + assert self.right is leftmost + + left = _ORMJoin( + self.left, + other.left, + self.onclause, + isouter=self.isouter, + _left_memo=self._left_memo, + _right_memo=other._left_memo, + ) + + return _ORMJoin( + left, + other.right, + other.onclause, + isouter=other.isouter, + _right_memo=other._right_memo, + ) + + def join( + self, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, + ) -> _ORMJoin: + return _ORMJoin(self, right, onclause, full=full, isouter=isouter) + + def outerjoin( + self, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, + ) -> _ORMJoin: + return _ORMJoin(self, right, onclause, isouter=True, full=full) + + +def with_parent( + instance: object, + prop: attributes.QueryableAttribute[Any], + from_entity: Optional[_EntityType[Any]] = None, +) -> ColumnElement[bool]: + """Create filtering criterion that relates this query's primary entity + to the given related instance, using established + :func:`_orm.relationship()` + configuration. + + E.g.:: + + stmt = select(Address).where(with_parent(some_user, User.addresses)) + + + The SQL rendered is the same as that rendered when a lazy loader + would fire off from the given parent on that attribute, meaning + that the appropriate state is taken from the parent object in + Python without the need to render joins to the parent table + in the rendered statement. + + The given property may also make use of :meth:`_orm.PropComparator.of_type` + to indicate the left side of the criteria:: + + + a1 = aliased(Address) + a2 = aliased(Address) + stmt = select(a1, a2).where( + with_parent(u1, User.addresses.of_type(a2)) + ) + + The above use is equivalent to using the + :func:`_orm.with_parent.from_entity` argument:: + + a1 = aliased(Address) + a2 = aliased(Address) + stmt = select(a1, a2).where( + with_parent(u1, User.addresses, from_entity=a2) + ) + + :param instance: + An instance which has some :func:`_orm.relationship`. + + :param property: + Class-bound attribute, which indicates + what relationship from the instance should be used to reconcile the + parent/child relationship. + + :param from_entity: + Entity in which to consider as the left side. This defaults to the + "zero" entity of the :class:`_query.Query` itself. + + .. versionadded:: 1.2 + + """ + prop_t: RelationshipProperty[Any] + + if isinstance(prop, str): + raise sa_exc.ArgumentError( + "with_parent() accepts class-bound mapped attributes, not strings" + ) + elif isinstance(prop, attributes.QueryableAttribute): + if prop._of_type: + from_entity = prop._of_type + mapper_property = prop.property + if mapper_property is None or not prop_is_relationship( + mapper_property + ): + raise sa_exc.ArgumentError( + f"Expected relationship property for with_parent(), " + f"got {mapper_property}" + ) + prop_t = mapper_property + else: + prop_t = prop + + return prop_t._with_parent(instance, from_entity=from_entity) + + +def has_identity(object_: object) -> bool: + """Return True if the given object has a database + identity. + + This typically corresponds to the object being + in either the persistent or detached state. + + .. seealso:: + + :func:`.was_deleted` + + """ + state = attributes.instance_state(object_) + return state.has_identity + + +def was_deleted(object_: object) -> bool: + """Return True if the given object was deleted + within a session flush. + + This is regardless of whether or not the object is + persistent or detached. + + .. seealso:: + + :attr:`.InstanceState.was_deleted` + + """ + + state = attributes.instance_state(object_) + return state.was_deleted + + +def _entity_corresponds_to( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: + """determine if 'given' corresponds to 'entity', in terms + of an entity passed to Query that would match the same entity + being referred to elsewhere in the query. + + """ + if insp_is_aliased_class(entity): + if insp_is_aliased_class(given): + if entity._base_alias() is given._base_alias(): + return True + return False + elif insp_is_aliased_class(given): + if given._use_mapper_path: + return entity in given.with_polymorphic_mappers + else: + return entity is given + + assert insp_is_mapper(given) + return entity.common_parent(given) + + +def _entity_corresponds_to_use_path_impl( + given: _InternalEntityType[Any], entity: _InternalEntityType[Any] +) -> bool: + """determine if 'given' corresponds to 'entity', in terms + of a path of loader options where a mapped attribute is taken to + be a member of a parent entity. + + e.g.:: + + someoption(A).someoption(A.b) # -> fn(A, A) -> True + someoption(A).someoption(C.d) # -> fn(A, C) -> False + + a1 = aliased(A) + someoption(a1).someoption(A.b) # -> fn(a1, A) -> False + someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True + + wp = with_polymorphic(A, [A1, A2]) + someoption(wp).someoption(A1.foo) # -> fn(wp, A1) -> False + someoption(wp).someoption(wp.A1.foo) # -> fn(wp, wp.A1) -> True + + + """ + if insp_is_aliased_class(given): + return ( + insp_is_aliased_class(entity) + and not entity._use_mapper_path + and (given is entity or entity in given._with_polymorphic_entities) + ) + elif not insp_is_aliased_class(entity): + return given.isa(entity.mapper) + else: + return ( + entity._use_mapper_path + and given in entity.with_polymorphic_mappers + ) + + +def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: + """determine if 'given' "is a" mapper, in terms of the given + would load rows of type 'mapper'. + + """ + if given.is_aliased_class: + return mapper in given.with_polymorphic_mappers or given.mapper.isa( + mapper + ) + elif given.with_polymorphic_mappers: + return mapper in given.with_polymorphic_mappers + else: + return given.isa(mapper) + + +def _getitem(iterable_query: Query[Any], item: Any) -> Any: + """calculate __getitem__ in terms of an iterable query object + that also has a slice() method. + + """ + + def _no_negative_indexes(): + raise IndexError( + "negative indexes are not accepted by SQL " + "index / slice operators" + ) + + if isinstance(item, slice): + start, stop, step = util.decode_slice(item) + + if ( + isinstance(stop, int) + and isinstance(start, int) + and stop - start <= 0 + ): + return [] + + elif (isinstance(start, int) and start < 0) or ( + isinstance(stop, int) and stop < 0 + ): + _no_negative_indexes() + + res = iterable_query.slice(start, stop) + if step is not None: + return list(res)[None : None : item.step] + else: + return list(res) + else: + if item == -1: + _no_negative_indexes() + else: + return list(iterable_query[item : item + 1])[0] + + +def _is_mapped_annotation( + raw_annotation: _AnnotationScanType, + cls: Type[Any], + originating_cls: Type[Any], +) -> bool: + try: + annotated = de_stringify_annotation( + cls, raw_annotation, originating_cls.__module__ + ) + except NameError: + # in most cases, at least within our own tests, we can raise + # here, which is more accurate as it prevents us from returning + # false negatives. However, in the real world, try to avoid getting + # involved with end-user annotations that have nothing to do with us. + # see issue #8888 where we bypass using this function in the case + # that we want to detect an unresolvable Mapped[] type. + return False + else: + return is_origin_of_cls(annotated, _MappedAnnotationBase) + + +class _CleanupError(Exception): + pass + + +def _cleanup_mapped_str_annotation( + annotation: str, originating_module: str +) -> str: + # fix up an annotation that comes in as the form: + # 'Mapped[List[Address]]' so that it instead looks like: + # 'Mapped[List["Address"]]' , which will allow us to get + # "Address" as a string + + # additionally, resolve symbols for these names since this is where + # we'd have to do it + + inner: Optional[Match[str]] + + mm = re.match(r"^(.+?)\[(.+)\]$", annotation) + + if not mm: + return annotation + + # ticket #8759. Resolve the Mapped name to a real symbol. + # originally this just checked the name. + try: + obj = eval_name_only(mm.group(1), originating_module) + except NameError as ne: + raise _CleanupError( + f'For annotation "{annotation}", could not resolve ' + f'container type "{mm.group(1)}". ' + "Please ensure this type is imported at the module level " + "outside of TYPE_CHECKING blocks" + ) from ne + + if obj is typing.ClassVar: + real_symbol = "ClassVar" + else: + try: + if issubclass(obj, _MappedAnnotationBase): + real_symbol = obj.__name__ + else: + return annotation + except TypeError: + # avoid isinstance(obj, type) check, just catch TypeError + return annotation + + # note: if one of the codepaths above didn't define real_symbol and + # then didn't return, real_symbol raises UnboundLocalError + # which is actually a NameError, and the calling routines don't + # notice this since they are catching NameError anyway. Just in case + # this is being modified in the future, something to be aware of. + + stack = [] + inner = mm + while True: + stack.append(real_symbol if mm is inner else inner.group(1)) + g2 = inner.group(2) + inner = re.match(r"^(.+?)\[(.+)\]$", g2) + if inner is None: + stack.append(g2) + break + + # stacks we want to rewrite, that is, quote the last entry which + # we think is a relationship class name: + # + # ['Mapped', 'List', 'Address'] + # ['Mapped', 'A'] + # + # stacks we dont want to rewrite, which are generally MappedColumn + # use cases: + # + # ['Mapped', "'Optional[Dict[str, str]]'"] + # ['Mapped', 'dict[str, str] | None'] + + if ( + # avoid already quoted symbols such as + # ['Mapped', "'Optional[Dict[str, str]]'"] + not re.match(r"""^["'].*["']$""", stack[-1]) + # avoid further generics like Dict[] such as + # ['Mapped', 'dict[str, str] | None'] + and not re.match(r".*\[.*\]", stack[-1]) + ): + stripchars = "\"' " + stack[-1] = ", ".join( + f'"{elem.strip(stripchars)}"' for elem in stack[-1].split(",") + ) + + annotation = "[".join(stack) + ("]" * (len(stack) - 1)) + + return annotation + + +def _extract_mapped_subtype( + raw_annotation: Optional[_AnnotationScanType], + cls: type, + originating_module: str, + key: str, + attr_cls: Type[Any], + required: bool, + is_dataclass_field: bool, + expect_mapped: bool = True, + raiseerr: bool = True, +) -> Optional[Tuple[Union[type, str], Optional[type]]]: + """given an annotation, figure out if it's ``Mapped[something]`` and if + so, return the ``something`` part. + + Includes error raise scenarios and other options. + + """ + + if raw_annotation is None: + if required: + raise sa_exc.ArgumentError( + f"Python typing annotation is required for attribute " + f'"{cls.__name__}.{key}" when primary argument(s) for ' + f'"{attr_cls.__name__}" construct are None or not present' + ) + return None + + try: + annotated = de_stringify_annotation( + cls, + raw_annotation, + originating_module, + str_cleanup_fn=_cleanup_mapped_str_annotation, + ) + except _CleanupError as ce: + raise sa_exc.ArgumentError( + f"Could not interpret annotation {raw_annotation}. " + "Check that it uses names that are correctly imported at the " + "module level. See chained stack trace for more hints." + ) from ce + except NameError as ne: + if raiseerr and "Mapped[" in raw_annotation: # type: ignore + raise sa_exc.ArgumentError( + f"Could not interpret annotation {raw_annotation}. " + "Check that it uses names that are correctly imported at the " + "module level. See chained stack trace for more hints." + ) from ne + + annotated = raw_annotation # type: ignore + + if is_dataclass_field: + return annotated, None + else: + if not hasattr(annotated, "__origin__") or not is_origin_of_cls( + annotated, _MappedAnnotationBase + ): + if expect_mapped: + if not raiseerr: + return None + + origin = getattr(annotated, "__origin__", None) + if origin is typing.ClassVar: + return None + + # check for other kind of ORM descriptor like AssociationProxy, + # don't raise for that (issue #9957) + elif isinstance(origin, type) and issubclass( + origin, ORMDescriptor + ): + return None + + raise sa_exc.ArgumentError( + f'Type annotation for "{cls.__name__}.{key}" ' + "can't be correctly interpreted for " + "Annotated Declarative Table form. ORM annotations " + "should normally make use of the ``Mapped[]`` generic " + "type, or other ORM-compatible generic type, as a " + "container for the actual type, which indicates the " + "intent that the attribute is mapped. " + "Class variables that are not intended to be mapped " + "by the ORM should use ClassVar[]. " + "To allow Annotated Declarative to disregard legacy " + "annotations which don't use Mapped[] to pass, set " + '"__allow_unmapped__ = True" on the class or a ' + "superclass this class.", + code="zlpr", + ) + + else: + return annotated, None + + if len(annotated.__args__) != 1: + raise sa_exc.ArgumentError( + "Expected sub-type for Mapped[] annotation" + ) + + return annotated.__args__[0], annotated.__origin__ + + +def _mapper_property_as_plain_name(prop: Type[Any]) -> str: + if hasattr(prop, "_mapper_property_name"): + name = prop._mapper_property_name() + else: + name = None + return util.clsname_as_plain_name(prop, name) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/writeonly.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/writeonly.py new file mode 100644 index 0000000..5680cc7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/writeonly.py @@ -0,0 +1,678 @@ +# orm/writeonly.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Write-only collection API. + +This is an alternate mapped attribute style that only supports single-item +collection mutation operations. To read the collection, a select() +object must be executed each time. + +.. versionadded:: 2.0 + + +""" + +from __future__ import annotations + +from typing import Any +from typing import Collection +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from sqlalchemy.sql import bindparam +from . import attributes +from . import interfaces +from . import relationships +from . import strategies +from .base import NEVER_SET +from .base import object_mapper +from .base import PassiveFlag +from .base import RelationshipDirection +from .. import exc +from .. import inspect +from .. import log +from .. import util +from ..sql import delete +from ..sql import insert +from ..sql import select +from ..sql import update +from ..sql.dml import Delete +from ..sql.dml import Insert +from ..sql.dml import Update +from ..util.typing import Literal + +if TYPE_CHECKING: + from . import QueryableAttribute + from ._typing import _InstanceDict + from .attributes import AttributeEventToken + from .base import LoaderCallableStatus + from .collections import _AdaptedCollectionProtocol + from .collections import CollectionAdapter + from .mapper import Mapper + from .relationships import _RelationshipOrderByArg + from .state import InstanceState + from .util import AliasedClass + from ..event import _Dispatch + from ..sql.selectable import FromClause + from ..sql.selectable import Select + +_T = TypeVar("_T", bound=Any) + + +class WriteOnlyHistory(Generic[_T]): + """Overrides AttributeHistory to receive append/remove events directly.""" + + unchanged_items: util.OrderedIdentitySet + added_items: util.OrderedIdentitySet + deleted_items: util.OrderedIdentitySet + _reconcile_collection: bool + + def __init__( + self, + attr: WriteOnlyAttributeImpl, + state: InstanceState[_T], + passive: PassiveFlag, + apply_to: Optional[WriteOnlyHistory[_T]] = None, + ) -> None: + if apply_to: + if passive & PassiveFlag.SQL_OK: + raise exc.InvalidRequestError( + f"Attribute {attr} can't load the existing state from the " + "database for this operation; full iteration is not " + "permitted. If this is a delete operation, configure " + f"passive_deletes=True on the {attr} relationship in " + "order to resolve this error." + ) + + self.unchanged_items = apply_to.unchanged_items + self.added_items = apply_to.added_items + self.deleted_items = apply_to.deleted_items + self._reconcile_collection = apply_to._reconcile_collection + else: + self.deleted_items = util.OrderedIdentitySet() + self.added_items = util.OrderedIdentitySet() + self.unchanged_items = util.OrderedIdentitySet() + self._reconcile_collection = False + + @property + def added_plus_unchanged(self) -> List[_T]: + return list(self.added_items.union(self.unchanged_items)) + + @property + def all_items(self) -> List[_T]: + return list( + self.added_items.union(self.unchanged_items).union( + self.deleted_items + ) + ) + + def as_history(self) -> attributes.History: + if self._reconcile_collection: + added = self.added_items.difference(self.unchanged_items) + deleted = self.deleted_items.intersection(self.unchanged_items) + unchanged = self.unchanged_items.difference(deleted) + else: + added, unchanged, deleted = ( + self.added_items, + self.unchanged_items, + self.deleted_items, + ) + return attributes.History(list(added), list(unchanged), list(deleted)) + + def indexed(self, index: Union[int, slice]) -> Union[List[_T], _T]: + return list(self.added_items)[index] + + def add_added(self, value: _T) -> None: + self.added_items.add(value) + + def add_removed(self, value: _T) -> None: + if value in self.added_items: + self.added_items.remove(value) + else: + self.deleted_items.add(value) + + +class WriteOnlyAttributeImpl( + attributes.HasCollectionAdapter, attributes.AttributeImpl +): + uses_objects: bool = True + default_accepts_scalar_loader: bool = False + supports_population: bool = False + _supports_dynamic_iteration: bool = False + collection: bool = False + dynamic: bool = True + order_by: _RelationshipOrderByArg = () + collection_history_cls: Type[WriteOnlyHistory[Any]] = WriteOnlyHistory + + query_class: Type[WriteOnlyCollection[Any]] + + def __init__( + self, + class_: Union[Type[Any], AliasedClass[Any]], + key: str, + dispatch: _Dispatch[QueryableAttribute[Any]], + target_mapper: Mapper[_T], + order_by: _RelationshipOrderByArg, + **kw: Any, + ): + super().__init__(class_, key, None, dispatch, **kw) + self.target_mapper = target_mapper + self.query_class = WriteOnlyCollection + if order_by: + self.order_by = tuple(order_by) + + def get( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Union[util.OrderedIdentitySet, WriteOnlyCollection[Any]]: + if not passive & PassiveFlag.SQL_OK: + return self._get_collection_history( + state, PassiveFlag.PASSIVE_NO_INITIALIZE + ).added_items + else: + return self.query_class(self, state) + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Literal[None] = ..., + passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., + ) -> CollectionAdapter: ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: _AdaptedCollectionProtocol = ..., + passive: PassiveFlag = ..., + ) -> CollectionAdapter: ... + + @overload + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = ..., + passive: PassiveFlag = ..., + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: ... + + def get_collection( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + user_data: Optional[_AdaptedCollectionProtocol] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + ) -> Union[ + Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter + ]: + data: Collection[Any] + if not passive & PassiveFlag.SQL_OK: + data = self._get_collection_history(state, passive).added_items + else: + history = self._get_collection_history(state, passive) + data = history.added_plus_unchanged + return DynamicCollectionAdapter(data) # type: ignore[return-value] + + @util.memoized_property + def _append_token( # type:ignore[override] + self, + ) -> attributes.AttributeEventToken: + return attributes.AttributeEventToken(self, attributes.OP_APPEND) + + @util.memoized_property + def _remove_token( # type:ignore[override] + self, + ) -> attributes.AttributeEventToken: + return attributes.AttributeEventToken(self, attributes.OP_REMOVE) + + def fire_append_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + collection_history: Optional[WriteOnlyHistory[Any]] = None, + ) -> None: + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_added(value) + + for fn in self.dispatch.append: + value = fn(state, value, initiator or self._append_token) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), state, True) + + def fire_remove_event( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + collection_history: Optional[WriteOnlyHistory[Any]] = None, + ) -> None: + if collection_history is None: + collection_history = self._modified_event(state, dict_) + + collection_history.add_removed(value) + + if self.trackparent and value is not None: + self.sethasparent(attributes.instance_state(value), state, False) + + for fn in self.dispatch.remove: + fn(state, value, initiator or self._remove_token) + + def _modified_event( + self, state: InstanceState[Any], dict_: _InstanceDict + ) -> WriteOnlyHistory[Any]: + if self.key not in state.committed_state: + state.committed_state[self.key] = self.collection_history_cls( + self, state, PassiveFlag.PASSIVE_NO_FETCH + ) + + state._modified_event(dict_, self, NEVER_SET) + + # this is a hack to allow the entities.ComparableEntity fixture + # to work + dict_[self.key] = True + return state.committed_state[self.key] # type: ignore[no-any-return] + + def set( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken] = None, + passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, + check_old: Any = None, + pop: bool = False, + _adapt: bool = True, + ) -> None: + if initiator and initiator.parent_token is self.parent_token: + return + + if pop and value is None: + return + + iterable = value + new_values = list(iterable) + if state.has_identity: + if not self._supports_dynamic_iteration: + raise exc.InvalidRequestError( + f'Collection "{self}" does not support implicit ' + "iteration; collection replacement operations " + "can't be used" + ) + old_collection = util.IdentitySet( + self.get(state, dict_, passive=passive) + ) + + collection_history = self._modified_event(state, dict_) + if not state.has_identity: + old_collection = collection_history.added_items + else: + old_collection = old_collection.union( + collection_history.added_items + ) + + constants = old_collection.intersection(new_values) + additions = util.IdentitySet(new_values).difference(constants) + removals = old_collection.difference(constants) + + for member in new_values: + if member in additions: + self.fire_append_event( + state, + dict_, + member, + None, + collection_history=collection_history, + ) + + for member in removals: + self.fire_remove_event( + state, + dict_, + member, + None, + collection_history=collection_history, + ) + + def delete(self, *args: Any, **kwargs: Any) -> NoReturn: + raise NotImplementedError() + + def set_committed_value( + self, state: InstanceState[Any], dict_: _InstanceDict, value: Any + ) -> NoReturn: + raise NotImplementedError( + "Dynamic attributes don't support collection population." + ) + + def get_history( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH, + ) -> attributes.History: + c = self._get_collection_history(state, passive) + return c.as_history() + + def get_all_pending( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + passive: PassiveFlag = PassiveFlag.PASSIVE_NO_INITIALIZE, + ) -> List[Tuple[InstanceState[Any], Any]]: + c = self._get_collection_history(state, passive) + return [(attributes.instance_state(x), x) for x in c.all_items] + + def _get_collection_history( + self, state: InstanceState[Any], passive: PassiveFlag + ) -> WriteOnlyHistory[Any]: + c: WriteOnlyHistory[Any] + if self.key in state.committed_state: + c = state.committed_state[self.key] + else: + c = self.collection_history_cls( + self, state, PassiveFlag.PASSIVE_NO_FETCH + ) + + if state.has_identity and (passive & PassiveFlag.INIT_OK): + return self.collection_history_cls( + self, state, passive, apply_to=c + ) + else: + return c + + def append( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH, + ) -> None: + if initiator is not self: + self.fire_append_event(state, dict_, value, initiator) + + def remove( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH, + ) -> None: + if initiator is not self: + self.fire_remove_event(state, dict_, value, initiator) + + def pop( + self, + state: InstanceState[Any], + dict_: _InstanceDict, + value: Any, + initiator: Optional[AttributeEventToken], + passive: PassiveFlag = PassiveFlag.PASSIVE_NO_FETCH, + ) -> None: + self.remove(state, dict_, value, initiator, passive=passive) + + +@log.class_logger +@relationships.RelationshipProperty.strategy_for(lazy="write_only") +class WriteOnlyLoader(strategies.AbstractRelationshipLoader, log.Identified): + impl_class = WriteOnlyAttributeImpl + + def init_class_attribute(self, mapper: Mapper[Any]) -> None: + self.is_class_level = True + if not self.uselist or self.parent_property.direction not in ( + interfaces.ONETOMANY, + interfaces.MANYTOMANY, + ): + raise exc.InvalidRequestError( + "On relationship %s, 'dynamic' loaders cannot be used with " + "many-to-one/one-to-one relationships and/or " + "uselist=False." % self.parent_property + ) + + strategies._register_attribute( # type: ignore[no-untyped-call] + self.parent_property, + mapper, + useobject=True, + impl_class=self.impl_class, + target_mapper=self.parent_property.mapper, + order_by=self.parent_property.order_by, + query_class=self.parent_property.query_class, + ) + + +class DynamicCollectionAdapter: + """simplified CollectionAdapter for internal API consistency""" + + data: Collection[Any] + + def __init__(self, data: Collection[Any]): + self.data = data + + def __iter__(self) -> Iterator[Any]: + return iter(self.data) + + def _reset_empty(self) -> None: + pass + + def __len__(self) -> int: + return len(self.data) + + def __bool__(self) -> bool: + return True + + +class AbstractCollectionWriter(Generic[_T]): + """Virtual collection which includes append/remove methods that synchronize + into the attribute event system. + + """ + + if not TYPE_CHECKING: + __slots__ = () + + instance: _T + _from_obj: Tuple[FromClause, ...] + + def __init__(self, attr: WriteOnlyAttributeImpl, state: InstanceState[_T]): + instance = state.obj() + if TYPE_CHECKING: + assert instance + self.instance = instance + self.attr = attr + + mapper = object_mapper(instance) + prop = mapper._props[self.attr.key] + + if prop.secondary is not None: + # this is a hack right now. The Query only knows how to + # make subsequent joins() without a given left-hand side + # from self._from_obj[0]. We need to ensure prop.secondary + # is in the FROM. So we purposely put the mapper selectable + # in _from_obj[0] to ensure a user-defined join() later on + # doesn't fail, and secondary is then in _from_obj[1]. + + # note also, we are using the official ORM-annotated selectable + # from __clause_element__(), see #7868 + self._from_obj = (prop.mapper.__clause_element__(), prop.secondary) + else: + self._from_obj = () + + self._where_criteria = ( + prop._with_parent(instance, alias_secondary=False), + ) + + if self.attr.order_by: + self._order_by_clauses = self.attr.order_by + else: + self._order_by_clauses = () + + def _add_all_impl(self, iterator: Iterable[_T]) -> None: + for item in iterator: + self.attr.append( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), + item, + None, + ) + + def _remove_impl(self, item: _T) -> None: + self.attr.remove( + attributes.instance_state(self.instance), + attributes.instance_dict(self.instance), + item, + None, + ) + + +class WriteOnlyCollection(AbstractCollectionWriter[_T]): + """Write-only collection which can synchronize changes into the + attribute event system. + + The :class:`.WriteOnlyCollection` is used in a mapping by + using the ``"write_only"`` lazy loading strategy with + :func:`_orm.relationship`. For background on this configuration, + see :ref:`write_only_relationship`. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`write_only_relationship` + + """ + + __slots__ = ( + "instance", + "attr", + "_where_criteria", + "_from_obj", + "_order_by_clauses", + ) + + def __iter__(self) -> NoReturn: + raise TypeError( + "WriteOnly collections don't support iteration in-place; " + "to query for collection items, use the select() method to " + "produce a SQL statement and execute it with session.scalars()." + ) + + def select(self) -> Select[Tuple[_T]]: + """Produce a :class:`_sql.Select` construct that represents the + rows within this instance-local :class:`_orm.WriteOnlyCollection`. + + """ + stmt = select(self.attr.target_mapper).where(*self._where_criteria) + if self._from_obj: + stmt = stmt.select_from(*self._from_obj) + if self._order_by_clauses: + stmt = stmt.order_by(*self._order_by_clauses) + return stmt + + def insert(self) -> Insert: + """For one-to-many collections, produce a :class:`_dml.Insert` which + will insert new rows in terms of this this instance-local + :class:`_orm.WriteOnlyCollection`. + + This construct is only supported for a :class:`_orm.Relationship` + that does **not** include the :paramref:`_orm.relationship.secondary` + parameter. For relationships that refer to a many-to-many table, + use ordinary bulk insert techniques to produce new objects, then + use :meth:`_orm.AbstractCollectionWriter.add_all` to associate them + with the collection. + + + """ + + state = inspect(self.instance) + mapper = state.mapper + prop = mapper._props[self.attr.key] + + if prop.direction is not RelationshipDirection.ONETOMANY: + raise exc.InvalidRequestError( + "Write only bulk INSERT only supported for one-to-many " + "collections; for many-to-many, use a separate bulk " + "INSERT along with add_all()." + ) + + dict_: Dict[str, Any] = {} + + for l, r in prop.synchronize_pairs: + fn = prop._get_attr_w_warn_on_none( + mapper, + state, + state.dict, + l, + ) + + dict_[r.key] = bindparam(None, callable_=fn) + + return insert(self.attr.target_mapper).values(**dict_) + + def update(self) -> Update: + """Produce a :class:`_dml.Update` which will refer to rows in terms + of this instance-local :class:`_orm.WriteOnlyCollection`. + + """ + return update(self.attr.target_mapper).where(*self._where_criteria) + + def delete(self) -> Delete: + """Produce a :class:`_dml.Delete` which will refer to rows in terms + of this instance-local :class:`_orm.WriteOnlyCollection`. + + """ + return delete(self.attr.target_mapper).where(*self._where_criteria) + + def add_all(self, iterator: Iterable[_T]) -> None: + """Add an iterable of items to this :class:`_orm.WriteOnlyCollection`. + + The given items will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl(iterator) + + def add(self, item: _T) -> None: + """Add an item to this :class:`_orm.WriteOnlyCollection`. + + The given item will be persisted to the database in terms of + the parent instance's collection on the next flush. + + """ + self._add_all_impl([item]) + + def remove(self, item: _T) -> None: + """Remove an item from this :class:`_orm.WriteOnlyCollection`. + + The given item will be removed from the parent instance's collection on + the next flush. + + """ + self._remove_impl(item) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__init__.py new file mode 100644 index 0000000..29fd652 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__init__.py @@ -0,0 +1,44 @@ +# pool/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +"""Connection pooling for DB-API connections. + +Provides a number of connection pool implementations for a variety of +usage scenarios and thread behavior requirements imposed by the +application, DB-API or database itself. + +Also provides a DB-API 2.0 connection proxying mechanism allowing +regular DB-API connect() methods to be transparently managed by a +SQLAlchemy connection pool. +""" + +from . import events +from .base import _AdhocProxiedConnection as _AdhocProxiedConnection +from .base import _ConnectionFairy as _ConnectionFairy +from .base import _ConnectionRecord +from .base import _CreatorFnType as _CreatorFnType +from .base import _CreatorWRecFnType as _CreatorWRecFnType +from .base import _finalize_fairy +from .base import _ResetStyleArgType as _ResetStyleArgType +from .base import ConnectionPoolEntry as ConnectionPoolEntry +from .base import ManagesConnection as ManagesConnection +from .base import Pool as Pool +from .base import PoolProxiedConnection as PoolProxiedConnection +from .base import PoolResetState as PoolResetState +from .base import reset_commit as reset_commit +from .base import reset_none as reset_none +from .base import reset_rollback as reset_rollback +from .impl import AssertionPool as AssertionPool +from .impl import AsyncAdaptedQueuePool as AsyncAdaptedQueuePool +from .impl import ( + FallbackAsyncAdaptedQueuePool as FallbackAsyncAdaptedQueuePool, +) +from .impl import NullPool as NullPool +from .impl import QueuePool as QueuePool +from .impl import SingletonThreadPool as SingletonThreadPool +from .impl import StaticPool as StaticPool diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..6bb2901 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..14fc526 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/events.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/events.cpython-311.pyc new file mode 100644 index 0000000..0fb8362 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/events.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/impl.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/impl.cpython-311.pyc new file mode 100644 index 0000000..7939b43 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/pool/__pycache__/impl.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/pool/base.py new file mode 100644 index 0000000..98d2027 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/pool/base.py @@ -0,0 +1,1515 @@ +# pool/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +"""Base constructs for connection pools. + +""" + +from __future__ import annotations + +from collections import deque +import dataclasses +from enum import Enum +import threading +import time +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Deque +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from .. import event +from .. import exc +from .. import log +from .. import util +from ..util.typing import Literal +from ..util.typing import Protocol + +if TYPE_CHECKING: + from ..engine.interfaces import DBAPIConnection + from ..engine.interfaces import DBAPICursor + from ..engine.interfaces import Dialect + from ..event import _DispatchCommon + from ..event import _ListenerFnType + from ..event import dispatcher + from ..sql._typing import _InfoType + + +@dataclasses.dataclass(frozen=True) +class PoolResetState: + """describes the state of a DBAPI connection as it is being passed to + the :meth:`.PoolEvents.reset` connection pool event. + + .. versionadded:: 2.0.0b3 + + """ + + __slots__ = ("transaction_was_reset", "terminate_only", "asyncio_safe") + + transaction_was_reset: bool + """Indicates if the transaction on the DBAPI connection was already + essentially "reset" back by the :class:`.Connection` object. + + This boolean is True if the :class:`.Connection` had transactional + state present upon it, which was then not closed using the + :meth:`.Connection.rollback` or :meth:`.Connection.commit` method; + instead, the transaction was closed inline within the + :meth:`.Connection.close` method so is guaranteed to remain non-present + when this event is reached. + + """ + + terminate_only: bool + """indicates if the connection is to be immediately terminated and + not checked in to the pool. + + This occurs for connections that were invalidated, as well as asyncio + connections that were not cleanly handled by the calling code that + are instead being garbage collected. In the latter case, + operations can't be safely run on asyncio connections within garbage + collection as there is not necessarily an event loop present. + + """ + + asyncio_safe: bool + """Indicates if the reset operation is occurring within a scope where + an enclosing event loop is expected to be present for asyncio applications. + + Will be False in the case that the connection is being garbage collected. + + """ + + +class ResetStyle(Enum): + """Describe options for "reset on return" behaviors.""" + + reset_rollback = 0 + reset_commit = 1 + reset_none = 2 + + +_ResetStyleArgType = Union[ + ResetStyle, + Literal[True, None, False, "commit", "rollback"], +] +reset_rollback, reset_commit, reset_none = list(ResetStyle) + + +class _ConnDialect: + """partial implementation of :class:`.Dialect` + which provides DBAPI connection methods. + + When a :class:`_pool.Pool` is combined with an :class:`_engine.Engine`, + the :class:`_engine.Engine` replaces this with its own + :class:`.Dialect`. + + """ + + is_async = False + has_terminate = False + + def do_rollback(self, dbapi_connection: PoolProxiedConnection) -> None: + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection: PoolProxiedConnection) -> None: + dbapi_connection.commit() + + def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: + dbapi_connection.close() + + def do_close(self, dbapi_connection: DBAPIConnection) -> None: + dbapi_connection.close() + + def _do_ping_w_event(self, dbapi_connection: DBAPIConnection) -> bool: + raise NotImplementedError( + "The ping feature requires that a dialect is " + "passed to the connection pool." + ) + + def get_driver_connection(self, connection: DBAPIConnection) -> Any: + return connection + + +class _AsyncConnDialect(_ConnDialect): + is_async = True + + +class _CreatorFnType(Protocol): + def __call__(self) -> DBAPIConnection: ... + + +class _CreatorWRecFnType(Protocol): + def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: ... + + +class Pool(log.Identified, event.EventTarget): + """Abstract base class for connection pools.""" + + dispatch: dispatcher[Pool] + echo: log._EchoFlagType + + _orig_logging_name: Optional[str] + _dialect: Union[_ConnDialect, Dialect] = _ConnDialect() + _creator_arg: Union[_CreatorFnType, _CreatorWRecFnType] + _invoke_creator: _CreatorWRecFnType + _invalidate_time: float + + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + recycle: int = -1, + echo: log._EchoFlagType = None, + logging_name: Optional[str] = None, + reset_on_return: _ResetStyleArgType = True, + events: Optional[List[Tuple[_ListenerFnType, str]]] = None, + dialect: Optional[Union[_ConnDialect, Dialect]] = None, + pre_ping: bool = False, + _dispatch: Optional[_DispatchCommon[Pool]] = None, + ): + """ + Construct a Pool. + + :param creator: a callable function that returns a DB-API + connection object. The function will be called with + parameters. + + :param recycle: If set to a value other than -1, number of + seconds between connection recycling, which means upon + checkout, if this timeout is surpassed the connection will be + closed and replaced with a newly opened connection. Defaults to -1. + + :param logging_name: String identifier which will be used within + the "name" field of logging records generated within the + "sqlalchemy.pool" logger. Defaults to a hexstring of the object's + id. + + :param echo: if True, the connection pool will log + informational output such as when connections are invalidated + as well as when connections are recycled to the default log handler, + which defaults to ``sys.stdout`` for output.. If set to the string + ``"debug"``, the logging will include pool checkouts and checkins. + + The :paramref:`_pool.Pool.echo` parameter can also be set from the + :func:`_sa.create_engine` call by using the + :paramref:`_sa.create_engine.echo_pool` parameter. + + .. seealso:: + + :ref:`dbengine_logging` - further detail on how to configure + logging. + + :param reset_on_return: Determine steps to take on + connections as they are returned to the pool, which were + not otherwise handled by a :class:`_engine.Connection`. + Available from :func:`_sa.create_engine` via the + :paramref:`_sa.create_engine.pool_reset_on_return` parameter. + + :paramref:`_pool.Pool.reset_on_return` can have any of these values: + + * ``"rollback"`` - call rollback() on the connection, + to release locks and transaction resources. + This is the default value. The vast majority + of use cases should leave this value set. + * ``"commit"`` - call commit() on the connection, + to release locks and transaction resources. + A commit here may be desirable for databases that + cache query plans if a commit is emitted, + such as Microsoft SQL Server. However, this + value is more dangerous than 'rollback' because + any data changes present on the transaction + are committed unconditionally. + * ``None`` - don't do anything on the connection. + This setting may be appropriate if the database / DBAPI + works in pure "autocommit" mode at all times, or if + a custom reset handler is established using the + :meth:`.PoolEvents.reset` event handler. + + * ``True`` - same as 'rollback', this is here for + backwards compatibility. + * ``False`` - same as None, this is here for + backwards compatibility. + + For further customization of reset on return, the + :meth:`.PoolEvents.reset` event hook may be used which can perform + any connection activity desired on reset. + + .. seealso:: + + :ref:`pool_reset_on_return` + + :meth:`.PoolEvents.reset` + + :param events: a list of 2-tuples, each of the form + ``(callable, target)`` which will be passed to :func:`.event.listen` + upon construction. Provided here so that event listeners + can be assigned via :func:`_sa.create_engine` before dialect-level + listeners are applied. + + :param dialect: a :class:`.Dialect` that will handle the job + of calling rollback(), close(), or commit() on DBAPI connections. + If omitted, a built-in "stub" dialect is used. Applications that + make use of :func:`_sa.create_engine` should not use this parameter + as it is handled by the engine creation strategy. + + :param pre_ping: if True, the pool will emit a "ping" (typically + "SELECT 1", but is dialect-specific) on the connection + upon checkout, to test if the connection is alive or not. If not, + the connection is transparently re-connected and upon success, all + other pooled connections established prior to that timestamp are + invalidated. Requires that a dialect is passed as well to + interpret the disconnection error. + + .. versionadded:: 1.2 + + """ + if logging_name: + self.logging_name = self._orig_logging_name = logging_name + else: + self._orig_logging_name = None + + log.instance_logger(self, echoflag=echo) + self._creator = creator + self._recycle = recycle + self._invalidate_time = 0 + self._pre_ping = pre_ping + self._reset_on_return = util.parse_user_argument_for_enum( + reset_on_return, + { + ResetStyle.reset_rollback: ["rollback", True], + ResetStyle.reset_none: ["none", None, False], + ResetStyle.reset_commit: ["commit"], + }, + "reset_on_return", + ) + + self.echo = echo + + if _dispatch: + self.dispatch._update(_dispatch, only_propagate=False) + if dialect: + self._dialect = dialect + if events: + for fn, target in events: + event.listen(self, target, fn) + + @util.hybridproperty + def _is_asyncio(self) -> bool: + return self._dialect.is_async + + @property + def _creator(self) -> Union[_CreatorFnType, _CreatorWRecFnType]: + return self._creator_arg + + @_creator.setter + def _creator( + self, creator: Union[_CreatorFnType, _CreatorWRecFnType] + ) -> None: + self._creator_arg = creator + + # mypy seems to get super confused assigning functions to + # attributes + self._invoke_creator = self._should_wrap_creator(creator) + + @_creator.deleter + def _creator(self) -> None: + # needed for mock testing + del self._creator_arg + del self._invoke_creator + + def _should_wrap_creator( + self, creator: Union[_CreatorFnType, _CreatorWRecFnType] + ) -> _CreatorWRecFnType: + """Detect if creator accepts a single argument, or is sent + as a legacy style no-arg function. + + """ + + try: + argspec = util.get_callable_argspec(self._creator, no_self=True) + except TypeError: + creator_fn = cast(_CreatorFnType, creator) + return lambda rec: creator_fn() + + if argspec.defaults is not None: + defaulted = len(argspec.defaults) + else: + defaulted = 0 + positionals = len(argspec[0]) - defaulted + + # look for the exact arg signature that DefaultStrategy + # sends us + if (argspec[0], argspec[3]) == (["connection_record"], (None,)): + return cast(_CreatorWRecFnType, creator) + # or just a single positional + elif positionals == 1: + return cast(_CreatorWRecFnType, creator) + # all other cases, just wrap and assume legacy "creator" callable + # thing + else: + creator_fn = cast(_CreatorFnType, creator) + return lambda rec: creator_fn() + + def _close_connection( + self, connection: DBAPIConnection, *, terminate: bool = False + ) -> None: + self.logger.debug( + "%s connection %r", + "Hard-closing" if terminate else "Closing", + connection, + ) + try: + if terminate: + self._dialect.do_terminate(connection) + else: + self._dialect.do_close(connection) + except BaseException as e: + self.logger.error( + f"Exception {'terminating' if terminate else 'closing'} " + f"connection %r", + connection, + exc_info=True, + ) + if not isinstance(e, Exception): + raise + + def _create_connection(self) -> ConnectionPoolEntry: + """Called by subclasses to create a new ConnectionRecord.""" + + return _ConnectionRecord(self) + + def _invalidate( + self, + connection: PoolProxiedConnection, + exception: Optional[BaseException] = None, + _checkin: bool = True, + ) -> None: + """Mark all connections established within the generation + of the given connection as invalidated. + + If this pool's last invalidate time is before when the given + connection was created, update the timestamp til now. Otherwise, + no action is performed. + + Connections with a start time prior to this pool's invalidation + time will be recycled upon next checkout. + """ + rec = getattr(connection, "_connection_record", None) + if not rec or self._invalidate_time < rec.starttime: + self._invalidate_time = time.time() + if _checkin and getattr(connection, "is_valid", False): + connection.invalidate(exception) + + def recreate(self) -> Pool: + """Return a new :class:`_pool.Pool`, of the same class as this one + and configured with identical creation arguments. + + This method is used in conjunction with :meth:`dispose` + to close out an entire :class:`_pool.Pool` and create a new one in + its place. + + """ + + raise NotImplementedError() + + def dispose(self) -> None: + """Dispose of this pool. + + This method leaves the possibility of checked-out connections + remaining open, as it only affects connections that are + idle in the pool. + + .. seealso:: + + :meth:`Pool.recreate` + + """ + + raise NotImplementedError() + + def connect(self) -> PoolProxiedConnection: + """Return a DBAPI connection from the pool. + + The connection is instrumented such that when its + ``close()`` method is called, the connection will be returned to + the pool. + + """ + return _ConnectionFairy._checkout(self) + + def _return_conn(self, record: ConnectionPoolEntry) -> None: + """Given a _ConnectionRecord, return it to the :class:`_pool.Pool`. + + This method is called when an instrumented DBAPI connection + has its ``close()`` method called. + + """ + self._do_return_conn(record) + + def _do_get(self) -> ConnectionPoolEntry: + """Implementation for :meth:`get`, supplied by subclasses.""" + + raise NotImplementedError() + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + """Implementation for :meth:`return_conn`, supplied by subclasses.""" + + raise NotImplementedError() + + def status(self) -> str: + raise NotImplementedError() + + +class ManagesConnection: + """Common base for the two connection-management interfaces + :class:`.PoolProxiedConnection` and :class:`.ConnectionPoolEntry`. + + These two objects are typically exposed in the public facing API + via the connection pool event hooks, documented at :class:`.PoolEvents`. + + .. versionadded:: 2.0 + + """ + + __slots__ = () + + dbapi_connection: Optional[DBAPIConnection] + """A reference to the actual DBAPI connection being tracked. + + This is a :pep:`249`-compliant object that for traditional sync-style + dialects is provided by the third-party + DBAPI implementation in use. For asyncio dialects, the implementation + is typically an adapter object provided by the SQLAlchemy dialect + itself; the underlying asyncio object is available via the + :attr:`.ManagesConnection.driver_connection` attribute. + + SQLAlchemy's interface for the DBAPI connection is based on the + :class:`.DBAPIConnection` protocol object + + .. seealso:: + + :attr:`.ManagesConnection.driver_connection` + + :ref:`faq_dbapi_connection` + + """ + + driver_connection: Optional[Any] + """The "driver level" connection object as used by the Python + DBAPI or database driver. + + For traditional :pep:`249` DBAPI implementations, this object will + be the same object as that of + :attr:`.ManagesConnection.dbapi_connection`. For an asyncio database + driver, this will be the ultimate "connection" object used by that + driver, such as the ``asyncpg.Connection`` object which will not have + standard pep-249 methods. + + .. versionadded:: 1.4.24 + + .. seealso:: + + :attr:`.ManagesConnection.dbapi_connection` + + :ref:`faq_dbapi_connection` + + """ + + @util.ro_memoized_property + def info(self) -> _InfoType: + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ManagesConnection` instance, allowing + user-defined data to be associated with the connection. + + The data in this dictionary is persistent for the lifespan + of the DBAPI connection itself, including across pool checkins + and checkouts. When the connection is invalidated + and replaced with a new one, this dictionary is cleared. + + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns a dictionary that is local to that + :class:`.ConnectionPoolEntry`. Therefore the + :attr:`.ManagesConnection.info` attribute will always provide a Python + dictionary. + + .. seealso:: + + :attr:`.ManagesConnection.record_info` + + + """ + raise NotImplementedError() + + @util.ro_memoized_property + def record_info(self) -> Optional[_InfoType]: + """Persistent info dictionary associated with this + :class:`.ManagesConnection`. + + Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan + of this dictionary is that of the :class:`.ConnectionPoolEntry` + which owns it; therefore this dictionary will persist across + reconnects and connection invalidation for a particular entry + in the connection pool. + + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns None. Contrast to the :attr:`.ManagesConnection.info` + dictionary which is never None. + + + .. seealso:: + + :attr:`.ManagesConnection.info` + + """ + raise NotImplementedError() + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + """Mark the managed connection as invalidated. + + :param e: an exception object indicating a reason for the invalidation. + + :param soft: if True, the connection isn't closed; instead, this + connection will be recycled on next checkout. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + + """ + raise NotImplementedError() + + +class ConnectionPoolEntry(ManagesConnection): + """Interface for the object that maintains an individual database + connection on behalf of a :class:`_pool.Pool` instance. + + The :class:`.ConnectionPoolEntry` object represents the long term + maintainance of a particular connection for a pool, including expiring or + invalidating that connection to have it replaced with a new one, which will + continue to be maintained by that same :class:`.ConnectionPoolEntry` + instance. Compared to :class:`.PoolProxiedConnection`, which is the + short-term, per-checkout connection manager, this object lasts for the + lifespan of a particular "slot" within a connection pool. + + The :class:`.ConnectionPoolEntry` object is mostly visible to public-facing + API code when it is delivered to connection pool event hooks, such as + :meth:`_events.PoolEvents.connect` and :meth:`_events.PoolEvents.checkout`. + + .. versionadded:: 2.0 :class:`.ConnectionPoolEntry` provides the public + facing interface for the :class:`._ConnectionRecord` internal class. + + """ + + __slots__ = () + + @property + def in_use(self) -> bool: + """Return True the connection is currently checked out""" + + raise NotImplementedError() + + def close(self) -> None: + """Close the DBAPI connection managed by this connection pool entry.""" + raise NotImplementedError() + + +class _ConnectionRecord(ConnectionPoolEntry): + """Maintains a position in a connection pool which references a pooled + connection. + + This is an internal object used by the :class:`_pool.Pool` implementation + to provide context management to a DBAPI connection maintained by + that :class:`_pool.Pool`. The public facing interface for this class + is described by the :class:`.ConnectionPoolEntry` class. See that + class for public API details. + + .. seealso:: + + :class:`.ConnectionPoolEntry` + + :class:`.PoolProxiedConnection` + + """ + + __slots__ = ( + "__pool", + "fairy_ref", + "finalize_callback", + "fresh", + "starttime", + "dbapi_connection", + "__weakref__", + "__dict__", + ) + + finalize_callback: Deque[Callable[[DBAPIConnection], None]] + fresh: bool + fairy_ref: Optional[weakref.ref[_ConnectionFairy]] + starttime: float + + def __init__(self, pool: Pool, connect: bool = True): + self.fresh = False + self.fairy_ref = None + self.starttime = 0 + self.dbapi_connection = None + + self.__pool = pool + if connect: + self.__connect() + self.finalize_callback = deque() + + dbapi_connection: Optional[DBAPIConnection] + + @property + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa: E501 + if self.dbapi_connection is None: + return None + else: + return self.__pool._dialect.get_driver_connection( + self.dbapi_connection + ) + + @property + @util.deprecated( + "2.0", + "The _ConnectionRecord.connection attribute is deprecated; " + "please use 'driver_connection'", + ) + def connection(self) -> Optional[DBAPIConnection]: + return self.dbapi_connection + + _soft_invalidate_time: float = 0 + + @util.ro_memoized_property + def info(self) -> _InfoType: + return {} + + @util.ro_memoized_property + def record_info(self) -> Optional[_InfoType]: + return {} + + @classmethod + def checkout(cls, pool: Pool) -> _ConnectionFairy: + if TYPE_CHECKING: + rec = cast(_ConnectionRecord, pool._do_get()) + else: + rec = pool._do_get() + + try: + dbapi_connection = rec.get_connection() + except BaseException as err: + with util.safe_reraise(): + rec._checkin_failed(err, _fairy_was_created=False) + + # not reached, for code linters only + raise + + echo = pool._should_log_debug() + fairy = _ConnectionFairy(pool, dbapi_connection, rec, echo) + + rec.fairy_ref = ref = weakref.ref( + fairy, + lambda ref: ( + _finalize_fairy( + None, rec, pool, ref, echo, transaction_was_reset=False + ) + if _finalize_fairy is not None + else None + ), + ) + _strong_ref_connection_records[ref] = rec + if echo: + pool.logger.debug( + "Connection %r checked out from pool", dbapi_connection + ) + return fairy + + def _checkin_failed( + self, err: BaseException, _fairy_was_created: bool = True + ) -> None: + self.invalidate(e=err) + self.checkin( + _fairy_was_created=_fairy_was_created, + ) + + def checkin(self, _fairy_was_created: bool = True) -> None: + if self.fairy_ref is None and _fairy_was_created: + # _fairy_was_created is False for the initial get connection phase; + # meaning there was no _ConnectionFairy and we must unconditionally + # do a checkin. + # + # otherwise, if fairy_was_created==True, if fairy_ref is None here + # that means we were checked in already, so this looks like + # a double checkin. + util.warn("Double checkin attempted on %s" % self) + return + self.fairy_ref = None + connection = self.dbapi_connection + pool = self.__pool + while self.finalize_callback: + finalizer = self.finalize_callback.pop() + if connection is not None: + finalizer(connection) + if pool.dispatch.checkin: + pool.dispatch.checkin(connection, self) + + pool._return_conn(self) + + @property + def in_use(self) -> bool: + return self.fairy_ref is not None + + @property + def last_connect_time(self) -> float: + return self.starttime + + def close(self) -> None: + if self.dbapi_connection is not None: + self.__close() + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + # already invalidated + if self.dbapi_connection is None: + return + if soft: + self.__pool.dispatch.soft_invalidate( + self.dbapi_connection, self, e + ) + else: + self.__pool.dispatch.invalidate(self.dbapi_connection, self, e) + if e is not None: + self.__pool.logger.info( + "%sInvalidate connection %r (reason: %s:%s)", + "Soft " if soft else "", + self.dbapi_connection, + e.__class__.__name__, + e, + ) + else: + self.__pool.logger.info( + "%sInvalidate connection %r", + "Soft " if soft else "", + self.dbapi_connection, + ) + + if soft: + self._soft_invalidate_time = time.time() + else: + self.__close(terminate=True) + self.dbapi_connection = None + + def get_connection(self) -> DBAPIConnection: + recycle = False + + # NOTE: the various comparisons here are assuming that measurable time + # passes between these state changes. however, time.time() is not + # guaranteed to have sub-second precision. comparisons of + # "invalidation time" to "starttime" should perhaps use >= so that the + # state change can take place assuming no measurable time has passed, + # however this does not guarantee correct behavior here as if time + # continues to not pass, it will try to reconnect repeatedly until + # these timestamps diverge, so in that sense using > is safer. Per + # https://stackoverflow.com/a/1938096/34549, Windows time.time() may be + # within 16 milliseconds accuracy, so unit tests for connection + # invalidation need a sleep of at least this long between initial start + # time and invalidation for the logic below to work reliably. + + if self.dbapi_connection is None: + self.info.clear() + self.__connect() + elif ( + self.__pool._recycle > -1 + and time.time() - self.starttime > self.__pool._recycle + ): + self.__pool.logger.info( + "Connection %r exceeded timeout; recycling", + self.dbapi_connection, + ) + recycle = True + elif self.__pool._invalidate_time > self.starttime: + self.__pool.logger.info( + "Connection %r invalidated due to pool invalidation; " + + "recycling", + self.dbapi_connection, + ) + recycle = True + elif self._soft_invalidate_time > self.starttime: + self.__pool.logger.info( + "Connection %r invalidated due to local soft invalidation; " + + "recycling", + self.dbapi_connection, + ) + recycle = True + + if recycle: + self.__close(terminate=True) + self.info.clear() + + self.__connect() + + assert self.dbapi_connection is not None + return self.dbapi_connection + + def _is_hard_or_soft_invalidated(self) -> bool: + return ( + self.dbapi_connection is None + or self.__pool._invalidate_time > self.starttime + or (self._soft_invalidate_time > self.starttime) + ) + + def __close(self, *, terminate: bool = False) -> None: + self.finalize_callback.clear() + if self.__pool.dispatch.close: + self.__pool.dispatch.close(self.dbapi_connection, self) + assert self.dbapi_connection is not None + self.__pool._close_connection( + self.dbapi_connection, terminate=terminate + ) + self.dbapi_connection = None + + def __connect(self) -> None: + pool = self.__pool + + # ensure any existing connection is removed, so that if + # creator fails, this attribute stays None + self.dbapi_connection = None + try: + self.starttime = time.time() + self.dbapi_connection = connection = pool._invoke_creator(self) + pool.logger.debug("Created new connection %r", connection) + self.fresh = True + except BaseException as e: + with util.safe_reraise(): + pool.logger.debug("Error on connect(): %s", e) + else: + # in SQLAlchemy 1.4 the first_connect event is not used by + # the engine, so this will usually not be set + if pool.dispatch.first_connect: + pool.dispatch.first_connect.for_modify( + pool.dispatch + ).exec_once_unless_exception(self.dbapi_connection, self) + + # init of the dialect now takes place within the connect + # event, so ensure a mutex is used on the first run + pool.dispatch.connect.for_modify( + pool.dispatch + )._exec_w_sync_on_first_run(self.dbapi_connection, self) + + +def _finalize_fairy( + dbapi_connection: Optional[DBAPIConnection], + connection_record: Optional[_ConnectionRecord], + pool: Pool, + ref: Optional[ + weakref.ref[_ConnectionFairy] + ], # this is None when called directly, not by the gc + echo: Optional[log._EchoFlagType], + transaction_was_reset: bool = False, + fairy: Optional[_ConnectionFairy] = None, +) -> None: + """Cleanup for a :class:`._ConnectionFairy` whether or not it's already + been garbage collected. + + When using an async dialect no IO can happen here (without using + a dedicated thread), since this is called outside the greenlet + context and with an already running loop. In this case function + will only log a message and raise a warning. + """ + + is_gc_cleanup = ref is not None + + if is_gc_cleanup: + assert ref is not None + _strong_ref_connection_records.pop(ref, None) + assert connection_record is not None + if connection_record.fairy_ref is not ref: + return + assert dbapi_connection is None + dbapi_connection = connection_record.dbapi_connection + + elif fairy: + _strong_ref_connection_records.pop(weakref.ref(fairy), None) + + # null pool is not _is_asyncio but can be used also with async dialects + dont_restore_gced = pool._dialect.is_async + + if dont_restore_gced: + detach = connection_record is None or is_gc_cleanup + can_manipulate_connection = not is_gc_cleanup + can_close_or_terminate_connection = ( + not pool._dialect.is_async or pool._dialect.has_terminate + ) + requires_terminate_for_close = ( + pool._dialect.is_async and pool._dialect.has_terminate + ) + + else: + detach = connection_record is None + can_manipulate_connection = can_close_or_terminate_connection = True + requires_terminate_for_close = False + + if dbapi_connection is not None: + if connection_record and echo: + pool.logger.debug( + "Connection %r being returned to pool", dbapi_connection + ) + + try: + if not fairy: + assert connection_record is not None + fairy = _ConnectionFairy( + pool, + dbapi_connection, + connection_record, + echo, + ) + assert fairy.dbapi_connection is dbapi_connection + + fairy._reset( + pool, + transaction_was_reset=transaction_was_reset, + terminate_only=detach, + asyncio_safe=can_manipulate_connection, + ) + + if detach: + if connection_record: + fairy._pool = pool + fairy.detach() + + if can_close_or_terminate_connection: + if pool.dispatch.close_detached: + pool.dispatch.close_detached(dbapi_connection) + + pool._close_connection( + dbapi_connection, + terminate=requires_terminate_for_close, + ) + + except BaseException as e: + pool.logger.error( + "Exception during reset or similar", exc_info=True + ) + if connection_record: + connection_record.invalidate(e=e) + if not isinstance(e, Exception): + raise + finally: + if detach and is_gc_cleanup and dont_restore_gced: + message = ( + "The garbage collector is trying to clean up " + f"non-checked-in connection {dbapi_connection!r}, " + f"""which will be { + 'dropped, as it cannot be safely terminated' + if not can_close_or_terminate_connection + else 'terminated' + }. """ + "Please ensure that SQLAlchemy pooled connections are " + "returned to " + "the pool explicitly, either by calling ``close()`` " + "or by using appropriate context managers to manage " + "their lifecycle." + ) + pool.logger.error(message) + util.warn(message) + + if connection_record and connection_record.fairy_ref is not None: + connection_record.checkin() + + # give gc some help. See + # test/engine/test_pool.py::PoolEventsTest::test_checkin_event_gc[True] + # which actually started failing when pytest warnings plugin was + # turned on, due to util.warn() above + if fairy is not None: + fairy.dbapi_connection = None # type: ignore + fairy._connection_record = None + del dbapi_connection + del connection_record + del fairy + + +# a dictionary of the _ConnectionFairy weakrefs to _ConnectionRecord, so that +# GC under pypy will call ConnectionFairy finalizers. linked directly to the +# weakref that will empty itself when collected so that it should not create +# any unmanaged memory references. +_strong_ref_connection_records: Dict[ + weakref.ref[_ConnectionFairy], _ConnectionRecord +] = {} + + +class PoolProxiedConnection(ManagesConnection): + """A connection-like adapter for a :pep:`249` DBAPI connection, which + includes additional methods specific to the :class:`.Pool` implementation. + + :class:`.PoolProxiedConnection` is the public-facing interface for the + internal :class:`._ConnectionFairy` implementation object; users familiar + with :class:`._ConnectionFairy` can consider this object to be equivalent. + + .. versionadded:: 2.0 :class:`.PoolProxiedConnection` provides the public- + facing interface for the :class:`._ConnectionFairy` internal class. + + """ + + __slots__ = () + + if typing.TYPE_CHECKING: + + def commit(self) -> None: ... + + def cursor(self) -> DBAPICursor: ... + + def rollback(self) -> None: ... + + @property + def is_valid(self) -> bool: + """Return True if this :class:`.PoolProxiedConnection` still refers + to an active DBAPI connection.""" + + raise NotImplementedError() + + @property + def is_detached(self) -> bool: + """Return True if this :class:`.PoolProxiedConnection` is detached + from its pool.""" + + raise NotImplementedError() + + def detach(self) -> None: + """Separate this connection from its Pool. + + This means that the connection will no longer be returned to the + pool when closed, and will instead be literally closed. The + associated :class:`.ConnectionPoolEntry` is de-associated from this + DBAPI connection. + + Note that any overall connection limiting constraints imposed by a + Pool implementation may be violated after a detach, as the detached + connection is removed from the pool's knowledge and control. + + """ + + raise NotImplementedError() + + def close(self) -> None: + """Release this connection back to the pool. + + The :meth:`.PoolProxiedConnection.close` method shadows the + :pep:`249` ``.close()`` method, altering its behavior to instead + :term:`release` the proxied connection back to the connection pool. + + Upon release to the pool, whether the connection stays "opened" and + pooled in the Python process, versus actually closed out and removed + from the Python process, is based on the pool implementation in use and + its configuration and current state. + + """ + raise NotImplementedError() + + +class _AdhocProxiedConnection(PoolProxiedConnection): + """provides the :class:`.PoolProxiedConnection` interface for cases where + the DBAPI connection is not actually proxied. + + This is used by the engine internals to pass a consistent + :class:`.PoolProxiedConnection` object to consuming dialects in response to + pool events that may not always have the :class:`._ConnectionFairy` + available. + + """ + + __slots__ = ("dbapi_connection", "_connection_record", "_is_valid") + + dbapi_connection: DBAPIConnection + _connection_record: ConnectionPoolEntry + + def __init__( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ): + self.dbapi_connection = dbapi_connection + self._connection_record = connection_record + self._is_valid = True + + @property + def driver_connection(self) -> Any: # type: ignore[override] # mypy#4125 + return self._connection_record.driver_connection + + @property + def connection(self) -> DBAPIConnection: + return self.dbapi_connection + + @property + def is_valid(self) -> bool: + """Implement is_valid state attribute. + + for the adhoc proxied connection it's assumed the connection is valid + as there is no "invalidate" routine. + + """ + return self._is_valid + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + self._is_valid = False + + @util.ro_non_memoized_property + def record_info(self) -> Optional[_InfoType]: + return self._connection_record.record_info + + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: + return self.dbapi_connection.cursor(*args, **kwargs) + + def __getattr__(self, key: Any) -> Any: + return getattr(self.dbapi_connection, key) + + +class _ConnectionFairy(PoolProxiedConnection): + """Proxies a DBAPI connection and provides return-on-dereference + support. + + This is an internal object used by the :class:`_pool.Pool` implementation + to provide context management to a DBAPI connection delivered by + that :class:`_pool.Pool`. The public facing interface for this class + is described by the :class:`.PoolProxiedConnection` class. See that + class for public API details. + + The name "fairy" is inspired by the fact that the + :class:`._ConnectionFairy` object's lifespan is transitory, as it lasts + only for the length of a specific DBAPI connection being checked out from + the pool, and additionally that as a transparent proxy, it is mostly + invisible. + + .. seealso:: + + :class:`.PoolProxiedConnection` + + :class:`.ConnectionPoolEntry` + + + """ + + __slots__ = ( + "dbapi_connection", + "_connection_record", + "_echo", + "_pool", + "_counter", + "__weakref__", + "__dict__", + ) + + pool: Pool + dbapi_connection: DBAPIConnection + _echo: log._EchoFlagType + + def __init__( + self, + pool: Pool, + dbapi_connection: DBAPIConnection, + connection_record: _ConnectionRecord, + echo: log._EchoFlagType, + ): + self._pool = pool + self._counter = 0 + self.dbapi_connection = dbapi_connection + self._connection_record = connection_record + self._echo = echo + + _connection_record: Optional[_ConnectionRecord] + + @property + def driver_connection(self) -> Optional[Any]: # type: ignore[override] # mypy#4125 # noqa: E501 + if self._connection_record is None: + return None + return self._connection_record.driver_connection + + @property + @util.deprecated( + "2.0", + "The _ConnectionFairy.connection attribute is deprecated; " + "please use 'driver_connection'", + ) + def connection(self) -> DBAPIConnection: + return self.dbapi_connection + + @classmethod + def _checkout( + cls, + pool: Pool, + threadconns: Optional[threading.local] = None, + fairy: Optional[_ConnectionFairy] = None, + ) -> _ConnectionFairy: + if not fairy: + fairy = _ConnectionRecord.checkout(pool) + + if threadconns is not None: + threadconns.current = weakref.ref(fairy) + + assert ( + fairy._connection_record is not None + ), "can't 'checkout' a detached connection fairy" + assert ( + fairy.dbapi_connection is not None + ), "can't 'checkout' an invalidated connection fairy" + + fairy._counter += 1 + if ( + not pool.dispatch.checkout and not pool._pre_ping + ) or fairy._counter != 1: + return fairy + + # Pool listeners can trigger a reconnection on checkout, as well + # as the pre-pinger. + # there are three attempts made here, but note that if the database + # is not accessible from a connection standpoint, those won't proceed + # here. + + attempts = 2 + + while attempts > 0: + connection_is_fresh = fairy._connection_record.fresh + fairy._connection_record.fresh = False + try: + if pool._pre_ping: + if not connection_is_fresh: + if fairy._echo: + pool.logger.debug( + "Pool pre-ping on connection %s", + fairy.dbapi_connection, + ) + result = pool._dialect._do_ping_w_event( + fairy.dbapi_connection + ) + if not result: + if fairy._echo: + pool.logger.debug( + "Pool pre-ping on connection %s failed, " + "will invalidate pool", + fairy.dbapi_connection, + ) + raise exc.InvalidatePoolError() + elif fairy._echo: + pool.logger.debug( + "Connection %s is fresh, skipping pre-ping", + fairy.dbapi_connection, + ) + + pool.dispatch.checkout( + fairy.dbapi_connection, fairy._connection_record, fairy + ) + return fairy + except exc.DisconnectionError as e: + if e.invalidate_pool: + pool.logger.info( + "Disconnection detected on checkout, " + "invalidating all pooled connections prior to " + "current timestamp (reason: %r)", + e, + ) + fairy._connection_record.invalidate(e) + pool._invalidate(fairy, e, _checkin=False) + else: + pool.logger.info( + "Disconnection detected on checkout, " + "invalidating individual connection %s (reason: %r)", + fairy.dbapi_connection, + e, + ) + fairy._connection_record.invalidate(e) + try: + fairy.dbapi_connection = ( + fairy._connection_record.get_connection() + ) + except BaseException as err: + with util.safe_reraise(): + fairy._connection_record._checkin_failed( + err, + _fairy_was_created=True, + ) + + # prevent _ConnectionFairy from being carried + # in the stack trace. Do this after the + # connection record has been checked in, so that + # if the del triggers a finalize fairy, it won't + # try to checkin a second time. + del fairy + + # never called, this is for code linters + raise + + attempts -= 1 + except BaseException as be_outer: + with util.safe_reraise(): + rec = fairy._connection_record + if rec is not None: + rec._checkin_failed( + be_outer, + _fairy_was_created=True, + ) + + # prevent _ConnectionFairy from being carried + # in the stack trace, see above + del fairy + + # never called, this is for code linters + raise + + pool.logger.info("Reconnection attempts exhausted on checkout") + fairy.invalidate() + raise exc.InvalidRequestError("This connection is closed") + + def _checkout_existing(self) -> _ConnectionFairy: + return _ConnectionFairy._checkout(self._pool, fairy=self) + + def _checkin(self, transaction_was_reset: bool = False) -> None: + _finalize_fairy( + self.dbapi_connection, + self._connection_record, + self._pool, + None, + self._echo, + transaction_was_reset=transaction_was_reset, + fairy=self, + ) + + def _close(self) -> None: + self._checkin() + + def _reset( + self, + pool: Pool, + transaction_was_reset: bool, + terminate_only: bool, + asyncio_safe: bool, + ) -> None: + if pool.dispatch.reset: + pool.dispatch.reset( + self.dbapi_connection, + self._connection_record, + PoolResetState( + transaction_was_reset=transaction_was_reset, + terminate_only=terminate_only, + asyncio_safe=asyncio_safe, + ), + ) + + if not asyncio_safe: + return + + if pool._reset_on_return is reset_rollback: + if transaction_was_reset: + if self._echo: + pool.logger.debug( + "Connection %s reset, transaction already reset", + self.dbapi_connection, + ) + else: + if self._echo: + pool.logger.debug( + "Connection %s rollback-on-return", + self.dbapi_connection, + ) + pool._dialect.do_rollback(self) + elif pool._reset_on_return is reset_commit: + if self._echo: + pool.logger.debug( + "Connection %s commit-on-return", + self.dbapi_connection, + ) + pool._dialect.do_commit(self) + + @property + def _logger(self) -> log._IdentifiedLoggerType: + return self._pool.logger + + @property + def is_valid(self) -> bool: + return self.dbapi_connection is not None + + @property + def is_detached(self) -> bool: + return self._connection_record is None + + @util.ro_memoized_property + def info(self) -> _InfoType: + if self._connection_record is None: + return {} + else: + return self._connection_record.info + + @util.ro_non_memoized_property + def record_info(self) -> Optional[_InfoType]: + if self._connection_record is None: + return None + else: + return self._connection_record.record_info + + def invalidate( + self, e: Optional[BaseException] = None, soft: bool = False + ) -> None: + if self.dbapi_connection is None: + util.warn("Can't invalidate an already-closed connection.") + return + if self._connection_record: + self._connection_record.invalidate(e=e, soft=soft) + if not soft: + # prevent any rollback / reset actions etc. on + # the connection + self.dbapi_connection = None # type: ignore + + # finalize + self._checkin() + + def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: + assert self.dbapi_connection is not None + return self.dbapi_connection.cursor(*args, **kwargs) + + def __getattr__(self, key: str) -> Any: + return getattr(self.dbapi_connection, key) + + def detach(self) -> None: + if self._connection_record is not None: + rec = self._connection_record + rec.fairy_ref = None + rec.dbapi_connection = None + # TODO: should this be _return_conn? + self._pool._do_return_conn(self._connection_record) + + # can't get the descriptor assignment to work here + # in pylance. mypy is OK w/ it + self.info = self.info.copy() # type: ignore + + self._connection_record = None + + if self._pool.dispatch.detach: + self._pool.dispatch.detach(self.dbapi_connection, rec) + + def close(self) -> None: + self._counter -= 1 + if self._counter == 0: + self._checkin() + + def _close_special(self, transaction_reset: bool = False) -> None: + self._counter -= 1 + if self._counter == 0: + self._checkin(transaction_was_reset=transaction_reset) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/events.py b/venv/lib/python3.11/site-packages/sqlalchemy/pool/events.py new file mode 100644 index 0000000..4b4f4e4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/pool/events.py @@ -0,0 +1,370 @@ +# pool/events.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + +import typing +from typing import Any +from typing import Optional +from typing import Type +from typing import Union + +from .base import ConnectionPoolEntry +from .base import Pool +from .base import PoolProxiedConnection +from .base import PoolResetState +from .. import event +from .. import util + +if typing.TYPE_CHECKING: + from ..engine import Engine + from ..engine.interfaces import DBAPIConnection + + +class PoolEvents(event.Events[Pool]): + """Available events for :class:`_pool.Pool`. + + The methods here define the name of an event as well + as the names of members that are passed to listener + functions. + + e.g.:: + + from sqlalchemy import event + + def my_on_checkout(dbapi_conn, connection_rec, connection_proxy): + "handle an on checkout event" + + event.listen(Pool, 'checkout', my_on_checkout) + + In addition to accepting the :class:`_pool.Pool` class and + :class:`_pool.Pool` instances, :class:`_events.PoolEvents` also accepts + :class:`_engine.Engine` objects and the :class:`_engine.Engine` class as + targets, which will be resolved to the ``.pool`` attribute of the + given engine or the :class:`_pool.Pool` class:: + + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + + # will associate with engine.pool + event.listen(engine, 'checkout', my_on_checkout) + + """ # noqa: E501 + + _target_class_doc = "SomeEngineOrPool" + _dispatch_target = Pool + + @util.preload_module("sqlalchemy.engine") + @classmethod + def _accept_with( + cls, + target: Union[Pool, Type[Pool], Engine, Type[Engine]], + identifier: str, + ) -> Optional[Union[Pool, Type[Pool]]]: + if not typing.TYPE_CHECKING: + Engine = util.preloaded.engine.Engine + + if isinstance(target, type): + if issubclass(target, Engine): + return Pool + else: + assert issubclass(target, Pool) + return target + elif isinstance(target, Engine): + return target.pool + elif isinstance(target, Pool): + return target + elif hasattr(target, "_no_async_engine_events"): + target._no_async_engine_events() + else: + return None + + @classmethod + def _listen( + cls, + event_key: event._EventKey[Pool], + **kw: Any, + ) -> None: + target = event_key.dispatch_target + + kw.setdefault("asyncio", target._is_asyncio) + + event_key.base_listen(**kw) + + def connect( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called at the moment a particular DBAPI connection is first + created for a given :class:`_pool.Pool`. + + This event allows one to capture the point directly after which + the DBAPI module-level ``.connect()`` method has been used in order + to produce a new DBAPI connection. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def first_connect( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called exactly once for the first time a DBAPI connection is + checked out from a particular :class:`_pool.Pool`. + + The rationale for :meth:`_events.PoolEvents.first_connect` + is to determine + information about a particular series of database connections based + on the settings used for all connections. Since a particular + :class:`_pool.Pool` + refers to a single "creator" function (which in terms + of a :class:`_engine.Engine` + refers to the URL and connection options used), + it is typically valid to make observations about a single connection + that can be safely assumed to be valid about all subsequent + connections, such as the database version, the server and client + encoding settings, collation settings, and many others. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def checkout( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + connection_proxy: PoolProxiedConnection, + ) -> None: + """Called when a connection is retrieved from the Pool. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param connection_proxy: the :class:`.PoolProxiedConnection` object + which will proxy the public interface of the DBAPI connection for the + lifespan of the checkout. + + If you raise a :class:`~sqlalchemy.exc.DisconnectionError`, the current + connection will be disposed and a fresh connection retrieved. + Processing of all checkout listeners will abort and restart + using the new connection. + + .. seealso:: :meth:`_events.ConnectionEvents.engine_connect` + - a similar event + which occurs upon creation of a new :class:`_engine.Connection`. + + """ + + def checkin( + self, + dbapi_connection: Optional[DBAPIConnection], + connection_record: ConnectionPoolEntry, + ) -> None: + """Called when a connection returns to the pool. + + Note that the connection may be closed, and may be None if the + connection has been invalidated. ``checkin`` will not be called + for detached connections. (They do not return to the pool.) + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + @event._legacy_signature( + "2.0", + ["dbapi_connection", "connection_record"], + lambda dbapi_connection, connection_record, reset_state: ( + dbapi_connection, + connection_record, + ), + ) + def reset( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + reset_state: PoolResetState, + ) -> None: + """Called before the "reset" action occurs for a pooled connection. + + This event represents + when the ``rollback()`` method is called on the DBAPI connection + before it is returned to the pool or discarded. + A custom "reset" strategy may be implemented using this event hook, + which may also be combined with disabling the default "reset" + behavior using the :paramref:`_pool.Pool.reset_on_return` parameter. + + The primary difference between the :meth:`_events.PoolEvents.reset` and + :meth:`_events.PoolEvents.checkin` events are that + :meth:`_events.PoolEvents.reset` is called not just for pooled + connections that are being returned to the pool, but also for + connections that were detached using the + :meth:`_engine.Connection.detach` method as well as asyncio connections + that are being discarded due to garbage collection taking place on + connections before the connection was checked in. + + Note that the event **is not** invoked for connections that were + invalidated using :meth:`_engine.Connection.invalidate`. These + events may be intercepted using the :meth:`.PoolEvents.soft_invalidate` + and :meth:`.PoolEvents.invalidate` event hooks, and all "connection + close" events may be intercepted using :meth:`.PoolEvents.close`. + + The :meth:`_events.PoolEvents.reset` event is usually followed by the + :meth:`_events.PoolEvents.checkin` event, except in those + cases where the connection is discarded immediately after reset. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param reset_state: :class:`.PoolResetState` instance which provides + information about the circumstances under which the connection + is being reset. + + .. versionadded:: 2.0 + + .. seealso:: + + :ref:`pool_reset_on_return` + + :meth:`_events.ConnectionEvents.rollback` + + :meth:`_events.ConnectionEvents.commit` + + """ + + def invalidate( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + exception: Optional[BaseException], + ) -> None: + """Called when a DBAPI connection is to be "invalidated". + + This event is called any time the + :meth:`.ConnectionPoolEntry.invalidate` method is invoked, either from + API usage or via "auto-invalidation", without the ``soft`` flag. + + The event occurs before a final attempt to call ``.close()`` on the + connection occurs. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param exception: the exception object corresponding to the reason + for this invalidation, if any. May be ``None``. + + .. seealso:: + + :ref:`pool_connection_invalidation` + + """ + + def soft_invalidate( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + exception: Optional[BaseException], + ) -> None: + """Called when a DBAPI connection is to be "soft invalidated". + + This event is called any time the + :meth:`.ConnectionPoolEntry.invalidate` + method is invoked with the ``soft`` flag. + + Soft invalidation refers to when the connection record that tracks + this connection will force a reconnect after the current connection + is checked in. It does not actively close the dbapi_connection + at the point at which it is called. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + :param exception: the exception object corresponding to the reason + for this invalidation, if any. May be ``None``. + + """ + + def close( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called when a DBAPI connection is closed. + + The event is emitted before the close occurs. + + The close of a connection can fail; typically this is because + the connection is already closed. If the close operation fails, + the connection is discarded. + + The :meth:`.close` event corresponds to a connection that's still + associated with the pool. To intercept close events for detached + connections use :meth:`.close_detached`. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def detach( + self, + dbapi_connection: DBAPIConnection, + connection_record: ConnectionPoolEntry, + ) -> None: + """Called when a DBAPI connection is "detached" from a pool. + + This event is emitted after the detach occurs. The connection + is no longer associated with the given connection record. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + :param connection_record: the :class:`.ConnectionPoolEntry` managing + the DBAPI connection. + + """ + + def close_detached(self, dbapi_connection: DBAPIConnection) -> None: + """Called when a detached DBAPI connection is closed. + + The event is emitted before the close occurs. + + The close of a connection can fail; typically this is because + the connection is already closed. If the close operation fails, + the connection is discarded. + + :param dbapi_connection: a DBAPI connection. + The :attr:`.ConnectionPoolEntry.dbapi_connection` attribute. + + """ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/pool/impl.py b/venv/lib/python3.11/site-packages/sqlalchemy/pool/impl.py new file mode 100644 index 0000000..157455c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/pool/impl.py @@ -0,0 +1,581 @@ +# pool/impl.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + + +"""Pool implementation classes. + +""" +from __future__ import annotations + +import threading +import traceback +import typing +from typing import Any +from typing import cast +from typing import List +from typing import Optional +from typing import Set +from typing import Type +from typing import TYPE_CHECKING +from typing import Union +import weakref + +from .base import _AsyncConnDialect +from .base import _ConnectionFairy +from .base import _ConnectionRecord +from .base import _CreatorFnType +from .base import _CreatorWRecFnType +from .base import ConnectionPoolEntry +from .base import Pool +from .base import PoolProxiedConnection +from .. import exc +from .. import util +from ..util import chop_traceback +from ..util import queue as sqla_queue +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ..engine.interfaces import DBAPIConnection + + +class QueuePool(Pool): + """A :class:`_pool.Pool` + that imposes a limit on the number of open connections. + + :class:`.QueuePool` is the default pooling implementation used for + all :class:`_engine.Engine` objects other than SQLite with a ``:memory:`` + database. + + The :class:`.QueuePool` class **is not compatible** with asyncio and + :func:`_asyncio.create_async_engine`. The + :class:`.AsyncAdaptedQueuePool` class is used automatically when + using :func:`_asyncio.create_async_engine`, if no other kind of pool + is specified. + + .. seealso:: + + :class:`.AsyncAdaptedQueuePool` + + """ + + _is_asyncio = False # type: ignore[assignment] + + _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( + sqla_queue.Queue + ) + + _pool: sqla_queue.QueueCommon[ConnectionPoolEntry] + + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + max_overflow: int = 10, + timeout: float = 30.0, + use_lifo: bool = False, + **kw: Any, + ): + r""" + Construct a QueuePool. + + :param creator: a callable function that returns a DB-API + connection object, same as that of :paramref:`_pool.Pool.creator`. + + :param pool_size: The size of the pool to be maintained, + defaults to 5. This is the largest number of connections that + will be kept persistently in the pool. Note that the pool + begins with no connections; once this number of connections + is requested, that number of connections will remain. + ``pool_size`` can be set to 0 to indicate no size limit; to + disable pooling, use a :class:`~sqlalchemy.pool.NullPool` + instead. + + :param max_overflow: The maximum overflow size of the + pool. When the number of checked-out connections reaches the + size set in pool_size, additional connections will be + returned up to this limit. When those additional connections + are returned to the pool, they are disconnected and + discarded. It follows then that the total number of + simultaneous connections the pool will allow is pool_size + + `max_overflow`, and the total number of "sleeping" + connections the pool will allow is pool_size. `max_overflow` + can be set to -1 to indicate no overflow limit; no limit + will be placed on the total number of concurrent + connections. Defaults to 10. + + :param timeout: The number of seconds to wait before giving up + on returning a connection. Defaults to 30.0. This can be a float + but is subject to the limitations of Python time functions which + may not be reliable in the tens of milliseconds. + + :param use_lifo: use LIFO (last-in-first-out) when retrieving + connections instead of FIFO (first-in-first-out). Using LIFO, a + server-side timeout scheme can reduce the number of connections used + during non-peak periods of use. When planning for server-side + timeouts, ensure that a recycle or pre-ping strategy is in use to + gracefully handle stale connections. + + .. versionadded:: 1.3 + + .. seealso:: + + :ref:`pool_use_lifo` + + :ref:`pool_disconnects` + + :param \**kw: Other keyword arguments including + :paramref:`_pool.Pool.recycle`, :paramref:`_pool.Pool.echo`, + :paramref:`_pool.Pool.reset_on_return` and others are passed to the + :class:`_pool.Pool` constructor. + + """ + + Pool.__init__(self, creator, **kw) + self._pool = self._queue_class(pool_size, use_lifo=use_lifo) + self._overflow = 0 - pool_size + self._max_overflow = -1 if pool_size == 0 else max_overflow + self._timeout = timeout + self._overflow_lock = threading.Lock() + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + try: + self._pool.put(record, False) + except sqla_queue.Full: + try: + record.close() + finally: + self._dec_overflow() + + def _do_get(self) -> ConnectionPoolEntry: + use_overflow = self._max_overflow > -1 + + wait = use_overflow and self._overflow >= self._max_overflow + try: + return self._pool.get(wait, self._timeout) + except sqla_queue.Empty: + # don't do things inside of "except Empty", because when we say + # we timed out or can't connect and raise, Python 3 tells + # people the real error is queue.Empty which it isn't. + pass + if use_overflow and self._overflow >= self._max_overflow: + if not wait: + return self._do_get() + else: + raise exc.TimeoutError( + "QueuePool limit of size %d overflow %d reached, " + "connection timed out, timeout %0.2f" + % (self.size(), self.overflow(), self._timeout), + code="3o7r", + ) + + if self._inc_overflow(): + try: + return self._create_connection() + except: + with util.safe_reraise(): + self._dec_overflow() + raise + else: + return self._do_get() + + def _inc_overflow(self) -> bool: + if self._max_overflow == -1: + self._overflow += 1 + return True + with self._overflow_lock: + if self._overflow < self._max_overflow: + self._overflow += 1 + return True + else: + return False + + def _dec_overflow(self) -> Literal[True]: + if self._max_overflow == -1: + self._overflow -= 1 + return True + with self._overflow_lock: + self._overflow -= 1 + return True + + def recreate(self) -> QueuePool: + self.logger.info("Pool recreating") + return self.__class__( + self._creator, + pool_size=self._pool.maxsize, + max_overflow=self._max_overflow, + pre_ping=self._pre_ping, + use_lifo=self._pool.use_lifo, + timeout=self._timeout, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def dispose(self) -> None: + while True: + try: + conn = self._pool.get(False) + conn.close() + except sqla_queue.Empty: + break + + self._overflow = 0 - self.size() + self.logger.info("Pool disposed. %s", self.status()) + + def status(self) -> str: + return ( + "Pool size: %d Connections in pool: %d " + "Current Overflow: %d Current Checked out " + "connections: %d" + % ( + self.size(), + self.checkedin(), + self.overflow(), + self.checkedout(), + ) + ) + + def size(self) -> int: + return self._pool.maxsize + + def timeout(self) -> float: + return self._timeout + + def checkedin(self) -> int: + return self._pool.qsize() + + def overflow(self) -> int: + return self._overflow if self._pool.maxsize else 0 + + def checkedout(self) -> int: + return self._pool.maxsize - self._pool.qsize() + self._overflow + + +class AsyncAdaptedQueuePool(QueuePool): + """An asyncio-compatible version of :class:`.QueuePool`. + + This pool is used by default when using :class:`.AsyncEngine` engines that + were generated from :func:`_asyncio.create_async_engine`. It uses an + asyncio-compatible queue implementation that does not use + ``threading.Lock``. + + The arguments and operation of :class:`.AsyncAdaptedQueuePool` are + otherwise identical to that of :class:`.QueuePool`. + + """ + + _is_asyncio = True # type: ignore[assignment] + _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( + sqla_queue.AsyncAdaptedQueue + ) + + _dialect = _AsyncConnDialect() + + +class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool): + _queue_class = sqla_queue.FallbackAsyncAdaptedQueue + + +class NullPool(Pool): + """A Pool which does not pool connections. + + Instead it literally opens and closes the underlying DB-API connection + per each connection open/close. + + Reconnect-related functions such as ``recycle`` and connection + invalidation are not supported by this Pool implementation, since + no connections are held persistently. + + The :class:`.NullPool` class **is compatible** with asyncio and + :func:`_asyncio.create_async_engine`. + + """ + + def status(self) -> str: + return "NullPool" + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + record.close() + + def _do_get(self) -> ConnectionPoolEntry: + return self._create_connection() + + def recreate(self) -> NullPool: + self.logger.info("Pool recreating") + + return self.__class__( + self._creator, + recycle=self._recycle, + echo=self.echo, + logging_name=self._orig_logging_name, + reset_on_return=self._reset_on_return, + pre_ping=self._pre_ping, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def dispose(self) -> None: + pass + + +class SingletonThreadPool(Pool): + """A Pool that maintains one connection per thread. + + Maintains one connection per each thread, never moving a connection to a + thread other than the one which it was created in. + + .. warning:: the :class:`.SingletonThreadPool` will call ``.close()`` + on arbitrary connections that exist beyond the size setting of + ``pool_size``, e.g. if more unique **thread identities** + than what ``pool_size`` states are used. This cleanup is + non-deterministic and not sensitive to whether or not the connections + linked to those thread identities are currently in use. + + :class:`.SingletonThreadPool` may be improved in a future release, + however in its current status it is generally used only for test + scenarios using a SQLite ``:memory:`` database and is not recommended + for production use. + + The :class:`.SingletonThreadPool` class **is not compatible** with asyncio + and :func:`_asyncio.create_async_engine`. + + + Options are the same as those of :class:`_pool.Pool`, as well as: + + :param pool_size: The number of threads in which to maintain connections + at once. Defaults to five. + + :class:`.SingletonThreadPool` is used by the SQLite dialect + automatically when a memory-based database is used. + See :ref:`sqlite_toplevel`. + + """ + + _is_asyncio = False # type: ignore[assignment] + + def __init__( + self, + creator: Union[_CreatorFnType, _CreatorWRecFnType], + pool_size: int = 5, + **kw: Any, + ): + Pool.__init__(self, creator, **kw) + self._conn = threading.local() + self._fairy = threading.local() + self._all_conns: Set[ConnectionPoolEntry] = set() + self.size = pool_size + + def recreate(self) -> SingletonThreadPool: + self.logger.info("Pool recreating") + return self.__class__( + self._creator, + pool_size=self.size, + recycle=self._recycle, + echo=self.echo, + pre_ping=self._pre_ping, + logging_name=self._orig_logging_name, + reset_on_return=self._reset_on_return, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def dispose(self) -> None: + """Dispose of this pool.""" + + for conn in self._all_conns: + try: + conn.close() + except Exception: + # pysqlite won't even let you close a conn from a thread + # that didn't create it + pass + + self._all_conns.clear() + + def _cleanup(self) -> None: + while len(self._all_conns) >= self.size: + c = self._all_conns.pop() + c.close() + + def status(self) -> str: + return "SingletonThreadPool id:%d size: %d" % ( + id(self), + len(self._all_conns), + ) + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + try: + del self._fairy.current + except AttributeError: + pass + + def _do_get(self) -> ConnectionPoolEntry: + try: + if TYPE_CHECKING: + c = cast(ConnectionPoolEntry, self._conn.current()) + else: + c = self._conn.current() + if c: + return c + except AttributeError: + pass + c = self._create_connection() + self._conn.current = weakref.ref(c) + if len(self._all_conns) >= self.size: + self._cleanup() + self._all_conns.add(c) + return c + + def connect(self) -> PoolProxiedConnection: + # vendored from Pool to include the now removed use_threadlocal + # behavior + try: + rec = cast(_ConnectionFairy, self._fairy.current()) + except AttributeError: + pass + else: + if rec is not None: + return rec._checkout_existing() + + return _ConnectionFairy._checkout(self, self._fairy) + + +class StaticPool(Pool): + """A Pool of exactly one connection, used for all requests. + + Reconnect-related functions such as ``recycle`` and connection + invalidation (which is also used to support auto-reconnect) are only + partially supported right now and may not yield good results. + + The :class:`.StaticPool` class **is compatible** with asyncio and + :func:`_asyncio.create_async_engine`. + + """ + + @util.memoized_property + def connection(self) -> _ConnectionRecord: + return _ConnectionRecord(self) + + def status(self) -> str: + return "StaticPool" + + def dispose(self) -> None: + if ( + "connection" in self.__dict__ + and self.connection.dbapi_connection is not None + ): + self.connection.close() + del self.__dict__["connection"] + + def recreate(self) -> StaticPool: + self.logger.info("Pool recreating") + return self.__class__( + creator=self._creator, + recycle=self._recycle, + reset_on_return=self._reset_on_return, + pre_ping=self._pre_ping, + echo=self.echo, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def _transfer_from(self, other_static_pool: StaticPool) -> None: + # used by the test suite to make a new engine / pool without + # losing the state of an existing SQLite :memory: connection + def creator(rec: ConnectionPoolEntry) -> DBAPIConnection: + conn = other_static_pool.connection.dbapi_connection + assert conn is not None + return conn + + self._invoke_creator = creator + + def _create_connection(self) -> ConnectionPoolEntry: + raise NotImplementedError() + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + pass + + def _do_get(self) -> ConnectionPoolEntry: + rec = self.connection + if rec._is_hard_or_soft_invalidated(): + del self.__dict__["connection"] + rec = self.connection + + return rec + + +class AssertionPool(Pool): + """A :class:`_pool.Pool` that allows at most one checked out connection at + any given time. + + This will raise an exception if more than one connection is checked out + at a time. Useful for debugging code that is using more connections + than desired. + + The :class:`.AssertionPool` class **is compatible** with asyncio and + :func:`_asyncio.create_async_engine`. + + """ + + _conn: Optional[ConnectionPoolEntry] + _checkout_traceback: Optional[List[str]] + + def __init__(self, *args: Any, **kw: Any): + self._conn = None + self._checked_out = False + self._store_traceback = kw.pop("store_traceback", True) + self._checkout_traceback = None + Pool.__init__(self, *args, **kw) + + def status(self) -> str: + return "AssertionPool" + + def _do_return_conn(self, record: ConnectionPoolEntry) -> None: + if not self._checked_out: + raise AssertionError("connection is not checked out") + self._checked_out = False + assert record is self._conn + + def dispose(self) -> None: + self._checked_out = False + if self._conn: + self._conn.close() + + def recreate(self) -> AssertionPool: + self.logger.info("Pool recreating") + return self.__class__( + self._creator, + echo=self.echo, + pre_ping=self._pre_ping, + recycle=self._recycle, + reset_on_return=self._reset_on_return, + logging_name=self._orig_logging_name, + _dispatch=self.dispatch, + dialect=self._dialect, + ) + + def _do_get(self) -> ConnectionPoolEntry: + if self._checked_out: + if self._checkout_traceback: + suffix = " at:\n%s" % "".join( + chop_traceback(self._checkout_traceback) + ) + else: + suffix = "" + raise AssertionError("connection is already checked out" + suffix) + + if not self._conn: + self._conn = self._create_connection() + + self._checked_out = True + if self._store_traceback: + self._checkout_traceback = traceback.format_stack() + return self._conn diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/py.typed b/venv/lib/python3.11/site-packages/sqlalchemy/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/schema.py b/venv/lib/python3.11/site-packages/sqlalchemy/schema.py new file mode 100644 index 0000000..9edca4e --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/schema.py @@ -0,0 +1,70 @@ +# schema.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Compatibility namespace for sqlalchemy.sql.schema and related. + +""" + +from __future__ import annotations + +from .sql.base import SchemaVisitor as SchemaVisitor +from .sql.ddl import _CreateDropBase as _CreateDropBase +from .sql.ddl import _DropView as _DropView +from .sql.ddl import AddConstraint as AddConstraint +from .sql.ddl import BaseDDLElement as BaseDDLElement +from .sql.ddl import CreateColumn as CreateColumn +from .sql.ddl import CreateIndex as CreateIndex +from .sql.ddl import CreateSchema as CreateSchema +from .sql.ddl import CreateSequence as CreateSequence +from .sql.ddl import CreateTable as CreateTable +from .sql.ddl import DDL as DDL +from .sql.ddl import DDLElement as DDLElement +from .sql.ddl import DropColumnComment as DropColumnComment +from .sql.ddl import DropConstraint as DropConstraint +from .sql.ddl import DropConstraintComment as DropConstraintComment +from .sql.ddl import DropIndex as DropIndex +from .sql.ddl import DropSchema as DropSchema +from .sql.ddl import DropSequence as DropSequence +from .sql.ddl import DropTable as DropTable +from .sql.ddl import DropTableComment as DropTableComment +from .sql.ddl import ExecutableDDLElement as ExecutableDDLElement +from .sql.ddl import InvokeDDLBase as InvokeDDLBase +from .sql.ddl import SetColumnComment as SetColumnComment +from .sql.ddl import SetConstraintComment as SetConstraintComment +from .sql.ddl import SetTableComment as SetTableComment +from .sql.ddl import sort_tables as sort_tables +from .sql.ddl import ( + sort_tables_and_constraints as sort_tables_and_constraints, +) +from .sql.naming import conv as conv +from .sql.schema import _get_table_key as _get_table_key +from .sql.schema import BLANK_SCHEMA as BLANK_SCHEMA +from .sql.schema import CheckConstraint as CheckConstraint +from .sql.schema import Column as Column +from .sql.schema import ( + ColumnCollectionConstraint as ColumnCollectionConstraint, +) +from .sql.schema import ColumnCollectionMixin as ColumnCollectionMixin +from .sql.schema import ColumnDefault as ColumnDefault +from .sql.schema import Computed as Computed +from .sql.schema import Constraint as Constraint +from .sql.schema import DefaultClause as DefaultClause +from .sql.schema import DefaultGenerator as DefaultGenerator +from .sql.schema import FetchedValue as FetchedValue +from .sql.schema import ForeignKey as ForeignKey +from .sql.schema import ForeignKeyConstraint as ForeignKeyConstraint +from .sql.schema import HasConditionalDDL as HasConditionalDDL +from .sql.schema import Identity as Identity +from .sql.schema import Index as Index +from .sql.schema import insert_sentinel as insert_sentinel +from .sql.schema import MetaData as MetaData +from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint +from .sql.schema import SchemaConst as SchemaConst +from .sql.schema import SchemaItem as SchemaItem +from .sql.schema import Sequence as Sequence +from .sql.schema import Table as Table +from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__init__.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__init__.py new file mode 100644 index 0000000..9e0d2ca --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__init__.py @@ -0,0 +1,145 @@ +# sql/__init__.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +from typing import Any +from typing import TYPE_CHECKING + +from ._typing import ColumnExpressionArgument as ColumnExpressionArgument +from ._typing import NotNullable as NotNullable +from ._typing import Nullable as Nullable +from .base import Executable as Executable +from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS +from .compiler import FROM_LINTING as FROM_LINTING +from .compiler import NO_LINTING as NO_LINTING +from .compiler import WARN_LINTING as WARN_LINTING +from .ddl import BaseDDLElement as BaseDDLElement +from .ddl import DDL as DDL +from .ddl import DDLElement as DDLElement +from .ddl import ExecutableDDLElement as ExecutableDDLElement +from .expression import Alias as Alias +from .expression import alias as alias +from .expression import all_ as all_ +from .expression import and_ as and_ +from .expression import any_ as any_ +from .expression import asc as asc +from .expression import between as between +from .expression import bindparam as bindparam +from .expression import case as case +from .expression import cast as cast +from .expression import ClauseElement as ClauseElement +from .expression import collate as collate +from .expression import column as column +from .expression import ColumnCollection as ColumnCollection +from .expression import ColumnElement as ColumnElement +from .expression import CompoundSelect as CompoundSelect +from .expression import cte as cte +from .expression import Delete as Delete +from .expression import delete as delete +from .expression import desc as desc +from .expression import distinct as distinct +from .expression import except_ as except_ +from .expression import except_all as except_all +from .expression import exists as exists +from .expression import extract as extract +from .expression import false as false +from .expression import False_ as False_ +from .expression import FromClause as FromClause +from .expression import func as func +from .expression import funcfilter as funcfilter +from .expression import Insert as Insert +from .expression import insert as insert +from .expression import intersect as intersect +from .expression import intersect_all as intersect_all +from .expression import Join as Join +from .expression import join as join +from .expression import label as label +from .expression import LABEL_STYLE_DEFAULT as LABEL_STYLE_DEFAULT +from .expression import ( + LABEL_STYLE_DISAMBIGUATE_ONLY as LABEL_STYLE_DISAMBIGUATE_ONLY, +) +from .expression import LABEL_STYLE_NONE as LABEL_STYLE_NONE +from .expression import ( + LABEL_STYLE_TABLENAME_PLUS_COL as LABEL_STYLE_TABLENAME_PLUS_COL, +) +from .expression import lambda_stmt as lambda_stmt +from .expression import LambdaElement as LambdaElement +from .expression import lateral as lateral +from .expression import literal as literal +from .expression import literal_column as literal_column +from .expression import modifier as modifier +from .expression import not_ as not_ +from .expression import null as null +from .expression import nulls_first as nulls_first +from .expression import nulls_last as nulls_last +from .expression import nullsfirst as nullsfirst +from .expression import nullslast as nullslast +from .expression import or_ as or_ +from .expression import outerjoin as outerjoin +from .expression import outparam as outparam +from .expression import over as over +from .expression import quoted_name as quoted_name +from .expression import Select as Select +from .expression import select as select +from .expression import Selectable as Selectable +from .expression import SelectLabelStyle as SelectLabelStyle +from .expression import SQLColumnExpression as SQLColumnExpression +from .expression import StatementLambdaElement as StatementLambdaElement +from .expression import Subquery as Subquery +from .expression import table as table +from .expression import TableClause as TableClause +from .expression import TableSample as TableSample +from .expression import tablesample as tablesample +from .expression import text as text +from .expression import true as true +from .expression import True_ as True_ +from .expression import try_cast as try_cast +from .expression import tuple_ as tuple_ +from .expression import type_coerce as type_coerce +from .expression import union as union +from .expression import union_all as union_all +from .expression import Update as Update +from .expression import update as update +from .expression import Values as Values +from .expression import values as values +from .expression import within_group as within_group +from .visitors import ClauseVisitor as ClauseVisitor + + +def __go(lcls: Any) -> None: + from .. import util as _sa_util + + from . import base + from . import coercions + from . import elements + from . import lambdas + from . import selectable + from . import schema + from . import traversals + from . import type_api + + if not TYPE_CHECKING: + base.coercions = elements.coercions = coercions + base.elements = elements + base.type_api = type_api + coercions.elements = elements + coercions.lambdas = lambdas + coercions.schema = schema + coercions.selectable = selectable + + from .annotation import _prepare_annotations + from .annotation import Annotated + from .elements import AnnotatedColumnElement + from .elements import ClauseList + from .selectable import AnnotatedFromClause + + _prepare_annotations(ColumnElement, AnnotatedColumnElement) + _prepare_annotations(FromClause, AnnotatedFromClause) + _prepare_annotations(ClauseList, Annotated) + + _sa_util.preloaded.import_prefix("sqlalchemy.sql") + + +__go(locals()) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..16135d0 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/__init__.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_dml_constructors.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_dml_constructors.cpython-311.pyc new file mode 100644 index 0000000..0525743 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_dml_constructors.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_elements_constructors.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_elements_constructors.cpython-311.pyc new file mode 100644 index 0000000..493de78 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_elements_constructors.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_orm_types.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_orm_types.cpython-311.pyc new file mode 100644 index 0000000..76e4a97 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_orm_types.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_py_util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_py_util.cpython-311.pyc new file mode 100644 index 0000000..ab0a578 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_py_util.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_selectable_constructors.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_selectable_constructors.cpython-311.pyc new file mode 100644 index 0000000..29ea597 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_selectable_constructors.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_typing.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_typing.cpython-311.pyc new file mode 100644 index 0000000..d5f60fb Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/_typing.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/annotation.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/annotation.cpython-311.pyc new file mode 100644 index 0000000..d52f144 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/annotation.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000..efe2b16 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/base.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/cache_key.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/cache_key.cpython-311.pyc new file mode 100644 index 0000000..fa315ca Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/cache_key.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/coercions.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/coercions.cpython-311.pyc new file mode 100644 index 0000000..132dce4 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/coercions.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/compiler.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/compiler.cpython-311.pyc new file mode 100644 index 0000000..1f1a5fc Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/compiler.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/crud.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/crud.cpython-311.pyc new file mode 100644 index 0000000..021ffb3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/crud.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/ddl.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/ddl.cpython-311.pyc new file mode 100644 index 0000000..3f5a4e9 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/ddl.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/default_comparator.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/default_comparator.cpython-311.pyc new file mode 100644 index 0000000..ae92d63 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/default_comparator.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/dml.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/dml.cpython-311.pyc new file mode 100644 index 0000000..14985ca Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/dml.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/elements.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/elements.cpython-311.pyc new file mode 100644 index 0000000..3eda846 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/elements.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/events.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/events.cpython-311.pyc new file mode 100644 index 0000000..6ff8a38 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/events.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/expression.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/expression.cpython-311.pyc new file mode 100644 index 0000000..e3d2b68 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/expression.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/functions.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/functions.cpython-311.pyc new file mode 100644 index 0000000..f03311a Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/functions.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/lambdas.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/lambdas.cpython-311.pyc new file mode 100644 index 0000000..9eea091 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/lambdas.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/naming.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/naming.cpython-311.pyc new file mode 100644 index 0000000..86ebcd9 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/naming.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/operators.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/operators.cpython-311.pyc new file mode 100644 index 0000000..36ffdc3 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/operators.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/roles.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/roles.cpython-311.pyc new file mode 100644 index 0000000..d3dab5d Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/roles.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/schema.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/schema.cpython-311.pyc new file mode 100644 index 0000000..ca5ea38 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/schema.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/selectable.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/selectable.cpython-311.pyc new file mode 100644 index 0000000..fea6f8f Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/selectable.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/sqltypes.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/sqltypes.cpython-311.pyc new file mode 100644 index 0000000..1214a08 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/sqltypes.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/traversals.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/traversals.cpython-311.pyc new file mode 100644 index 0000000..f1c5425 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/traversals.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/type_api.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/type_api.cpython-311.pyc new file mode 100644 index 0000000..47f73cd Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/type_api.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/util.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/util.cpython-311.pyc new file mode 100644 index 0000000..1a4cda5 Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/util.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/visitors.cpython-311.pyc b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/visitors.cpython-311.pyc new file mode 100644 index 0000000..63c33db Binary files /dev/null and b/venv/lib/python3.11/site-packages/sqlalchemy/sql/__pycache__/visitors.cpython-311.pyc differ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/_dml_constructors.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_dml_constructors.py new file mode 100644 index 0000000..a7ead52 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_dml_constructors.py @@ -0,0 +1,140 @@ +# sql/_dml_constructors.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .dml import Delete +from .dml import Insert +from .dml import Update + +if TYPE_CHECKING: + from ._typing import _DMLTableArgument + + +def insert(table: _DMLTableArgument) -> Insert: + """Construct an :class:`_expression.Insert` object. + + E.g.:: + + from sqlalchemy import insert + + stmt = ( + insert(user_table). + values(name='username', fullname='Full Username') + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.insert` method on + :class:`_schema.Table`. + + .. seealso:: + + :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` + + + :param table: :class:`_expression.TableClause` + which is the subject of the + insert. + + :param values: collection of values to be inserted; see + :meth:`_expression.Insert.values` + for a description of allowed formats here. + Can be omitted entirely; a :class:`_expression.Insert` construct + will also dynamically render the VALUES clause at execution time + based on the parameters passed to :meth:`_engine.Connection.execute`. + + :param inline: if True, no attempt will be made to retrieve the + SQL-generated default values to be provided within the statement; + in particular, + this allows SQL expressions to be rendered 'inline' within the + statement without the need to pre-execute them beforehand; for + backends that support "returning", this turns off the "implicit + returning" feature for the statement. + + If both :paramref:`_expression.insert.values` and compile-time bind + parameters are present, the compile-time bind parameters override the + information specified within :paramref:`_expression.insert.values` on a + per-key basis. + + The keys within :paramref:`_expression.Insert.values` can be either + :class:`~sqlalchemy.schema.Column` objects or their string + identifiers. Each key may reference one of: + + * a literal data value (i.e. string, number, etc.); + * a Column object; + * a SELECT statement. + + If a ``SELECT`` statement is specified which references this + ``INSERT`` statement's table, the statement will be correlated + against the ``INSERT`` statement. + + .. seealso:: + + :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` + + """ + return Insert(table) + + +def update(table: _DMLTableArgument) -> Update: + r"""Construct an :class:`_expression.Update` object. + + E.g.:: + + from sqlalchemy import update + + stmt = ( + update(user_table). + where(user_table.c.id == 5). + values(name='user #5') + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.update` method on + :class:`_schema.Table`. + + :param table: A :class:`_schema.Table` + object representing the database + table to be updated. + + + .. seealso:: + + :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` + + + """ + return Update(table) + + +def delete(table: _DMLTableArgument) -> Delete: + r"""Construct :class:`_expression.Delete` object. + + E.g.:: + + from sqlalchemy import delete + + stmt = ( + delete(user_table). + where(user_table.c.id == 5) + ) + + Similar functionality is available via the + :meth:`_expression.TableClause.delete` method on + :class:`_schema.Table`. + + :param table: The table to delete rows from. + + .. seealso:: + + :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` + + + """ + return Delete(table) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/_elements_constructors.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_elements_constructors.py new file mode 100644 index 0000000..77cc2a8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_elements_constructors.py @@ -0,0 +1,1840 @@ +# sql/_elements_constructors.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple as typing_Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import roles +from .base import _NoArg +from .coercions import _document_text_coercion +from .elements import BindParameter +from .elements import BooleanClauseList +from .elements import Case +from .elements import Cast +from .elements import CollationClause +from .elements import CollectionAggregate +from .elements import ColumnClause +from .elements import ColumnElement +from .elements import Extract +from .elements import False_ +from .elements import FunctionFilter +from .elements import Label +from .elements import Null +from .elements import Over +from .elements import TextClause +from .elements import True_ +from .elements import TryCast +from .elements import Tuple +from .elements import TypeCoerce +from .elements import UnaryExpression +from .elements import WithinGroup +from .functions import FunctionElement +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from ._typing import _ByArgument + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrLiteralArgument + from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _TypeEngineArgument + from .elements import BinaryExpression + from .selectable import FromClause + from .type_api import TypeEngine + +_T = TypeVar("_T") + + +def all_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]: + """Produce an ALL expression. + + For dialects such as that of PostgreSQL, this operator applies + to usage of the :class:`_types.ARRAY` datatype, for that of + MySQL, it may apply to a subquery. e.g.:: + + # renders on PostgreSQL: + # '5 = ALL (somearray)' + expr = 5 == all_(mytable.c.somearray) + + # renders on MySQL: + # '5 = ALL (SELECT value FROM table)' + expr = 5 == all_(select(table.c.value)) + + Comparison to NULL may work using ``None``:: + + None == all_(mytable.c.somearray) + + The any_() / all_() operators also feature a special "operand flipping" + behavior such that if any_() / all_() are used on the left side of a + comparison using a standalone operator such as ``==``, ``!=``, etc. + (not including operator methods such as + :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped:: + + # would render '5 = ALL (column)` + all_(mytable.c.column) == 5 + + Or with ``None``, which note will not perform + the usual step of rendering "IS" as is normally the case for NULL:: + + # would render 'NULL = ALL(somearray)' + all_(mytable.c.somearray) == None + + .. versionchanged:: 1.4.26 repaired the use of any_() / all_() + comparing to NULL on the right side to be flipped to the left. + + The column-level :meth:`_sql.ColumnElement.all_` method (not to be + confused with :class:`_types.ARRAY` level + :meth:`_types.ARRAY.Comparator.all`) is shorthand for + ``all_(col)``:: + + 5 == mytable.c.somearray.all_() + + .. seealso:: + + :meth:`_sql.ColumnOperators.all_` + + :func:`_expression.any_` + + """ + return CollectionAggregate._create_all(expr) + + +def and_( # type: ignore[empty-body] + initial_clause: Union[Literal[True], _ColumnExpressionArgument[bool]], + *clauses: _ColumnExpressionArgument[bool], +) -> ColumnElement[bool]: + r"""Produce a conjunction of expressions joined by ``AND``. + + E.g.:: + + from sqlalchemy import and_ + + stmt = select(users_table).where( + and_( + users_table.c.name == 'wendy', + users_table.c.enrolled == True + ) + ) + + The :func:`.and_` conjunction is also available using the + Python ``&`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') & + (users_table.c.enrolled == True) + ) + + The :func:`.and_` operation is also implicit in some cases; + the :meth:`_expression.Select.where` + method for example can be invoked multiple + times against a statement, which will have the effect of each + clause being combined using :func:`.and_`:: + + stmt = select(users_table).\ + where(users_table.c.name == 'wendy').\ + where(users_table.c.enrolled == True) + + The :func:`.and_` construct must be given at least one positional + argument in order to be valid; a :func:`.and_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.and_` expression, from a given list of expressions, + a "default" element of :func:`_sql.true` (or just ``True``) should be + specified:: + + from sqlalchemy import true + criteria = and_(true(), *expressions) + + The above expression will compile to SQL as the expression ``true`` + or ``1 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.true` value is + ignored as it does not affect the outcome of an AND expression that + has other elements. + + .. deprecated:: 1.4 The :func:`.and_` element now requires that at + least one argument is passed; creating the :func:`.and_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.or_` + + """ + ... + + +if not TYPE_CHECKING: + # handle deprecated case which allows zero-arguments + def and_(*clauses): # noqa: F811 + r"""Produce a conjunction of expressions joined by ``AND``. + + E.g.:: + + from sqlalchemy import and_ + + stmt = select(users_table).where( + and_( + users_table.c.name == 'wendy', + users_table.c.enrolled == True + ) + ) + + The :func:`.and_` conjunction is also available using the + Python ``&`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') & + (users_table.c.enrolled == True) + ) + + The :func:`.and_` operation is also implicit in some cases; + the :meth:`_expression.Select.where` + method for example can be invoked multiple + times against a statement, which will have the effect of each + clause being combined using :func:`.and_`:: + + stmt = select(users_table).\ + where(users_table.c.name == 'wendy').\ + where(users_table.c.enrolled == True) + + The :func:`.and_` construct must be given at least one positional + argument in order to be valid; a :func:`.and_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.and_` expression, from a given list of expressions, + a "default" element of :func:`_sql.true` (or just ``True``) should be + specified:: + + from sqlalchemy import true + criteria = and_(true(), *expressions) + + The above expression will compile to SQL as the expression ``true`` + or ``1 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.true` value + is ignored as it does not affect the outcome of an AND expression that + has other elements. + + .. deprecated:: 1.4 The :func:`.and_` element now requires that at + least one argument is passed; creating the :func:`.and_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.or_` + + """ + return BooleanClauseList.and_(*clauses) + + +def any_(expr: _ColumnExpressionArgument[_T]) -> CollectionAggregate[bool]: + """Produce an ANY expression. + + For dialects such as that of PostgreSQL, this operator applies + to usage of the :class:`_types.ARRAY` datatype, for that of + MySQL, it may apply to a subquery. e.g.:: + + # renders on PostgreSQL: + # '5 = ANY (somearray)' + expr = 5 == any_(mytable.c.somearray) + + # renders on MySQL: + # '5 = ANY (SELECT value FROM table)' + expr = 5 == any_(select(table.c.value)) + + Comparison to NULL may work using ``None`` or :func:`_sql.null`:: + + None == any_(mytable.c.somearray) + + The any_() / all_() operators also feature a special "operand flipping" + behavior such that if any_() / all_() are used on the left side of a + comparison using a standalone operator such as ``==``, ``!=``, etc. + (not including operator methods such as + :meth:`_sql.ColumnOperators.is_`) the rendered expression is flipped:: + + # would render '5 = ANY (column)` + any_(mytable.c.column) == 5 + + Or with ``None``, which note will not perform + the usual step of rendering "IS" as is normally the case for NULL:: + + # would render 'NULL = ANY(somearray)' + any_(mytable.c.somearray) == None + + .. versionchanged:: 1.4.26 repaired the use of any_() / all_() + comparing to NULL on the right side to be flipped to the left. + + The column-level :meth:`_sql.ColumnElement.any_` method (not to be + confused with :class:`_types.ARRAY` level + :meth:`_types.ARRAY.Comparator.any`) is shorthand for + ``any_(col)``:: + + 5 = mytable.c.somearray.any_() + + .. seealso:: + + :meth:`_sql.ColumnOperators.any_` + + :func:`_expression.all_` + + """ + return CollectionAggregate._create_any(expr) + + +def asc( + column: _ColumnExpressionOrStrLabelArgument[_T], +) -> UnaryExpression[_T]: + """Produce an ascending ``ORDER BY`` clause element. + + e.g.:: + + from sqlalchemy import asc + stmt = select(users_table).order_by(asc(users_table.c.name)) + + will produce SQL as:: + + SELECT id, name FROM user ORDER BY name ASC + + The :func:`.asc` function is a standalone version of the + :meth:`_expression.ColumnElement.asc` + method available on all SQL expressions, + e.g.:: + + + stmt = select(users_table).order_by(users_table.c.name.asc()) + + :param column: A :class:`_expression.ColumnElement` (e.g. + scalar SQL expression) + with which to apply the :func:`.asc` operation. + + .. seealso:: + + :func:`.desc` + + :func:`.nulls_first` + + :func:`.nulls_last` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_asc(column) + + +def collate( + expression: _ColumnExpressionArgument[str], collation: str +) -> BinaryExpression[str]: + """Return the clause ``expression COLLATE collation``. + + e.g.:: + + collate(mycolumn, 'utf8_bin') + + produces:: + + mycolumn COLLATE utf8_bin + + The collation expression is also quoted if it is a case sensitive + identifier, e.g. contains uppercase characters. + + .. versionchanged:: 1.2 quoting is automatically applied to COLLATE + expressions if they are case sensitive. + + """ + return CollationClause._create_collation_expression(expression, collation) + + +def between( + expr: _ColumnExpressionOrLiteralArgument[_T], + lower_bound: Any, + upper_bound: Any, + symmetric: bool = False, +) -> BinaryExpression[bool]: + """Produce a ``BETWEEN`` predicate clause. + + E.g.:: + + from sqlalchemy import between + stmt = select(users_table).where(between(users_table.c.id, 5, 7)) + + Would produce SQL resembling:: + + SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2 + + The :func:`.between` function is a standalone version of the + :meth:`_expression.ColumnElement.between` method available on all + SQL expressions, as in:: + + stmt = select(users_table).where(users_table.c.id.between(5, 7)) + + All arguments passed to :func:`.between`, including the left side + column expression, are coerced from Python scalar values if a + the value is not a :class:`_expression.ColumnElement` subclass. + For example, + three fixed values can be compared as in:: + + print(between(5, 3, 7)) + + Which would produce:: + + :param_1 BETWEEN :param_2 AND :param_3 + + :param expr: a column expression, typically a + :class:`_expression.ColumnElement` + instance or alternatively a Python scalar expression to be coerced + into a column expression, serving as the left side of the ``BETWEEN`` + expression. + + :param lower_bound: a column or Python scalar expression serving as the + lower bound of the right side of the ``BETWEEN`` expression. + + :param upper_bound: a column or Python scalar expression serving as the + upper bound of the right side of the ``BETWEEN`` expression. + + :param symmetric: if True, will render " BETWEEN SYMMETRIC ". Note + that not all databases support this syntax. + + .. seealso:: + + :meth:`_expression.ColumnElement.between` + + """ + col_expr = coercions.expect(roles.ExpressionElementRole, expr) + return col_expr.between(lower_bound, upper_bound, symmetric=symmetric) + + +def outparam( + key: str, type_: Optional[TypeEngine[_T]] = None +) -> BindParameter[_T]: + """Create an 'OUT' parameter for usage in functions (stored procedures), + for databases which support them. + + The ``outparam`` can be used like a regular function parameter. + The "output" value will be available from the + :class:`~sqlalchemy.engine.CursorResult` object via its ``out_parameters`` + attribute, which returns a dictionary containing the values. + + """ + return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) + + +@overload +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: ... + + +@overload +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: ... + + +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: + """Return a negation of the given clause, i.e. ``NOT(clause)``. + + The ``~`` operator is also overloaded on all + :class:`_expression.ColumnElement` subclasses to produce the + same result. + + """ + + return coercions.expect(roles.ExpressionElementRole, clause).__invert__() + + +def bindparam( + key: Optional[str], + value: Any = _NoArg.NO_ARG, + type_: Optional[_TypeEngineArgument[_T]] = None, + unique: bool = False, + required: Union[bool, Literal[_NoArg.NO_ARG]] = _NoArg.NO_ARG, + quote: Optional[bool] = None, + callable_: Optional[Callable[[], Any]] = None, + expanding: bool = False, + isoutparam: bool = False, + literal_execute: bool = False, +) -> BindParameter[_T]: + r"""Produce a "bound expression". + + The return value is an instance of :class:`.BindParameter`; this + is a :class:`_expression.ColumnElement` + subclass which represents a so-called + "placeholder" value in a SQL expression, the value of which is + supplied at the point at which the statement in executed against a + database connection. + + In SQLAlchemy, the :func:`.bindparam` construct has + the ability to carry along the actual value that will be ultimately + used at expression time. In this way, it serves not just as + a "placeholder" for eventual population, but also as a means of + representing so-called "unsafe" values which should not be rendered + directly in a SQL statement, but rather should be passed along + to the :term:`DBAPI` as values which need to be correctly escaped + and potentially handled for type-safety. + + When using :func:`.bindparam` explicitly, the use case is typically + one of traditional deferment of parameters; the :func:`.bindparam` + construct accepts a name which can then be referred to at execution + time:: + + from sqlalchemy import bindparam + + stmt = select(users_table).where( + users_table.c.name == bindparam("username") + ) + + The above statement, when rendered, will produce SQL similar to:: + + SELECT id, name FROM user WHERE name = :username + + In order to populate the value of ``:username`` above, the value + would typically be applied at execution time to a method + like :meth:`_engine.Connection.execute`:: + + result = connection.execute(stmt, {"username": "wendy"}) + + Explicit use of :func:`.bindparam` is also common when producing + UPDATE or DELETE statements that are to be invoked multiple times, + where the WHERE criterion of the statement is to change on each + invocation, such as:: + + stmt = ( + users_table.update() + .where(user_table.c.name == bindparam("username")) + .values(fullname=bindparam("fullname")) + ) + + connection.execute( + stmt, + [ + {"username": "wendy", "fullname": "Wendy Smith"}, + {"username": "jack", "fullname": "Jack Jones"}, + ], + ) + + SQLAlchemy's Core expression system makes wide use of + :func:`.bindparam` in an implicit sense. It is typical that Python + literal values passed to virtually all SQL expression functions are + coerced into fixed :func:`.bindparam` constructs. For example, given + a comparison operation such as:: + + expr = users_table.c.name == 'Wendy' + + The above expression will produce a :class:`.BinaryExpression` + construct, where the left side is the :class:`_schema.Column` object + representing the ``name`` column, and the right side is a + :class:`.BindParameter` representing the literal value:: + + print(repr(expr.right)) + BindParameter('%(4327771088 name)s', 'Wendy', type_=String()) + + The expression above will render SQL such as:: + + user.name = :name_1 + + Where the ``:name_1`` parameter name is an anonymous name. The + actual string ``Wendy`` is not in the rendered string, but is carried + along where it is later used within statement execution. If we + invoke a statement like the following:: + + stmt = select(users_table).where(users_table.c.name == 'Wendy') + result = connection.execute(stmt) + + We would see SQL logging output as:: + + SELECT "user".id, "user".name + FROM "user" + WHERE "user".name = %(name_1)s + {'name_1': 'Wendy'} + + Above, we see that ``Wendy`` is passed as a parameter to the database, + while the placeholder ``:name_1`` is rendered in the appropriate form + for the target database, in this case the PostgreSQL database. + + Similarly, :func:`.bindparam` is invoked automatically when working + with :term:`CRUD` statements as far as the "VALUES" portion is + concerned. The :func:`_expression.insert` construct produces an + ``INSERT`` expression which will, at statement execution time, generate + bound placeholders based on the arguments passed, as in:: + + stmt = users_table.insert() + result = connection.execute(stmt, {"name": "Wendy"}) + + The above will produce SQL output as:: + + INSERT INTO "user" (name) VALUES (%(name)s) + {'name': 'Wendy'} + + The :class:`_expression.Insert` construct, at + compilation/execution time, rendered a single :func:`.bindparam` + mirroring the column name ``name`` as a result of the single ``name`` + parameter we passed to the :meth:`_engine.Connection.execute` method. + + :param key: + the key (e.g. the name) for this bind param. + Will be used in the generated + SQL statement for dialects that use named parameters. This + value may be modified when part of a compilation operation, + if other :class:`BindParameter` objects exist with the same + key, or if its length is too long and truncation is + required. + + If omitted, an "anonymous" name is generated for the bound parameter; + when given a value to bind, the end result is equivalent to calling upon + the :func:`.literal` function with a value to bind, particularly + if the :paramref:`.bindparam.unique` parameter is also provided. + + :param value: + Initial value for this bind param. Will be used at statement + execution time as the value for this parameter passed to the + DBAPI, if no other value is indicated to the statement execution + method for this particular parameter name. Defaults to ``None``. + + :param callable\_: + A callable function that takes the place of "value". The function + will be called at statement execution time to determine the + ultimate value. Used for scenarios where the actual bind + value cannot be determined at the point at which the clause + construct is created, but embedded bind values are still desirable. + + :param type\_: + A :class:`.TypeEngine` class or instance representing an optional + datatype for this :func:`.bindparam`. If not passed, a type + may be determined automatically for the bind, based on the given + value; for example, trivial Python types such as ``str``, + ``int``, ``bool`` + may result in the :class:`.String`, :class:`.Integer` or + :class:`.Boolean` types being automatically selected. + + The type of a :func:`.bindparam` is significant especially in that + the type will apply pre-processing to the value before it is + passed to the database. For example, a :func:`.bindparam` which + refers to a datetime value, and is specified as holding the + :class:`.DateTime` type, may apply conversion needed to the + value (such as stringification on SQLite) before passing the value + to the database. + + :param unique: + if True, the key name of this :class:`.BindParameter` will be + modified if another :class:`.BindParameter` of the same name + already has been located within the containing + expression. This flag is used generally by the internals + when producing so-called "anonymous" bound expressions, it + isn't generally applicable to explicitly-named :func:`.bindparam` + constructs. + + :param required: + If ``True``, a value is required at execution time. If not passed, + it defaults to ``True`` if neither :paramref:`.bindparam.value` + or :paramref:`.bindparam.callable` were passed. If either of these + parameters are present, then :paramref:`.bindparam.required` + defaults to ``False``. + + :param quote: + True if this parameter name requires quoting and is not + currently known as a SQLAlchemy reserved word; this currently + only applies to the Oracle backend, where bound names must + sometimes be quoted. + + :param isoutparam: + if True, the parameter should be treated like a stored procedure + "OUT" parameter. This applies to backends such as Oracle which + support OUT parameters. + + :param expanding: + if True, this parameter will be treated as an "expanding" parameter + at execution time; the parameter value is expected to be a sequence, + rather than a scalar value, and the string SQL statement will + be transformed on a per-execution basis to accommodate the sequence + with a variable number of parameter slots passed to the DBAPI. + This is to allow statement caching to be used in conjunction with + an IN clause. + + .. seealso:: + + :meth:`.ColumnOperators.in_` + + :ref:`baked_in` - with baked queries + + .. note:: The "expanding" feature does not support "executemany"- + style parameter sets. + + .. versionadded:: 1.2 + + .. versionchanged:: 1.3 the "expanding" bound parameter feature now + supports empty lists. + + :param literal_execute: + if True, the bound parameter will be rendered in the compile phase + with a special "POSTCOMPILE" token, and the SQLAlchemy compiler will + render the final value of the parameter into the SQL statement at + statement execution time, omitting the value from the parameter + dictionary / list passed to DBAPI ``cursor.execute()``. This + produces a similar effect as that of using the ``literal_binds``, + compilation flag, however takes place as the statement is sent to + the DBAPI ``cursor.execute()`` method, rather than when the statement + is compiled. The primary use of this + capability is for rendering LIMIT / OFFSET clauses for database + drivers that can't accommodate for bound parameters in these + contexts, while allowing SQL constructs to be cacheable at the + compilation level. + + .. versionadded:: 1.4 Added "post compile" bound parameters + + .. seealso:: + + :ref:`change_4808`. + + .. seealso:: + + :ref:`tutorial_sending_parameters` - in the + :ref:`unified_tutorial` + + + """ + return BindParameter( + key, + value, + type_, + unique, + required, + quote, + callable_, + expanding, + isoutparam, + literal_execute, + ) + + +def case( + *whens: Union[ + typing_Tuple[_ColumnExpressionArgument[bool], Any], Mapping[Any, Any] + ], + value: Optional[Any] = None, + else_: Optional[Any] = None, +) -> Case[Any]: + r"""Produce a ``CASE`` expression. + + The ``CASE`` construct in SQL is a conditional object that + acts somewhat analogously to an "if/then" construct in other + languages. It returns an instance of :class:`.Case`. + + :func:`.case` in its usual form is passed a series of "when" + constructs, that is, a list of conditions and results as tuples:: + + from sqlalchemy import case + + stmt = select(users_table).\ + where( + case( + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J'), + else_='E' + ) + ) + + The above statement will produce SQL resembling:: + + SELECT id, name FROM user + WHERE CASE + WHEN (name = :name_1) THEN :param_1 + WHEN (name = :name_2) THEN :param_2 + ELSE :param_3 + END + + When simple equality expressions of several values against a single + parent column are needed, :func:`.case` also has a "shorthand" format + used via the + :paramref:`.case.value` parameter, which is passed a column + expression to be compared. In this form, the :paramref:`.case.whens` + parameter is passed as a dictionary containing expressions to be + compared against keyed to result expressions. The statement below is + equivalent to the preceding statement:: + + stmt = select(users_table).\ + where( + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name, + else_='E' + ) + ) + + The values which are accepted as result values in + :paramref:`.case.whens` as well as with :paramref:`.case.else_` are + coerced from Python literals into :func:`.bindparam` constructs. + SQL expressions, e.g. :class:`_expression.ColumnElement` constructs, + are accepted + as well. To coerce a literal string expression into a constant + expression rendered inline, use the :func:`_expression.literal_column` + construct, + as in:: + + from sqlalchemy import case, literal_column + + case( + ( + orderline.c.qty > 100, + literal_column("'greaterthan100'") + ), + ( + orderline.c.qty > 10, + literal_column("'greaterthan10'") + ), + else_=literal_column("'lessthan10'") + ) + + The above will render the given constants without using bound + parameters for the result values (but still for the comparison + values), as in:: + + CASE + WHEN (orderline.qty > :qty_1) THEN 'greaterthan100' + WHEN (orderline.qty > :qty_2) THEN 'greaterthan10' + ELSE 'lessthan10' + END + + :param \*whens: The criteria to be compared against, + :paramref:`.case.whens` accepts two different forms, based on + whether or not :paramref:`.case.value` is used. + + .. versionchanged:: 1.4 the :func:`_sql.case` + function now accepts the series of WHEN conditions positionally + + In the first form, it accepts multiple 2-tuples passed as positional + arguments; each 2-tuple consists of ``(, )``, + where the SQL expression is a boolean expression and "value" is a + resulting value, e.g.:: + + case( + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J') + ) + + In the second form, it accepts a Python dictionary of comparison + values mapped to a resulting value; this form requires + :paramref:`.case.value` to be present, and values will be compared + using the ``==`` operator, e.g.:: + + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name + ) + + :param value: An optional SQL expression which will be used as a + fixed "comparison point" for candidate values within a dictionary + passed to :paramref:`.case.whens`. + + :param else\_: An optional SQL expression which will be the evaluated + result of the ``CASE`` construct if all expressions within + :paramref:`.case.whens` evaluate to false. When omitted, most + databases will produce a result of NULL if none of the "when" + expressions evaluate to true. + + + """ + return Case(*whens, value=value, else_=else_) + + +def cast( + expression: _ColumnExpressionOrLiteralArgument[Any], + type_: _TypeEngineArgument[_T], +) -> Cast[_T]: + r"""Produce a ``CAST`` expression. + + :func:`.cast` returns an instance of :class:`.Cast`. + + E.g.:: + + from sqlalchemy import cast, Numeric + + stmt = select(cast(product_table.c.unit_price, Numeric(10, 4))) + + The above statement will produce SQL resembling:: + + SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product + + The :func:`.cast` function performs two distinct functions when + used. The first is that it renders the ``CAST`` expression within + the resulting SQL string. The second is that it associates the given + type (e.g. :class:`.TypeEngine` class or instance) with the column + expression on the Python side, which means the expression will take + on the expression operator behavior associated with that type, + as well as the bound-value handling and result-row-handling behavior + of the type. + + An alternative to :func:`.cast` is the :func:`.type_coerce` function. + This function performs the second task of associating an expression + with a specific type, but does not render the ``CAST`` expression + in SQL. + + :param expression: A SQL expression, such as a + :class:`_expression.ColumnElement` + expression or a Python string which will be coerced into a bound + literal value. + + :param type\_: A :class:`.TypeEngine` class or instance indicating + the type to which the ``CAST`` should apply. + + .. seealso:: + + :ref:`tutorial_casts` + + :func:`.try_cast` - an alternative to CAST that results in + NULLs when the cast fails, instead of raising an error. + Only supported by some dialects. + + :func:`.type_coerce` - an alternative to CAST that coerces the type + on the Python side only, which is often sufficient to generate the + correct SQL and data coercion. + + + """ + return Cast(expression, type_) + + +def try_cast( + expression: _ColumnExpressionOrLiteralArgument[Any], + type_: _TypeEngineArgument[_T], +) -> TryCast[_T]: + """Produce a ``TRY_CAST`` expression for backends which support it; + this is a ``CAST`` which returns NULL for un-castable conversions. + + In SQLAlchemy, this construct is supported **only** by the SQL Server + dialect, and will raise a :class:`.CompileError` if used on other + included backends. However, third party backends may also support + this construct. + + .. tip:: As :func:`_sql.try_cast` originates from the SQL Server dialect, + it's importable both from ``sqlalchemy.`` as well as from + ``sqlalchemy.dialects.mssql``. + + :func:`_sql.try_cast` returns an instance of :class:`.TryCast` and + generally behaves similarly to the :class:`.Cast` construct; + at the SQL level, the difference between ``CAST`` and ``TRY_CAST`` + is that ``TRY_CAST`` returns NULL for an un-castable expression, + such as attempting to cast a string ``"hi"`` to an integer value. + + E.g.:: + + from sqlalchemy import select, try_cast, Numeric + + stmt = select( + try_cast(product_table.c.unit_price, Numeric(10, 4)) + ) + + The above would render on Microsoft SQL Server as:: + + SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4)) + FROM product_table + + .. versionadded:: 2.0.14 :func:`.try_cast` has been + generalized from the SQL Server dialect into a general use + construct that may be supported by additional dialects. + + """ + return TryCast(expression, type_) + + +def column( + text: str, + type_: Optional[_TypeEngineArgument[_T]] = None, + is_literal: bool = False, + _selectable: Optional[FromClause] = None, +) -> ColumnClause[_T]: + """Produce a :class:`.ColumnClause` object. + + The :class:`.ColumnClause` is a lightweight analogue to the + :class:`_schema.Column` class. The :func:`_expression.column` + function can + be invoked with just a name alone, as in:: + + from sqlalchemy import column + + id, name = column("id"), column("name") + stmt = select(id, name).select_from("user") + + The above statement would produce SQL like:: + + SELECT id, name FROM user + + Once constructed, :func:`_expression.column` + may be used like any other SQL + expression element such as within :func:`_expression.select` + constructs:: + + from sqlalchemy.sql import column + + id, name = column("id"), column("name") + stmt = select(id, name).select_from("user") + + The text handled by :func:`_expression.column` + is assumed to be handled + like the name of a database column; if the string contains mixed case, + special characters, or matches a known reserved word on the target + backend, the column expression will render using the quoting + behavior determined by the backend. To produce a textual SQL + expression that is rendered exactly without any quoting, + use :func:`_expression.literal_column` instead, + or pass ``True`` as the + value of :paramref:`_expression.column.is_literal`. Additionally, + full SQL + statements are best handled using the :func:`_expression.text` + construct. + + :func:`_expression.column` can be used in a table-like + fashion by combining it with the :func:`.table` function + (which is the lightweight analogue to :class:`_schema.Table` + ) to produce + a working table construct with minimal boilerplate:: + + from sqlalchemy import table, column, select + + user = table("user", + column("id"), + column("name"), + column("description"), + ) + + stmt = select(user.c.description).where(user.c.name == 'wendy') + + A :func:`_expression.column` / :func:`.table` + construct like that illustrated + above can be created in an + ad-hoc fashion and is not associated with any + :class:`_schema.MetaData`, DDL, or events, unlike its + :class:`_schema.Table` counterpart. + + :param text: the text of the element. + + :param type: :class:`_types.TypeEngine` object which can associate + this :class:`.ColumnClause` with a type. + + :param is_literal: if True, the :class:`.ColumnClause` is assumed to + be an exact expression that will be delivered to the output with no + quoting rules applied regardless of case sensitive settings. the + :func:`_expression.literal_column()` function essentially invokes + :func:`_expression.column` while passing ``is_literal=True``. + + .. seealso:: + + :class:`_schema.Column` + + :func:`_expression.literal_column` + + :func:`.table` + + :func:`_expression.text` + + :ref:`tutorial_select_arbitrary_text` + + """ + return ColumnClause(text, type_, is_literal, _selectable) + + +def desc( + column: _ColumnExpressionOrStrLabelArgument[_T], +) -> UnaryExpression[_T]: + """Produce a descending ``ORDER BY`` clause element. + + e.g.:: + + from sqlalchemy import desc + + stmt = select(users_table).order_by(desc(users_table.c.name)) + + will produce SQL as:: + + SELECT id, name FROM user ORDER BY name DESC + + The :func:`.desc` function is a standalone version of the + :meth:`_expression.ColumnElement.desc` + method available on all SQL expressions, + e.g.:: + + + stmt = select(users_table).order_by(users_table.c.name.desc()) + + :param column: A :class:`_expression.ColumnElement` (e.g. + scalar SQL expression) + with which to apply the :func:`.desc` operation. + + .. seealso:: + + :func:`.asc` + + :func:`.nulls_first` + + :func:`.nulls_last` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_desc(column) + + +def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce an column-expression-level unary ``DISTINCT`` clause. + + This applies the ``DISTINCT`` keyword to an individual column + expression, and is typically contained within an aggregate function, + as in:: + + from sqlalchemy import distinct, func + stmt = select(func.count(distinct(users_table.c.name))) + + The above would produce an expression resembling:: + + SELECT COUNT(DISTINCT name) FROM user + + The :func:`.distinct` function is also available as a column-level + method, e.g. :meth:`_expression.ColumnElement.distinct`, as in:: + + stmt = select(func.count(users_table.c.name.distinct())) + + The :func:`.distinct` operator is different from the + :meth:`_expression.Select.distinct` method of + :class:`_expression.Select`, + which produces a ``SELECT`` statement + with ``DISTINCT`` applied to the result set as a whole, + e.g. a ``SELECT DISTINCT`` expression. See that method for further + information. + + .. seealso:: + + :meth:`_expression.ColumnElement.distinct` + + :meth:`_expression.Select.distinct` + + :data:`.func` + + """ + return UnaryExpression._create_distinct(expr) + + +def bitwise_not(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce a unary bitwise NOT clause, typically via the ``~`` operator. + + Not to be confused with boolean negation :func:`_sql.not_`. + + .. versionadded:: 2.0.2 + + .. seealso:: + + :ref:`operators_bitwise` + + + """ + + return UnaryExpression._create_bitwise_not(expr) + + +def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: + """Return a :class:`.Extract` construct. + + This is typically available as :func:`.extract` + as well as ``func.extract`` from the + :data:`.func` namespace. + + :param field: The field to extract. + + :param expr: A column or Python scalar expression serving as the + right side of the ``EXTRACT`` expression. + + E.g.:: + + from sqlalchemy import extract + from sqlalchemy import table, column + + logged_table = table("user", + column("id"), + column("date_created"), + ) + + stmt = select(logged_table.c.id).where( + extract("YEAR", logged_table.c.date_created) == 2021 + ) + + In the above example, the statement is used to select ids from the + database where the ``YEAR`` component matches a specific value. + + Similarly, one can also select an extracted component:: + + stmt = select( + extract("YEAR", logged_table.c.date_created) + ).where(logged_table.c.id == 1) + + The implementation of ``EXTRACT`` may vary across database backends. + Users are reminded to consult their database documentation. + """ + return Extract(field, expr) + + +def false() -> False_: + """Return a :class:`.False_` construct. + + E.g.: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import false + >>> print(select(t.c.x).where(false())) + {printsql}SELECT x FROM t WHERE false + + A backend which does not support true/false constants will render as + an expression against 1 or 0: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(false())) + {printsql}SELECT x FROM t WHERE 0 = 1 + + The :func:`.true` and :func:`.false` constants also feature + "short circuit" operation within an :func:`.and_` or :func:`.or_` + conjunction: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(or_(t.c.x > 5, true()))) + {printsql}SELECT x FROM t WHERE true{stop} + + >>> print(select(t.c.x).where(and_(t.c.x > 5, false()))) + {printsql}SELECT x FROM t WHERE false{stop} + + .. seealso:: + + :func:`.true` + + """ + + return False_._instance() + + +def funcfilter( + func: FunctionElement[_T], *criterion: _ColumnExpressionArgument[bool] +) -> FunctionFilter[_T]: + """Produce a :class:`.FunctionFilter` object against a function. + + Used against aggregate and window functions, + for database backends that support the "FILTER" clause. + + E.g.:: + + from sqlalchemy import funcfilter + funcfilter(func.count(1), MyClass.name == 'some name') + + Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')". + + This function is also available from the :data:`~.expression.func` + construct itself via the :meth:`.FunctionElement.filter` method. + + .. seealso:: + + :ref:`tutorial_functions_within_group` - in the + :ref:`unified_tutorial` + + :meth:`.FunctionElement.filter` + + """ + return FunctionFilter(func, *criterion) + + +def label( + name: str, + element: _ColumnExpressionArgument[_T], + type_: Optional[_TypeEngineArgument[_T]] = None, +) -> Label[_T]: + """Return a :class:`Label` object for the + given :class:`_expression.ColumnElement`. + + A label changes the name of an element in the columns clause of a + ``SELECT`` statement, typically via the ``AS`` SQL keyword. + + This functionality is more conveniently available via the + :meth:`_expression.ColumnElement.label` method on + :class:`_expression.ColumnElement`. + + :param name: label name + + :param obj: a :class:`_expression.ColumnElement`. + + """ + return Label(name, element, type_) + + +def null() -> Null: + """Return a constant :class:`.Null` construct.""" + + return Null._instance() + + +def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce the ``NULLS FIRST`` modifier for an ``ORDER BY`` expression. + + :func:`.nulls_first` is intended to modify the expression produced + by :func:`.asc` or :func:`.desc`, and indicates how NULL values + should be handled when they are encountered during ordering:: + + + from sqlalchemy import desc, nulls_first + + stmt = select(users_table).order_by( + nulls_first(desc(users_table.c.name))) + + The SQL expression from the above would resemble:: + + SELECT id, name FROM user ORDER BY name DESC NULLS FIRST + + Like :func:`.asc` and :func:`.desc`, :func:`.nulls_first` is typically + invoked from the column expression itself using + :meth:`_expression.ColumnElement.nulls_first`, + rather than as its standalone + function version, as in:: + + stmt = select(users_table).order_by( + users_table.c.name.desc().nulls_first()) + + .. versionchanged:: 1.4 :func:`.nulls_first` is renamed from + :func:`.nullsfirst` in previous releases. + The previous name remains available for backwards compatibility. + + .. seealso:: + + :func:`.asc` + + :func:`.desc` + + :func:`.nulls_last` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_nulls_first(column) + + +def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: + """Produce the ``NULLS LAST`` modifier for an ``ORDER BY`` expression. + + :func:`.nulls_last` is intended to modify the expression produced + by :func:`.asc` or :func:`.desc`, and indicates how NULL values + should be handled when they are encountered during ordering:: + + + from sqlalchemy import desc, nulls_last + + stmt = select(users_table).order_by( + nulls_last(desc(users_table.c.name))) + + The SQL expression from the above would resemble:: + + SELECT id, name FROM user ORDER BY name DESC NULLS LAST + + Like :func:`.asc` and :func:`.desc`, :func:`.nulls_last` is typically + invoked from the column expression itself using + :meth:`_expression.ColumnElement.nulls_last`, + rather than as its standalone + function version, as in:: + + stmt = select(users_table).order_by( + users_table.c.name.desc().nulls_last()) + + .. versionchanged:: 1.4 :func:`.nulls_last` is renamed from + :func:`.nullslast` in previous releases. + The previous name remains available for backwards compatibility. + + .. seealso:: + + :func:`.asc` + + :func:`.desc` + + :func:`.nulls_first` + + :meth:`_expression.Select.order_by` + + """ + return UnaryExpression._create_nulls_last(column) + + +def or_( # type: ignore[empty-body] + initial_clause: Union[Literal[False], _ColumnExpressionArgument[bool]], + *clauses: _ColumnExpressionArgument[bool], +) -> ColumnElement[bool]: + """Produce a conjunction of expressions joined by ``OR``. + + E.g.:: + + from sqlalchemy import or_ + + stmt = select(users_table).where( + or_( + users_table.c.name == 'wendy', + users_table.c.name == 'jack' + ) + ) + + The :func:`.or_` conjunction is also available using the + Python ``|`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') | + (users_table.c.name == 'jack') + ) + + The :func:`.or_` construct must be given at least one positional + argument in order to be valid; a :func:`.or_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.or_` expression, from a given list of expressions, + a "default" element of :func:`_sql.false` (or just ``False``) should be + specified:: + + from sqlalchemy import false + or_criteria = or_(false(), *expressions) + + The above expression will compile to SQL as the expression ``false`` + or ``0 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.false` value is + ignored as it does not affect the outcome of an OR expression which + has other elements. + + .. deprecated:: 1.4 The :func:`.or_` element now requires that at + least one argument is passed; creating the :func:`.or_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.and_` + + """ + ... + + +if not TYPE_CHECKING: + # handle deprecated case which allows zero-arguments + def or_(*clauses): # noqa: F811 + """Produce a conjunction of expressions joined by ``OR``. + + E.g.:: + + from sqlalchemy import or_ + + stmt = select(users_table).where( + or_( + users_table.c.name == 'wendy', + users_table.c.name == 'jack' + ) + ) + + The :func:`.or_` conjunction is also available using the + Python ``|`` operator (though note that compound expressions + need to be parenthesized in order to function with Python + operator precedence behavior):: + + stmt = select(users_table).where( + (users_table.c.name == 'wendy') | + (users_table.c.name == 'jack') + ) + + The :func:`.or_` construct must be given at least one positional + argument in order to be valid; a :func:`.or_` construct with no + arguments is ambiguous. To produce an "empty" or dynamically + generated :func:`.or_` expression, from a given list of expressions, + a "default" element of :func:`_sql.false` (or just ``False``) should be + specified:: + + from sqlalchemy import false + or_criteria = or_(false(), *expressions) + + The above expression will compile to SQL as the expression ``false`` + or ``0 = 1``, depending on backend, if no other expressions are + present. If expressions are present, then the :func:`_sql.false` value + is ignored as it does not affect the outcome of an OR expression which + has other elements. + + .. deprecated:: 1.4 The :func:`.or_` element now requires that at + least one argument is passed; creating the :func:`.or_` construct + with no arguments is deprecated, and will emit a deprecation warning + while continuing to produce a blank SQL string. + + .. seealso:: + + :func:`.and_` + + """ + return BooleanClauseList.or_(*clauses) + + +def over( + element: FunctionElement[_T], + partition_by: Optional[_ByArgument] = None, + order_by: Optional[_ByArgument] = None, + range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, + rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, +) -> Over[_T]: + r"""Produce an :class:`.Over` object against a function. + + Used against aggregate or so-called "window" functions, + for database backends that support window functions. + + :func:`_expression.over` is usually called using + the :meth:`.FunctionElement.over` method, e.g.:: + + func.row_number().over(order_by=mytable.c.some_column) + + Would produce:: + + ROW_NUMBER() OVER(ORDER BY some_column) + + Ranges are also possible using the :paramref:`.expression.over.range_` + and :paramref:`.expression.over.rows` parameters. These + mutually-exclusive parameters each accept a 2-tuple, which contains + a combination of integers and None:: + + func.row_number().over( + order_by=my_table.c.some_column, range_=(None, 0)) + + The above would produce:: + + ROW_NUMBER() OVER(ORDER BY some_column + RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) + + A value of ``None`` indicates "unbounded", a + value of zero indicates "current row", and negative / positive + integers indicate "preceding" and "following": + + * RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING:: + + func.row_number().over(order_by='x', range_=(-5, 10)) + + * ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW:: + + func.row_number().over(order_by='x', rows=(None, 0)) + + * RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING:: + + func.row_number().over(order_by='x', range_=(-2, None)) + + * RANGE BETWEEN 1 FOLLOWING AND 3 FOLLOWING:: + + func.row_number().over(order_by='x', range_=(1, 3)) + + :param element: a :class:`.FunctionElement`, :class:`.WithinGroup`, + or other compatible construct. + :param partition_by: a column element or string, or a list + of such, that will be used as the PARTITION BY clause + of the OVER construct. + :param order_by: a column element or string, or a list + of such, that will be used as the ORDER BY clause + of the OVER construct. + :param range\_: optional range clause for the window. This is a + tuple value which can contain integer values or ``None``, + and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause. + + :param rows: optional rows clause for the window. This is a tuple + value which can contain integer values or None, and will render + a ROWS BETWEEN PRECEDING / FOLLOWING clause. + + This function is also available from the :data:`~.expression.func` + construct itself via the :meth:`.FunctionElement.over` method. + + .. seealso:: + + :ref:`tutorial_window_functions` - in the :ref:`unified_tutorial` + + :data:`.expression.func` + + :func:`_expression.within_group` + + """ + return Over(element, partition_by, order_by, range_, rows) + + +@_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") +def text(text: str) -> TextClause: + r"""Construct a new :class:`_expression.TextClause` clause, + representing + a textual SQL string directly. + + E.g.:: + + from sqlalchemy import text + + t = text("SELECT * FROM users") + result = connection.execute(t) + + The advantages :func:`_expression.text` + provides over a plain string are + backend-neutral support for bind parameters, per-statement + execution options, as well as + bind parameter and result-column typing behavior, allowing + SQLAlchemy type constructs to play a role when executing + a statement that is specified literally. The construct can also + be provided with a ``.c`` collection of column elements, allowing + it to be embedded in other SQL expression constructs as a subquery. + + Bind parameters are specified by name, using the format ``:name``. + E.g.:: + + t = text("SELECT * FROM users WHERE id=:user_id") + result = connection.execute(t, {"user_id": 12}) + + For SQL statements where a colon is required verbatim, as within + an inline string, use a backslash to escape:: + + t = text(r"SELECT * FROM users WHERE name='\:username'") + + The :class:`_expression.TextClause` + construct includes methods which can + provide information about the bound parameters as well as the column + values which would be returned from the textual statement, assuming + it's an executable SELECT type of statement. The + :meth:`_expression.TextClause.bindparams` + method is used to provide bound + parameter detail, and :meth:`_expression.TextClause.columns` + method allows + specification of return columns including names and types:: + + t = text("SELECT * FROM users WHERE id=:user_id").\ + bindparams(user_id=7).\ + columns(id=Integer, name=String) + + for id, name in connection.execute(t): + print(id, name) + + The :func:`_expression.text` construct is used in cases when + a literal string SQL fragment is specified as part of a larger query, + such as for the WHERE clause of a SELECT statement:: + + s = select(users.c.id, users.c.name).where(text("id=:user_id")) + result = connection.execute(s, {"user_id": 12}) + + :func:`_expression.text` is also used for the construction + of a full, standalone statement using plain text. + As such, SQLAlchemy refers + to it as an :class:`.Executable` object and may be used + like any other statement passed to an ``.execute()`` method. + + :param text: + the text of the SQL statement to be created. Use ``:`` + to specify bind parameters; they will be compiled to their + engine-specific format. + + .. seealso:: + + :ref:`tutorial_select_arbitrary_text` + + """ + return TextClause(text) + + +def true() -> True_: + """Return a constant :class:`.True_` construct. + + E.g.: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy import true + >>> print(select(t.c.x).where(true())) + {printsql}SELECT x FROM t WHERE true + + A backend which does not support true/false constants will render as + an expression against 1 or 0: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(true())) + {printsql}SELECT x FROM t WHERE 1 = 1 + + The :func:`.true` and :func:`.false` constants also feature + "short circuit" operation within an :func:`.and_` or :func:`.or_` + conjunction: + + .. sourcecode:: pycon+sql + + >>> print(select(t.c.x).where(or_(t.c.x > 5, true()))) + {printsql}SELECT x FROM t WHERE true{stop} + + >>> print(select(t.c.x).where(and_(t.c.x > 5, false()))) + {printsql}SELECT x FROM t WHERE false{stop} + + .. seealso:: + + :func:`.false` + + """ + + return True_._instance() + + +def tuple_( + *clauses: _ColumnExpressionArgument[Any], + types: Optional[Sequence[_TypeEngineArgument[Any]]] = None, +) -> Tuple: + """Return a :class:`.Tuple`. + + Main usage is to produce a composite IN construct using + :meth:`.ColumnOperators.in_` :: + + from sqlalchemy import tuple_ + + tuple_(table.c.col1, table.c.col2).in_( + [(1, 2), (5, 12), (10, 19)] + ) + + .. versionchanged:: 1.3.6 Added support for SQLite IN tuples. + + .. warning:: + + The composite IN construct is not supported by all backends, and is + currently known to work on PostgreSQL, MySQL, and SQLite. + Unsupported backends will raise a subclass of + :class:`~sqlalchemy.exc.DBAPIError` when such an expression is + invoked. + + """ + return Tuple(*clauses, types=types) + + +def type_coerce( + expression: _ColumnExpressionOrLiteralArgument[Any], + type_: _TypeEngineArgument[_T], +) -> TypeCoerce[_T]: + r"""Associate a SQL expression with a particular type, without rendering + ``CAST``. + + E.g.:: + + from sqlalchemy import type_coerce + + stmt = select(type_coerce(log_table.date_string, StringDateTime())) + + The above construct will produce a :class:`.TypeCoerce` object, which + does not modify the rendering in any way on the SQL side, with the + possible exception of a generated label if used in a columns clause + context: + + .. sourcecode:: sql + + SELECT date_string AS date_string FROM log + + When result rows are fetched, the ``StringDateTime`` type processor + will be applied to result rows on behalf of the ``date_string`` column. + + .. note:: the :func:`.type_coerce` construct does not render any + SQL syntax of its own, including that it does not imply + parenthesization. Please use :meth:`.TypeCoerce.self_group` + if explicit parenthesization is required. + + In order to provide a named label for the expression, use + :meth:`_expression.ColumnElement.label`:: + + stmt = select( + type_coerce(log_table.date_string, StringDateTime()).label('date') + ) + + + A type that features bound-value handling will also have that behavior + take effect when literal values or :func:`.bindparam` constructs are + passed to :func:`.type_coerce` as targets. + For example, if a type implements the + :meth:`.TypeEngine.bind_expression` + method or :meth:`.TypeEngine.bind_processor` method or equivalent, + these functions will take effect at statement compilation/execution + time when a literal value is passed, as in:: + + # bound-value handling of MyStringType will be applied to the + # literal value "some string" + stmt = select(type_coerce("some string", MyStringType)) + + When using :func:`.type_coerce` with composed expressions, note that + **parenthesis are not applied**. If :func:`.type_coerce` is being + used in an operator context where the parenthesis normally present from + CAST are necessary, use the :meth:`.TypeCoerce.self_group` method: + + .. sourcecode:: pycon+sql + + >>> some_integer = column("someint", Integer) + >>> some_string = column("somestr", String) + >>> expr = type_coerce(some_integer + 5, String) + some_string + >>> print(expr) + {printsql}someint + :someint_1 || somestr{stop} + >>> expr = type_coerce(some_integer + 5, String).self_group() + some_string + >>> print(expr) + {printsql}(someint + :someint_1) || somestr{stop} + + :param expression: A SQL expression, such as a + :class:`_expression.ColumnElement` + expression or a Python string which will be coerced into a bound + literal value. + + :param type\_: A :class:`.TypeEngine` class or instance indicating + the type to which the expression is coerced. + + .. seealso:: + + :ref:`tutorial_casts` + + :func:`.cast` + + """ # noqa + return TypeCoerce(expression, type_) + + +def within_group( + element: FunctionElement[_T], *order_by: _ColumnExpressionArgument[Any] +) -> WithinGroup[_T]: + r"""Produce a :class:`.WithinGroup` object against a function. + + Used against so-called "ordered set aggregate" and "hypothetical + set aggregate" functions, including :class:`.percentile_cont`, + :class:`.rank`, :class:`.dense_rank`, etc. + + :func:`_expression.within_group` is usually called using + the :meth:`.FunctionElement.within_group` method, e.g.:: + + from sqlalchemy import within_group + stmt = select( + department.c.id, + func.percentile_cont(0.5).within_group( + department.c.salary.desc() + ) + ) + + The above statement would produce SQL similar to + ``SELECT department.id, percentile_cont(0.5) + WITHIN GROUP (ORDER BY department.salary DESC)``. + + :param element: a :class:`.FunctionElement` construct, typically + generated by :data:`~.expression.func`. + :param \*order_by: one or more column elements that will be used + as the ORDER BY clause of the WITHIN GROUP construct. + + .. seealso:: + + :ref:`tutorial_functions_within_group` - in the + :ref:`unified_tutorial` + + :data:`.expression.func` + + :func:`_expression.over` + + """ + return WithinGroup(element, *order_by) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/_orm_types.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_orm_types.py new file mode 100644 index 0000000..bccb533 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_orm_types.py @@ -0,0 +1,20 @@ +# sql/_orm_types.py +# Copyright (C) 2022-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""ORM types that need to present specifically for **documentation only** of +the Executable.execution_options() method, which includes options that +are meaningful to the ORM. + +""" + + +from __future__ import annotations + +from ..util.typing import Literal + +SynchronizeSessionArgument = Literal[False, "auto", "evaluate", "fetch"] +DMLStrategyArgument = Literal["bulk", "raw", "orm", "auto"] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/_py_util.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_py_util.py new file mode 100644 index 0000000..df372bf --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_py_util.py @@ -0,0 +1,75 @@ +# sql/_py_util.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import typing +from typing import Any +from typing import Dict +from typing import Tuple +from typing import Union + +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + from .cache_key import CacheConst + + +class prefix_anon_map(Dict[str, str]): + """A map that creates new keys for missing key access. + + Considers keys of the form " " to produce + new symbols "_", where "index" is an incrementing integer + corresponding to . + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + def __missing__(self, key: str) -> str: + (ident, derived) = key.split(" ", 1) + anonymous_counter = self.get(derived, 1) + self[derived] = anonymous_counter + 1 # type: ignore + value = f"{derived}_{anonymous_counter}" + self[key] = value + return value + + +class cache_anon_map( + Dict[Union[int, "Literal[CacheConst.NO_CACHE]"], Union[Literal[True], str]] +): + """A map that creates new keys for missing key access. + + Produces an incrementing sequence given a series of unique keys. + + This is similar to the compiler prefix_anon_map class although simpler. + + Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which + is otherwise usually used for this type of operation. + + """ + + _index = 0 + + def get_anon(self, object_: Any) -> Tuple[str, bool]: + idself = id(object_) + if idself in self: + s_val = self[idself] + assert s_val is not True + return s_val, True + else: + # inline of __missing__ + self[idself] = id_ = str(self._index) + self._index += 1 + + return id_, False + + def __missing__(self, key: int) -> str: + self[key] = val = str(self._index) + self._index += 1 + return val diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/_selectable_constructors.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_selectable_constructors.py new file mode 100644 index 0000000..c2b5008 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_selectable_constructors.py @@ -0,0 +1,635 @@ +# sql/_selectable_constructors.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import overload +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import roles +from ._typing import _ColumnsClauseArgument +from ._typing import _no_kw +from .elements import ColumnClause +from .selectable import Alias +from .selectable import CompoundSelect +from .selectable import Exists +from .selectable import FromClause +from .selectable import Join +from .selectable import Lateral +from .selectable import LateralFromClause +from .selectable import NamedFromClause +from .selectable import Select +from .selectable import TableClause +from .selectable import TableSample +from .selectable import Values + +if TYPE_CHECKING: + from ._typing import _FromClauseArgument + from ._typing import _OnClauseArgument + from ._typing import _SelectStatementForCompoundArgument + from ._typing import _T0 + from ._typing import _T1 + from ._typing import _T2 + from ._typing import _T3 + from ._typing import _T4 + from ._typing import _T5 + from ._typing import _T6 + from ._typing import _T7 + from ._typing import _T8 + from ._typing import _T9 + from ._typing import _TypedColumnClauseArgument as _TCCA + from .functions import Function + from .selectable import CTE + from .selectable import HasCTE + from .selectable import ScalarSelect + from .selectable import SelectBase + + +_T = TypeVar("_T", bound=Any) + + +def alias( + selectable: FromClause, name: Optional[str] = None, flat: bool = False +) -> NamedFromClause: + """Return a named alias of the given :class:`.FromClause`. + + For :class:`.Table` and :class:`.Join` objects, the return type is the + :class:`_expression.Alias` object. Other kinds of :class:`.NamedFromClause` + objects may be returned for other kinds of :class:`.FromClause` objects. + + The named alias represents any :class:`_expression.FromClause` with an + alternate name assigned within SQL, typically using the ``AS`` clause when + generated, e.g. ``SELECT * FROM table AS aliasname``. + + Equivalent functionality is available via the + :meth:`_expression.FromClause.alias` + method available on all :class:`_expression.FromClause` objects. + + :param selectable: any :class:`_expression.FromClause` subclass, + such as a table, select statement, etc. + + :param name: string name to be assigned as the alias. + If ``None``, a name will be deterministically generated at compile + time. Deterministic means the name is guaranteed to be unique against + other constructs used in the same statement, and will also be the same + name for each successive compilation of the same statement object. + + :param flat: Will be passed through to if the given selectable + is an instance of :class:`_expression.Join` - see + :meth:`_expression.Join.alias` for details. + + """ + return Alias._factory(selectable, name=name, flat=flat) + + +def cte( + selectable: HasCTE, name: Optional[str] = None, recursive: bool = False +) -> CTE: + r"""Return a new :class:`_expression.CTE`, + or Common Table Expression instance. + + Please see :meth:`_expression.HasCTE.cte` for detail on CTE usage. + + """ + return coercions.expect(roles.HasCTERole, selectable).cte( + name=name, recursive=recursive + ) + + +def except_( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``EXCEPT`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_except(*selects) + + +def except_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``EXCEPT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_except_all(*selects) + + +def exists( + __argument: Optional[ + Union[_ColumnsClauseArgument[Any], SelectBase, ScalarSelect[Any]] + ] = None, +) -> Exists: + """Construct a new :class:`_expression.Exists` construct. + + The :func:`_sql.exists` can be invoked by itself to produce an + :class:`_sql.Exists` construct, which will accept simple WHERE + criteria:: + + exists_criteria = exists().where(table1.c.col1 == table2.c.col2) + + However, for greater flexibility in constructing the SELECT, an + existing :class:`_sql.Select` construct may be converted to an + :class:`_sql.Exists`, most conveniently by making use of the + :meth:`_sql.SelectBase.exists` method:: + + exists_criteria = ( + select(table2.c.col2). + where(table1.c.col1 == table2.c.col2). + exists() + ) + + The EXISTS criteria is then used inside of an enclosing SELECT:: + + stmt = select(table1.c.col1).where(exists_criteria) + + The above statement will then be of the form:: + + SELECT col1 FROM table1 WHERE EXISTS + (SELECT table2.col2 FROM table2 WHERE table2.col2 = table1.col1) + + .. seealso:: + + :ref:`tutorial_exists` - in the :term:`2.0 style` tutorial. + + :meth:`_sql.SelectBase.exists` - method to transform a ``SELECT`` to an + ``EXISTS`` clause. + + """ # noqa: E501 + + return Exists(__argument) + + +def intersect( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``INTERSECT`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_intersect(*selects) + + +def intersect_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return an ``INTERSECT ALL`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + + """ + return CompoundSelect._create_intersect_all(*selects) + + +def join( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + isouter: bool = False, + full: bool = False, +) -> Join: + """Produce a :class:`_expression.Join` object, given two + :class:`_expression.FromClause` + expressions. + + E.g.:: + + j = join(user_table, address_table, + user_table.c.id == address_table.c.user_id) + stmt = select(user_table).select_from(j) + + would emit SQL along the lines of:: + + SELECT user.id, user.name FROM user + JOIN address ON user.id = address.user_id + + Similar functionality is available given any + :class:`_expression.FromClause` object (e.g. such as a + :class:`_schema.Table`) using + the :meth:`_expression.FromClause.join` method. + + :param left: The left side of the join. + + :param right: the right side of the join; this is any + :class:`_expression.FromClause` object such as a + :class:`_schema.Table` object, and + may also be a selectable-compatible object such as an ORM-mapped + class. + + :param onclause: a SQL expression representing the ON clause of the + join. If left at ``None``, :meth:`_expression.FromClause.join` + will attempt to + join the two tables based on a foreign key relationship. + + :param isouter: if True, render a LEFT OUTER JOIN, instead of JOIN. + + :param full: if True, render a FULL OUTER JOIN, instead of JOIN. + + .. seealso:: + + :meth:`_expression.FromClause.join` - method form, + based on a given left side. + + :class:`_expression.Join` - the type of object produced. + + """ + + return Join(left, right, onclause, isouter, full) + + +def lateral( + selectable: Union[SelectBase, _FromClauseArgument], + name: Optional[str] = None, +) -> LateralFromClause: + """Return a :class:`_expression.Lateral` object. + + :class:`_expression.Lateral` is an :class:`_expression.Alias` + subclass that represents + a subquery with the LATERAL keyword applied to it. + + The special behavior of a LATERAL subquery is that it appears in the + FROM clause of an enclosing SELECT, but may correlate to other + FROM clauses of that SELECT. It is a special case of subquery + only supported by a small number of backends, currently more recent + PostgreSQL versions. + + .. seealso:: + + :ref:`tutorial_lateral_correlation` - overview of usage. + + """ + return Lateral._factory(selectable, name=name) + + +def outerjoin( + left: _FromClauseArgument, + right: _FromClauseArgument, + onclause: Optional[_OnClauseArgument] = None, + full: bool = False, +) -> Join: + """Return an ``OUTER JOIN`` clause element. + + The returned object is an instance of :class:`_expression.Join`. + + Similar functionality is also available via the + :meth:`_expression.FromClause.outerjoin` method on any + :class:`_expression.FromClause`. + + :param left: The left side of the join. + + :param right: The right side of the join. + + :param onclause: Optional criterion for the ``ON`` clause, is + derived from foreign key relationships established between + left and right otherwise. + + To chain joins together, use the :meth:`_expression.FromClause.join` + or + :meth:`_expression.FromClause.outerjoin` methods on the resulting + :class:`_expression.Join` object. + + """ + return Join(left, right, onclause, isouter=True, full=full) + + +# START OVERLOADED FUNCTIONS select Select 1-10 + +# code within this block is **programmatically, +# statically generated** by tools/generate_tuple_map_overloads.py + + +@overload +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] +) -> Select[Tuple[_T0, _T1]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] +) -> Select[Tuple[_T0, _T1, _T2]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: ... + + +@overload +def select( + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + __ent8: _TCCA[_T8], + __ent9: _TCCA[_T9], +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: ... + + +# END OVERLOADED FUNCTIONS select + + +@overload +def select( + *entities: _ColumnsClauseArgument[Any], **__kw: Any +) -> Select[Any]: ... + + +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + r"""Construct a new :class:`_expression.Select`. + + + .. versionadded:: 1.4 - The :func:`_sql.select` function now accepts + column arguments positionally. The top-level :func:`_sql.select` + function will automatically use the 1.x or 2.x style API based on + the incoming arguments; using :func:`_sql.select` from the + ``sqlalchemy.future`` module will enforce that only the 2.x style + constructor is used. + + Similar functionality is also available via the + :meth:`_expression.FromClause.select` method on any + :class:`_expression.FromClause`. + + .. seealso:: + + :ref:`tutorial_selecting_data` - in the :ref:`unified_tutorial` + + :param \*entities: + Entities to SELECT from. For Core usage, this is typically a series + of :class:`_expression.ColumnElement` and / or + :class:`_expression.FromClause` + objects which will form the columns clause of the resulting + statement. For those objects that are instances of + :class:`_expression.FromClause` (typically :class:`_schema.Table` + or :class:`_expression.Alias` + objects), the :attr:`_expression.FromClause.c` + collection is extracted + to form a collection of :class:`_expression.ColumnElement` objects. + + This parameter will also accept :class:`_expression.TextClause` + constructs as + given, as well as ORM-mapped classes. + + """ + # the keyword args are a necessary element in order for the typing + # to work out w/ the varargs vs. having named "keyword" arguments that + # aren't always present. + if __kw: + raise _no_kw() + return Select(*entities) + + +def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause: + """Produce a new :class:`_expression.TableClause`. + + The object returned is an instance of + :class:`_expression.TableClause`, which + represents the "syntactical" portion of the schema-level + :class:`_schema.Table` object. + It may be used to construct lightweight table constructs. + + :param name: Name of the table. + + :param columns: A collection of :func:`_expression.column` constructs. + + :param schema: The schema name for this table. + + .. versionadded:: 1.3.18 :func:`_expression.table` can now + accept a ``schema`` argument. + """ + + return TableClause(name, *columns, **kw) + + +def tablesample( + selectable: _FromClauseArgument, + sampling: Union[float, Function[Any]], + name: Optional[str] = None, + seed: Optional[roles.ExpressionElementRole[Any]] = None, +) -> TableSample: + """Return a :class:`_expression.TableSample` object. + + :class:`_expression.TableSample` is an :class:`_expression.Alias` + subclass that represents + a table with the TABLESAMPLE clause applied to it. + :func:`_expression.tablesample` + is also available from the :class:`_expression.FromClause` + class via the + :meth:`_expression.FromClause.tablesample` method. + + The TABLESAMPLE clause allows selecting a randomly selected approximate + percentage of rows from a table. It supports multiple sampling methods, + most commonly BERNOULLI and SYSTEM. + + e.g.:: + + from sqlalchemy import func + + selectable = people.tablesample( + func.bernoulli(1), + name='alias', + seed=func.random()) + stmt = select(selectable.c.people_id) + + Assuming ``people`` with a column ``people_id``, the above + statement would render as:: + + SELECT alias.people_id FROM + people AS alias TABLESAMPLE bernoulli(:bernoulli_1) + REPEATABLE (random()) + + :param sampling: a ``float`` percentage between 0 and 100 or + :class:`_functions.Function`. + + :param name: optional alias name + + :param seed: any real-valued SQL expression. When specified, the + REPEATABLE sub-clause is also rendered. + + """ + return TableSample._factory(selectable, sampling, name=name, seed=seed) + + +def union( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return a ``UNION`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + A similar :func:`union()` method is available on all + :class:`_expression.FromClause` subclasses. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + :param \**kwargs: + available keyword arguments are the same as those of + :func:`select`. + + """ + return CompoundSelect._create_union(*selects) + + +def union_all( + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: + r"""Return a ``UNION ALL`` of multiple selectables. + + The returned object is an instance of + :class:`_expression.CompoundSelect`. + + A similar :func:`union_all()` method is available on all + :class:`_expression.FromClause` subclasses. + + :param \*selects: + a list of :class:`_expression.Select` instances. + + """ + return CompoundSelect._create_union_all(*selects) + + +def values( + *columns: ColumnClause[Any], + name: Optional[str] = None, + literal_binds: bool = False, +) -> Values: + r"""Construct a :class:`_expression.Values` construct. + + The column expressions and the actual data for + :class:`_expression.Values` are given in two separate steps. The + constructor receives the column expressions typically as + :func:`_expression.column` constructs, + and the data is then passed via the + :meth:`_expression.Values.data` method as a list, + which can be called multiple + times to add more data, e.g.:: + + from sqlalchemy import column + from sqlalchemy import values + + value_expr = values( + column('id', Integer), + column('name', String), + name="my_values" + ).data( + [(1, 'name1'), (2, 'name2'), (3, 'name3')] + ) + + :param \*columns: column expressions, typically composed using + :func:`_expression.column` objects. + + :param name: the name for this VALUES construct. If omitted, the + VALUES construct will be unnamed in a SQL expression. Different + backends may have different requirements here. + + :param literal_binds: Defaults to False. Whether or not to render + the data values inline in the SQL output, rather than using bound + parameters. + + """ + return Values(*columns, literal_binds=literal_binds, name=name) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/_typing.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_typing.py new file mode 100644 index 0000000..c861bae --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/_typing.py @@ -0,0 +1,457 @@ +# sql/_typing.py +# Copyright (C) 2022-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import operator +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generic +from typing import Iterable +from typing import Mapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import roles +from .. import exc +from .. import util +from ..inspection import Inspectable +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypeAlias + +if TYPE_CHECKING: + from datetime import date + from datetime import datetime + from datetime import time + from datetime import timedelta + from decimal import Decimal + from uuid import UUID + + from .base import Executable + from .compiler import Compiled + from .compiler import DDLCompiler + from .compiler import SQLCompiler + from .dml import UpdateBase + from .dml import ValuesBase + from .elements import ClauseElement + from .elements import ColumnElement + from .elements import KeyedColumnElement + from .elements import quoted_name + from .elements import SQLCoreOperations + from .elements import TextClause + from .lambdas import LambdaElement + from .roles import FromClauseRole + from .schema import Column + from .selectable import Alias + from .selectable import CTE + from .selectable import FromClause + from .selectable import Join + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import Selectable + from .selectable import SelectBase + from .selectable import Subquery + from .selectable import TableClause + from .sqltypes import TableValueType + from .sqltypes import TupleType + from .type_api import TypeEngine + from ..engine import Dialect + from ..util.typing import TypeGuard + +_T = TypeVar("_T", bound=Any) +_T_co = TypeVar("_T_co", bound=Any, covariant=True) + + +_CE = TypeVar("_CE", bound="ColumnElement[Any]") + +_CLE = TypeVar("_CLE", bound="ClauseElement") + + +class _HasClauseElement(Protocol, Generic[_T_co]): + """indicates a class that has a __clause_element__() method""" + + def __clause_element__(self) -> roles.ExpressionElementRole[_T_co]: ... + + +class _CoreAdapterProto(Protocol): + """protocol for the ClauseAdapter/ColumnAdapter.traverse() method.""" + + def __call__(self, obj: _CE) -> _CE: ... + + +class _HasDialect(Protocol): + """protocol for Engine/Connection-like objects that have dialect + attribute. + """ + + @property + def dialect(self) -> Dialect: ... + + +# match column types that are not ORM entities +_NOT_ENTITY = TypeVar( + "_NOT_ENTITY", + int, + str, + bool, + "datetime", + "date", + "time", + "timedelta", + "UUID", + float, + "Decimal", +) + +_MAYBE_ENTITY = TypeVar( + "_MAYBE_ENTITY", + roles.ColumnsClauseRole, + Literal["*", 1], + Type[Any], + Inspectable[_HasClauseElement[Any]], + _HasClauseElement[Any], +) + + +# convention: +# XYZArgument - something that the end user is passing to a public API method +# XYZElement - the internal representation that we use for the thing. +# the coercions system is responsible for converting from XYZArgument to +# XYZElement. + +_TextCoercedExpressionArgument = Union[ + str, + "TextClause", + "ColumnElement[_T]", + _HasClauseElement[_T], + roles.ExpressionElementRole[_T], +] + +_ColumnsClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], + roles.ColumnsClauseRole, + "SQLCoreOperations[_T]", + Literal["*", 1], + Type[_T], + Inspectable[_HasClauseElement[_T]], + _HasClauseElement[_T], +] +"""open-ended SELECT columns clause argument. + +Includes column expressions, tables, ORM mapped entities, a few literal values. + +This type is used for lists of columns / entities to be returned in result +sets; select(...), insert().returning(...), etc. + + +""" + +_TypedColumnClauseArgument = Union[ + roles.TypedColumnsClauseRole[_T], + "SQLCoreOperations[_T]", + Type[_T], +] + +_TP = TypeVar("_TP", bound=Tuple[Any, ...]) + +_T0 = TypeVar("_T0", bound=Any) +_T1 = TypeVar("_T1", bound=Any) +_T2 = TypeVar("_T2", bound=Any) +_T3 = TypeVar("_T3", bound=Any) +_T4 = TypeVar("_T4", bound=Any) +_T5 = TypeVar("_T5", bound=Any) +_T6 = TypeVar("_T6", bound=Any) +_T7 = TypeVar("_T7", bound=Any) +_T8 = TypeVar("_T8", bound=Any) +_T9 = TypeVar("_T9", bound=Any) + + +_ColumnExpressionArgument = Union[ + "ColumnElement[_T]", + _HasClauseElement[_T], + "SQLCoreOperations[_T]", + roles.ExpressionElementRole[_T], + Callable[[], "ColumnElement[_T]"], + "LambdaElement", +] +"See docs in public alias ColumnExpressionArgument." + +ColumnExpressionArgument: TypeAlias = _ColumnExpressionArgument[_T] +"""Narrower "column expression" argument. + +This type is used for all the other "column" kinds of expressions that +typically represent a single SQL column expression, not a set of columns the +way a table or ORM entity does. + +This includes ColumnElement, or ORM-mapped attributes that will have a +``__clause_element__()`` method, it also has the ExpressionElementRole +overall which brings in the TextClause object also. + +.. versionadded:: 2.0.13 + +""" + +_ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] + +_ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] + +_ByArgument = Union[ + Iterable[_ColumnExpressionOrStrLabelArgument[Any]], + _ColumnExpressionOrStrLabelArgument[Any], +] +"""Used for keyword-based ``order_by`` and ``partition_by`` parameters.""" + + +_InfoType = Dict[Any, Any] +"""the .info dictionary accepted and used throughout Core /ORM""" + +_FromClauseArgument = Union[ + roles.FromClauseRole, + Type[Any], + Inspectable[_HasClauseElement[Any]], + _HasClauseElement[Any], +] +"""A FROM clause, like we would send to select().select_from(). + +Also accommodates ORM entities and related constructs. + +""" + +_JoinTargetArgument = Union[_FromClauseArgument, roles.JoinTargetRole] +"""target for join() builds on _FromClauseArgument to include additional +join target roles such as those which come from the ORM. + +""" + +_OnClauseArgument = Union[_ColumnExpressionArgument[Any], roles.OnClauseRole] +"""target for an ON clause, includes additional roles such as those which +come from the ORM. + +""" + +_SelectStatementForCompoundArgument = Union[ + "SelectBase", roles.CompoundElementRole +] +"""SELECT statement acceptable by ``union()`` and other SQL set operations""" + +_DMLColumnArgument = Union[ + str, + _HasClauseElement[Any], + roles.DMLColumnRole, + "SQLCoreOperations[Any]", +] +"""A DML column expression. This is a "key" inside of insert().values(), +update().values(), and related. + +These are usually strings or SQL table columns. + +There's also edge cases like JSON expression assignment, which we would want +the DMLColumnRole to be able to accommodate. + +""" + +_DMLKey = TypeVar("_DMLKey", bound=_DMLColumnArgument) +_DMLColumnKeyMapping = Mapping[_DMLKey, Any] + + +_DDLColumnArgument = Union[str, "Column[Any]", roles.DDLConstraintColumnRole] +"""DDL column. + +used for :class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, etc. + +""" + +_DMLTableArgument = Union[ + "TableClause", + "Join", + "Alias", + "CTE", + Type[Any], + Inspectable[_HasClauseElement[Any]], + _HasClauseElement[Any], +] + +_PropagateAttrsType = util.immutabledict[str, Any] + +_TypeEngineArgument = Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] + +_EquivalentColumnMap = Dict["ColumnElement[Any]", Set["ColumnElement[Any]"]] + +_LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None] + +_AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]] + +if TYPE_CHECKING: + + def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ... + + def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: ... + + def is_named_from_clause( + t: FromClauseRole, + ) -> TypeGuard[NamedFromClause]: ... + + def is_column_element( + c: ClauseElement, + ) -> TypeGuard[ColumnElement[Any]]: ... + + def is_keyed_column_element( + c: ClauseElement, + ) -> TypeGuard[KeyedColumnElement[Any]]: ... + + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... + + def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: ... + + def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: ... + + def is_table_value_type( + t: TypeEngine[Any], + ) -> TypeGuard[TableValueType]: ... + + def is_selectable(t: Any) -> TypeGuard[Selectable]: ... + + def is_select_base( + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[SelectBase]: ... + + def is_select_statement( + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[Select[Any]]: ... + + def is_table(t: FromClause) -> TypeGuard[TableClause]: ... + + def is_subquery(t: FromClause) -> TypeGuard[Subquery]: ... + + def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: ... + +else: + is_sql_compiler = operator.attrgetter("is_sql") + is_ddl_compiler = operator.attrgetter("is_ddl") + is_named_from_clause = operator.attrgetter("named_with_column") + is_column_element = operator.attrgetter("_is_column_element") + is_keyed_column_element = operator.attrgetter("_is_keyed_column_element") + is_text_clause = operator.attrgetter("_is_text_clause") + is_from_clause = operator.attrgetter("_is_from_clause") + is_tuple_type = operator.attrgetter("_is_tuple_type") + is_table_value_type = operator.attrgetter("_is_table_value") + is_selectable = operator.attrgetter("is_selectable") + is_select_base = operator.attrgetter("_is_select_base") + is_select_statement = operator.attrgetter("_is_select_statement") + is_table = operator.attrgetter("_is_table") + is_subquery = operator.attrgetter("_is_subquery") + is_dml = operator.attrgetter("is_dml") + + +def has_schema_attr(t: FromClauseRole) -> TypeGuard[TableClause]: + return hasattr(t, "schema") + + +def is_quoted_name(s: str) -> TypeGuard[quoted_name]: + return hasattr(s, "quote") + + +def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement[Any]]: + return hasattr(s, "__clause_element__") + + +def is_insert_update(c: ClauseElement) -> TypeGuard[ValuesBase]: + return c.is_dml and (c.is_insert or c.is_update) # type: ignore + + +def _no_kw() -> exc.ArgumentError: + return exc.ArgumentError( + "Additional keyword arguments are not accepted by this " + "function/method. The presence of **kw is for pep-484 typing purposes" + ) + + +def _unexpected_kw(methname: str, kw: Dict[str, Any]) -> NoReturn: + k = list(kw)[0] + raise TypeError(f"{methname} got an unexpected keyword argument '{k}'") + + +@overload +def Nullable( + val: "SQLCoreOperations[_T]", +) -> "SQLCoreOperations[Optional[_T]]": ... + + +@overload +def Nullable( + val: roles.ExpressionElementRole[_T], +) -> roles.ExpressionElementRole[Optional[_T]]: ... + + +@overload +def Nullable(val: Type[_T]) -> Type[Optional[_T]]: ... + + +def Nullable( + val: _TypedColumnClauseArgument[_T], +) -> _TypedColumnClauseArgument[Optional[_T]]: + """Types a column or ORM class as nullable. + + This can be used in select and other contexts to express that the value of + a column can be null, for example due to an outer join:: + + stmt1 = select(A, Nullable(B)).outerjoin(A.bs) + stmt2 = select(A.data, Nullable(B.data)).outerjoin(A.bs) + + At runtime this method returns the input unchanged. + + .. versionadded:: 2.0.20 + """ + return val + + +@overload +def NotNullable( + val: "SQLCoreOperations[Optional[_T]]", +) -> "SQLCoreOperations[_T]": ... + + +@overload +def NotNullable( + val: roles.ExpressionElementRole[Optional[_T]], +) -> roles.ExpressionElementRole[_T]: ... + + +@overload +def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: ... + + +@overload +def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: ... + + +def NotNullable( + val: Union[_TypedColumnClauseArgument[Optional[_T]], Optional[Type[_T]]], +) -> _TypedColumnClauseArgument[_T]: + """Types a column or ORM class as not nullable. + + This can be used in select and other contexts to express that the value of + a column cannot be null, for example due to a where condition on a + nullable column:: + + stmt = select(NotNullable(A.value)).where(A.value.is_not(None)) + + At runtime this method returns the input unchanged. + + .. versionadded:: 2.0.20 + """ + return val # type: ignore diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/annotation.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/annotation.py new file mode 100644 index 0000000..db382b8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/annotation.py @@ -0,0 +1,585 @@ +# sql/annotation.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""The :class:`.Annotated` class and related routines; creates hash-equivalent +copies of SQL constructs which contain context-specific markers and +associations. + +Note that the :class:`.Annotated` concept as implemented in this module is not +related in any way to the pep-593 concept of "Annotated". + + +""" + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar + +from . import operators +from .cache_key import HasCacheKey +from .visitors import anon_map +from .visitors import ExternallyTraversible +from .visitors import InternalTraversal +from .. import util +from ..util.typing import Literal +from ..util.typing import Self + +if TYPE_CHECKING: + from .base import _EntityNamespace + from .visitors import _TraverseInternalsType + +_AnnotationDict = Mapping[str, Any] + +EMPTY_ANNOTATIONS: util.immutabledict[str, Any] = util.EMPTY_DICT + + +class SupportsAnnotations(ExternallyTraversible): + __slots__ = () + + _annotations: util.immutabledict[str, Any] = EMPTY_ANNOTATIONS + + proxy_set: util.generic_fn_descriptor[FrozenSet[Any]] + + _is_immutable: bool + + def _annotate(self, values: _AnnotationDict) -> Self: + raise NotImplementedError() + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + raise NotImplementedError() + + @util.memoized_property + def _annotations_cache_key(self) -> Tuple[Any, ...]: + anon_map_ = anon_map() + + return self._gen_annotations_cache_key(anon_map_) + + def _gen_annotations_cache_key( + self, anon_map: anon_map + ) -> Tuple[Any, ...]: + return ( + "_annotations", + tuple( + ( + key, + ( + value._gen_cache_key(anon_map, []) + if isinstance(value, HasCacheKey) + else value + ), + ) + for key, value in [ + (key, self._annotations[key]) + for key in sorted(self._annotations) + ] + ), + ) + + +class SupportsWrappingAnnotations(SupportsAnnotations): + __slots__ = () + + _constructor: Callable[..., SupportsWrappingAnnotations] + + if TYPE_CHECKING: + + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: ... + + def _annotate(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + return Annotated._as_annotated_instance(self, values) # type: ignore + + def _with_annotations(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + return Annotated._as_annotated_instance(self, values) # type: ignore + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + """return a copy of this :class:`_expression.ClauseElement` + with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + s = self._clone() + return s + else: + return self + + +class SupportsCloneAnnotations(SupportsWrappingAnnotations): + # SupportsCloneAnnotations extends from SupportsWrappingAnnotations + # to support the structure of having the base ClauseElement + # be a subclass of SupportsWrappingAnnotations. Any ClauseElement + # subclass that wants to extend from SupportsCloneAnnotations + # will inherently also be subclassing SupportsWrappingAnnotations, so + # make that specific here. + + if not typing.TYPE_CHECKING: + __slots__ = () + + _clone_annotations_traverse_internals: _TraverseInternalsType = [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + def _annotate(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + new = self._clone() + new._annotations = new._annotations.union(values) + new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) + return new + + def _with_annotations(self, values: _AnnotationDict) -> Self: + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + new = self._clone() + new._annotations = util.immutabledict(values) + new.__dict__.pop("_annotations_cache_key", None) + new.__dict__.pop("_generate_cache_key", None) + return new + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> SupportsAnnotations: ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = False, + ) -> SupportsAnnotations: + """return a copy of this :class:`_expression.ClauseElement` + with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone or self._annotations: + # clone is used when we are also copying + # the expression for a deep deannotation + new = self._clone() + new._annotations = util.immutabledict() + new.__dict__.pop("_annotations_cache_key", None) + return new + else: + return self + + +class Annotated(SupportsAnnotations): + """clones a SupportsAnnotations and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __eq__() of the original element so that it takes its place + in hashed collections. + + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + .. note:: The rationale for Annotated producing a brand new class, + rather than placing the functionality directly within ClauseElement, + is **performance**. The __hash__() method is absent on plain + ClauseElement which leads to significantly reduced function call + overhead, as the use of sets and dictionaries against ClauseElement + objects is prevalent, but most are not "annotated". + + """ + + _is_column_operators = False + + @classmethod + def _as_annotated_instance( + cls, element: SupportsWrappingAnnotations, values: _AnnotationDict + ) -> Annotated: + try: + cls = annotated_classes[element.__class__] + except KeyError: + cls = _new_annotation_type(element.__class__, cls) + return cls(element, values) + + _annotations: util.immutabledict[str, Any] + __element: SupportsWrappingAnnotations + _hash: int + + def __new__(cls: Type[Self], *args: Any) -> Self: + return object.__new__(cls) + + def __init__( + self, element: SupportsWrappingAnnotations, values: _AnnotationDict + ): + self.__dict__ = element.__dict__.copy() + self.__dict__.pop("_annotations_cache_key", None) + self.__dict__.pop("_generate_cache_key", None) + self.__element = element + self._annotations = util.immutabledict(values) + self._hash = hash(element) + + def _annotate(self, values: _AnnotationDict) -> Self: + _values = self._annotations.union(values) + new = self._with_annotations(_values) + return new + + def _with_annotations(self, values: _AnnotationDict) -> Self: + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone.__dict__.pop("_annotations_cache_key", None) + clone.__dict__.pop("_generate_cache_key", None) + clone._annotations = util.immutabledict(values) + return clone + + @overload + def _deannotate( + self, + values: Literal[None] = ..., + clone: bool = ..., + ) -> Self: ... + + @overload + def _deannotate( + self, + values: Sequence[str] = ..., + clone: bool = ..., + ) -> Annotated: ... + + def _deannotate( + self, + values: Optional[Sequence[str]] = None, + clone: bool = True, + ) -> SupportsAnnotations: + if values is None: + return self.__element + else: + return self._with_annotations( + util.immutabledict( + { + key: value + for key, value in self._annotations.items() + if key not in values + } + ) + ) + + if not typing.TYPE_CHECKING: + # manually proxy some methods that need extra attention + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> Any: + return self.__element.__class__._compiler_dispatch( + self, visitor, **kw + ) + + @property + def _constructor(self): + return self.__element._constructor + + def _clone(self, **kw: Any) -> Self: + clone = self.__element._clone(**kw) + if clone is self.__element: + # detect immutable, don't change anything + return self + else: + # update the clone with any changes that have occurred + # to this object's __dict__. + clone.__dict__.update(self.__dict__) + return self.__class__(clone, self._annotations) + + def __reduce__(self) -> Tuple[Type[Annotated], Tuple[Any, ...]]: + return self.__class__, (self.__element, self._annotations) + + def __hash__(self) -> int: + return self._hash + + def __eq__(self, other: Any) -> bool: + if self._is_column_operators: + return self.__element.__class__.__eq__(self, other) + else: + return hash(other) == hash(self) + + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + if "entity_namespace" in self._annotations: + return cast( + SupportsWrappingAnnotations, + self._annotations["entity_namespace"], + ).entity_namespace + else: + return self.__element.entity_namespace + + +# hard-generate Annotated subclasses. this technique +# is used instead of on-the-fly types (i.e. type.__new__()) +# so that the resulting objects are pickleable; additionally, other +# decisions can be made up front about the type of object being annotated +# just once per class rather than per-instance. +annotated_classes: Dict[Type[SupportsWrappingAnnotations], Type[Annotated]] = ( + {} +) + +_SA = TypeVar("_SA", bound="SupportsAnnotations") + + +def _safe_annotate(to_annotate: _SA, annotations: _AnnotationDict) -> _SA: + try: + _annotate = to_annotate._annotate + except AttributeError: + # skip objects that don't actually have an `_annotate` + # attribute, namely QueryableAttribute inside of a join + # condition + return to_annotate + else: + return _annotate(annotations) + + +def _deep_annotate( + element: _SA, + annotations: _AnnotationDict, + exclude: Optional[Sequence[SupportsAnnotations]] = None, + *, + detect_subquery_cols: bool = False, + ind_cols_on_fromclause: bool = False, + annotate_callable: Optional[ + Callable[[SupportsAnnotations, _AnnotationDict], SupportsAnnotations] + ] = None, +) -> _SA: + """Deep copy the given ClauseElement, annotating each element + with the given annotations dictionary. + + Elements within the exclude collection will be cloned but not annotated. + + """ + + # annotated objects hack the __hash__() method so if we want to + # uniquely process them we have to use id() + + cloned_ids: Dict[int, SupportsAnnotations] = {} + + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: + # ind_cols_on_fromclause means make sure an AnnotatedFromClause + # has its own .c collection independent of that which its proxying. + # this is used specifically by orm.LoaderCriteriaOption to break + # a reference cycle that it's otherwise prone to building, + # see test_relationship_criteria-> + # test_loader_criteria_subquery_w_same_entity. logic here was + # changed for #8796 and made explicit; previously it occurred + # by accident + + kw["detect_subquery_cols"] = detect_subquery_cols + id_ = id(elem) + + if id_ in cloned_ids: + return cloned_ids[id_] + + if ( + exclude + and hasattr(elem, "proxy_set") + and elem.proxy_set.intersection(exclude) + ): + newelem = elem._clone(clone=clone, **kw) + elif annotations != elem._annotations: + if detect_subquery_cols and elem._is_immutable: + to_annotate = elem._clone(clone=clone, **kw) + else: + to_annotate = elem + if annotate_callable: + newelem = annotate_callable(to_annotate, annotations) + else: + newelem = _safe_annotate(to_annotate, annotations) + else: + newelem = elem + + newelem._copy_internals( + clone=clone, ind_cols_on_fromclause=ind_cols_on_fromclause + ) + + cloned_ids[id_] = newelem + return newelem + + if element is not None: + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles + return element + + +@overload +def _deep_deannotate( + element: Literal[None], values: Optional[Sequence[str]] = None +) -> Literal[None]: ... + + +@overload +def _deep_deannotate( + element: _SA, values: Optional[Sequence[str]] = None +) -> _SA: ... + + +def _deep_deannotate( + element: Optional[_SA], values: Optional[Sequence[str]] = None +) -> Optional[_SA]: + """Deep copy the given element, removing annotations.""" + + cloned: Dict[Any, SupportsAnnotations] = {} + + def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations: + key: Any + if values: + key = id(elem) + else: + key = elem + + if key not in cloned: + newelem = elem._deannotate(values=values, clone=True) + newelem._copy_internals(clone=clone) + cloned[key] = newelem + return newelem + else: + return cloned[key] + + if element is not None: + element = cast(_SA, clone(element)) + clone = None # type: ignore # remove gc cycles + return element + + +def _shallow_annotate(element: _SA, annotations: _AnnotationDict) -> _SA: + """Annotate the given ClauseElement and copy its internals so that + internal objects refer to the new annotated object. + + Basically used to apply a "don't traverse" annotation to a + selectable, without digging throughout the whole + structure wasting time. + """ + element = element._annotate(annotations) + element._copy_internals() + return element + + +def _new_annotation_type( + cls: Type[SupportsWrappingAnnotations], base_cls: Type[Annotated] +) -> Type[Annotated]: + """Generates a new class that subclasses Annotated and proxies a given + element type. + + """ + if issubclass(cls, Annotated): + return cls + elif cls in annotated_classes: + return annotated_classes[cls] + + for super_ in cls.__mro__: + # check if an Annotated subclass more specific than + # the given base_cls is already registered, such + # as AnnotatedColumnElement. + if super_ in annotated_classes: + base_cls = annotated_classes[super_] + break + + annotated_classes[cls] = anno_cls = cast( + Type[Annotated], + type("Annotated%s" % cls.__name__, (base_cls, cls), {}), + ) + globals()["Annotated%s" % cls.__name__] = anno_cls + + if "_traverse_internals" in cls.__dict__: + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + elif cls.__dict__.get("inherit_cache", False): + anno_cls._traverse_internals = list(cls._traverse_internals) + [ + ("_annotations", InternalTraversal.dp_annotations_key) + ] + + # some classes include this even if they have traverse_internals + # e.g. BindParameter, add it if present. + if cls.__dict__.get("inherit_cache", False): + anno_cls.inherit_cache = True # type: ignore + elif "inherit_cache" in cls.__dict__: + anno_cls.inherit_cache = cls.__dict__["inherit_cache"] # type: ignore + + anno_cls._is_column_operators = issubclass(cls, operators.ColumnOperators) + + return anno_cls + + +def _prepare_annotations( + target_hierarchy: Type[SupportsWrappingAnnotations], + base_cls: Type[Annotated], +) -> None: + for cls in util.walk_subclasses(target_hierarchy): + _new_annotation_type(cls, base_cls) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/base.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/base.py new file mode 100644 index 0000000..5eb32e3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/base.py @@ -0,0 +1,2180 @@ +# sql/base.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Foundational utilities common to many sql modules. + +""" + + +from __future__ import annotations + +import collections +from enum import Enum +import itertools +from itertools import zip_longest +import operator +import re +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import roles +from . import visitors +from .cache_key import HasCacheKey # noqa +from .cache_key import MemoizedHasCacheKey # noqa +from .traversals import HasCopyInternals # noqa +from .visitors import ClauseVisitor +from .visitors import ExtendedInternalTraversal +from .visitors import ExternallyTraversible +from .visitors import InternalTraversal +from .. import event +from .. import exc +from .. import util +from ..util import HasMemoized as HasMemoized +from ..util import hybridmethod +from ..util import typing as compat_typing +from ..util.typing import Protocol +from ..util.typing import Self +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from . import coercions + from . import elements + from . import type_api + from ._orm_types import DMLStrategyArgument + from ._orm_types import SynchronizeSessionArgument + from ._typing import _CLE + from .elements import BindParameter + from .elements import ClauseList + from .elements import ColumnClause # noqa + from .elements import ColumnElement + from .elements import KeyedColumnElement + from .elements import NamedColumn + from .elements import SQLCoreOperations + from .elements import TextClause + from .schema import Column + from .schema import DefaultGenerator + from .selectable import _JoinTargetElement + from .selectable import _SelectIterable + from .selectable import FromClause + from ..engine import Connection + from ..engine import CursorResult + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _ImmutableExecuteOptions + from ..engine.interfaces import CacheStats + from ..engine.interfaces import Compiled + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import CoreExecuteOptionsParameter + from ..engine.interfaces import Dialect + from ..engine.interfaces import IsolationLevel + from ..engine.interfaces import SchemaTranslateMapType + from ..event import dispatcher + +if not TYPE_CHECKING: + coercions = None # noqa + elements = None # noqa + type_api = None # noqa + + +class _NoArg(Enum): + NO_ARG = 0 + + def __repr__(self): + return f"_NoArg.{self.name}" + + +NO_ARG = _NoArg.NO_ARG + + +class _NoneName(Enum): + NONE_NAME = 0 + """indicate a 'deferred' name that was ultimately the value None.""" + + +_NONE_NAME = _NoneName.NONE_NAME + +_T = TypeVar("_T", bound=Any) + +_Fn = TypeVar("_Fn", bound=Callable[..., Any]) + +_AmbiguousTableNameMap = MutableMapping[str, str] + + +class _DefaultDescriptionTuple(NamedTuple): + arg: Any + is_scalar: Optional[bool] + is_callable: Optional[bool] + is_sentinel: Optional[bool] + + @classmethod + def _from_column_default( + cls, default: Optional[DefaultGenerator] + ) -> _DefaultDescriptionTuple: + return ( + _DefaultDescriptionTuple( + default.arg, # type: ignore + default.is_scalar, + default.is_callable, + default.is_sentinel, + ) + if default + and ( + default.has_arg + or (not default.for_update and default.is_sentinel) + ) + else _DefaultDescriptionTuple(None, None, None, None) + ) + + +_never_select_column = operator.attrgetter("_omit_from_statements") + + +class _EntityNamespace(Protocol): + def __getattr__(self, key: str) -> SQLCoreOperations[Any]: ... + + +class _HasEntityNamespace(Protocol): + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: ... + + +def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: + return hasattr(element, "entity_namespace") + + +# Remove when https://github.com/python/mypy/issues/14640 will be fixed +_Self = TypeVar("_Self", bound=Any) + + +class Immutable: + """mark a ClauseElement as 'immutable' when expressions are cloned. + + "immutable" objects refers to the "mutability" of an object in the + context of SQL DQL and DML generation. Such as, in DQL, one can + compose a SELECT or subquery of varied forms, but one cannot modify + the structure of a specific table or column within DQL. + :class:`.Immutable` is mostly intended to follow this concept, and as + such the primary "immutable" objects are :class:`.ColumnClause`, + :class:`.Column`, :class:`.TableClause`, :class:`.Table`. + + """ + + __slots__ = () + + _is_immutable = True + + def unique_params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def params(self, *optionaldict, **kwargs): + raise NotImplementedError("Immutable objects do not support copying") + + def _clone(self: _Self, **kw: Any) -> _Self: + return self + + def _copy_internals( + self, *, omit_attrs: Iterable[str] = (), **kw: Any + ) -> None: + pass + + +class SingletonConstant(Immutable): + """Represent SQL constants like NULL, TRUE, FALSE""" + + _is_singleton_constant = True + + _singleton: SingletonConstant + + def __new__(cls: _T, *arg: Any, **kw: Any) -> _T: + return cast(_T, cls._singleton) + + @util.non_memoized_property + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + raise NotImplementedError() + + @classmethod + def _create_singleton(cls): + obj = object.__new__(cls) + obj.__init__() # type: ignore + + # for a long time this was an empty frozenset, meaning + # a SingletonConstant would never be a "corresponding column" in + # a statement. This referred to #6259. However, in #7154 we see + # that we do in fact need "correspondence" to work when matching cols + # in result sets, so the non-correspondence was moved to a more + # specific level when we are actually adapting expressions for SQL + # render only. + obj.proxy_set = frozenset([obj]) + cls._singleton = obj + + +def _from_objects( + *elements: Union[ + ColumnElement[Any], FromClause, TextClause, _JoinTargetElement + ] +) -> Iterator[FromClause]: + return itertools.chain.from_iterable( + [element._from_objects for element in elements] + ) + + +def _select_iterables( + elements: Iterable[roles.ColumnsClauseRole], +) -> _SelectIterable: + """expand tables into individual columns in the + given list of column expressions. + + """ + return itertools.chain.from_iterable( + [c._select_iterable for c in elements] + ) + + +_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType") + + +class _GenerativeType(compat_typing.Protocol): + def _generate(self) -> Self: ... + + +def _generative(fn: _Fn) -> _Fn: + """non-caching _generative() decorator. + + This is basically the legacy decorator that copies the object and + runs a method on the new copy. + + """ + + @util.decorator + def _generative( + fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any + ) -> _SelfGenerativeType: + """Mark a method as generative.""" + + self = self._generate() + x = fn(self, *args, **kw) + assert x is self, "generative methods must return self" + return self + + decorated = _generative(fn) + decorated.non_generative = fn # type: ignore + return decorated + + +def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: + msgs = kw.pop("msgs", {}) + + defaults = kw.pop("defaults", {}) + + getters = [ + (name, operator.attrgetter(name), defaults.get(name, None)) + for name in names + ] + + @util.decorator + def check(fn, *args, **kw): + # make pylance happy by not including "self" in the argument + # list + self = args[0] + args = args[1:] + for name, getter, default_ in getters: + if getter(self) is not default_: + msg = msgs.get( + name, + "Method %s() has already been invoked on this %s construct" + % (fn.__name__, self.__class__), + ) + raise exc.InvalidRequestError(msg) + return fn(self, *args, **kw) + + return check + + +def _clone(element, **kw): + return element._clone(**kw) + + +def _expand_cloned( + elements: Iterable[_CLE], +) -> Iterable[_CLE]: + """expand the given set of ClauseElements to be the set of all 'cloned' + predecessors. + + """ + # TODO: cython candidate + return itertools.chain(*[x._cloned_set for x in elements]) + + +def _de_clone( + elements: Iterable[_CLE], +) -> Iterable[_CLE]: + for x in elements: + while x._is_clone_of is not None: + x = x._is_clone_of + yield x + + +def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: + """return the intersection of sets a and b, counting + any overlap between 'cloned' predecessors. + + The returned set is in terms of the entities present within 'a'. + + """ + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} + + +def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) + return { + elem for elem in a if not all_overlap.intersection(elem._cloned_set) + } + + +class _DialectArgView(MutableMapping[str, Any]): + """A dictionary view of dialect-level arguments in the form + _. + + """ + + def __init__(self, obj): + self.obj = obj + + def _key(self, key): + try: + dialect, value_key = key.split("_", 1) + except ValueError as err: + raise KeyError(key) from err + else: + return dialect, value_key + + def __getitem__(self, key): + dialect, value_key = self._key(key) + + try: + opt = self.obj.dialect_options[dialect] + except exc.NoSuchModuleError as err: + raise KeyError(key) from err + else: + return opt[value_key] + + def __setitem__(self, key, value): + try: + dialect, value_key = self._key(key) + except KeyError as err: + raise exc.ArgumentError( + "Keys must be of the form _" + ) from err + else: + self.obj.dialect_options[dialect][value_key] = value + + def __delitem__(self, key): + dialect, value_key = self._key(key) + del self.obj.dialect_options[dialect][value_key] + + def __len__(self): + return sum( + len(args._non_defaults) + for args in self.obj.dialect_options.values() + ) + + def __iter__(self): + return ( + "%s_%s" % (dialect_name, value_name) + for dialect_name in self.obj.dialect_options + for value_name in self.obj.dialect_options[ + dialect_name + ]._non_defaults + ) + + +class _DialectArgDict(MutableMapping[str, Any]): + """A dictionary view of dialect-level arguments for a specific + dialect. + + Maintains a separate collection of user-specified arguments + and dialect-specified default arguments. + + """ + + def __init__(self): + self._non_defaults = {} + self._defaults = {} + + def __len__(self): + return len(set(self._non_defaults).union(self._defaults)) + + def __iter__(self): + return iter(set(self._non_defaults).union(self._defaults)) + + def __getitem__(self, key): + if key in self._non_defaults: + return self._non_defaults[key] + else: + return self._defaults[key] + + def __setitem__(self, key, value): + self._non_defaults[key] = value + + def __delitem__(self, key): + del self._non_defaults[key] + + +@util.preload_module("sqlalchemy.dialects") +def _kw_reg_for_dialect(dialect_name): + dialect_cls = util.preloaded.dialects.registry.load(dialect_name) + if dialect_cls.construct_arguments is None: + return None + return dict(dialect_cls.construct_arguments) + + +class DialectKWArgs: + """Establish the ability for a class to have dialect-specific arguments + with defaults and constructor validation. + + The :class:`.DialectKWArgs` interacts with the + :attr:`.DefaultDialect.construct_arguments` present on a dialect. + + .. seealso:: + + :attr:`.DefaultDialect.construct_arguments` + + """ + + __slots__ = () + + _dialect_kwargs_traverse_internals = [ + ("dialect_options", InternalTraversal.dp_dialect_options) + ] + + @classmethod + def argument_for(cls, dialect_name, argument_name, default): + """Add a new kind of dialect-specific keyword argument for this class. + + E.g.:: + + Index.argument_for("mydialect", "length", None) + + some_index = Index('a', 'b', mydialect_length=5) + + The :meth:`.DialectKWArgs.argument_for` method is a per-argument + way adding extra arguments to the + :attr:`.DefaultDialect.construct_arguments` dictionary. This + dictionary provides a list of argument names accepted by various + schema-level constructs on behalf of a dialect. + + New dialects should typically specify this dictionary all at once as a + data member of the dialect class. The use case for ad-hoc addition of + argument names is typically for end-user code that is also using + a custom compilation scheme which consumes the additional arguments. + + :param dialect_name: name of a dialect. The dialect must be + locatable, else a :class:`.NoSuchModuleError` is raised. The + dialect must also include an existing + :attr:`.DefaultDialect.construct_arguments` collection, indicating + that it participates in the keyword-argument validation and default + system, else :class:`.ArgumentError` is raised. If the dialect does + not include this collection, then any keyword argument can be + specified on behalf of this dialect already. All dialects packaged + within SQLAlchemy include this collection, however for third party + dialects, support may vary. + + :param argument_name: name of the parameter. + + :param default: default value of the parameter. + + """ + + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + if construct_arg_dictionary is None: + raise exc.ArgumentError( + "Dialect '%s' does have keyword-argument " + "validation and defaults enabled configured" % dialect_name + ) + if cls not in construct_arg_dictionary: + construct_arg_dictionary[cls] = {} + construct_arg_dictionary[cls][argument_name] = default + + @util.memoized_property + def dialect_kwargs(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + The arguments are present here in their original ``_`` + format. Only arguments that were actually passed are included; + unlike the :attr:`.DialectKWArgs.dialect_options` collection, which + contains all options known by this dialect including defaults. + + The collection is also writable; keys are accepted of the + form ``_`` where the value will be assembled + into the list of options. + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_options` - nested dictionary form + + """ + return _DialectArgView(self) + + @property + def kwargs(self): + """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" + return self.dialect_kwargs + + _kw_registry = util.PopulateDict(_kw_reg_for_dialect) + + def _kw_reg_for_dialect_cls(self, dialect_name): + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] + d = _DialectArgDict() + + if construct_arg_dictionary is None: + d._defaults.update({"*": None}) + else: + for cls in reversed(self.__class__.__mro__): + if cls in construct_arg_dictionary: + d._defaults.update(construct_arg_dictionary[cls]) + return d + + @util.memoized_property + def dialect_options(self): + """A collection of keyword arguments specified as dialect-specific + options to this construct. + + This is a two-level nested registry, keyed to ```` + and ````. For example, the ``postgresql_where`` + argument would be locatable as:: + + arg = my_object.dialect_options['postgresql']['where'] + + .. versionadded:: 0.9.2 + + .. seealso:: + + :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form + + """ + + return util.PopulateDict( + util.portable_instancemethod(self._kw_reg_for_dialect_cls) + ) + + def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None: + # validate remaining kwargs that they all specify DB prefixes + + if not kwargs: + return + + for k in kwargs: + m = re.match("^(.+?)_(.+)$", k) + if not m: + raise TypeError( + "Additional arguments should be " + "named _, got '%s'" % k + ) + dialect_name, arg_name = m.group(1, 2) + + try: + construct_arg_dictionary = self.dialect_options[dialect_name] + except exc.NoSuchModuleError: + util.warn( + "Can't validate argument %r; can't " + "locate any SQLAlchemy dialect named %r" + % (k, dialect_name) + ) + self.dialect_options[dialect_name] = d = _DialectArgDict() + d._defaults.update({"*": None}) + d._non_defaults[arg_name] = kwargs[k] + else: + if ( + "*" not in construct_arg_dictionary + and arg_name not in construct_arg_dictionary + ): + raise exc.ArgumentError( + "Argument %r is not accepted by " + "dialect %r on behalf of %r" + % (k, dialect_name, self.__class__) + ) + else: + construct_arg_dictionary[arg_name] = kwargs[k] + + +class CompileState: + """Produces additional object state necessary for a statement to be + compiled. + + the :class:`.CompileState` class is at the base of classes that assemble + state for a particular statement object that is then used by the + compiler. This process is essentially an extension of the process that + the SQLCompiler.visit_XYZ() method takes, however there is an emphasis + on converting raw user intent into more organized structures rather than + producing string output. The top-level :class:`.CompileState` for the + statement being executed is also accessible when the execution context + works with invoking the statement and collecting results. + + The production of :class:`.CompileState` is specific to the compiler, such + as within the :meth:`.SQLCompiler.visit_insert`, + :meth:`.SQLCompiler.visit_select` etc. methods. These methods are also + responsible for associating the :class:`.CompileState` with the + :class:`.SQLCompiler` itself, if the statement is the "toplevel" statement, + i.e. the outermost SQL statement that's actually being executed. + There can be other :class:`.CompileState` objects that are not the + toplevel, such as when a SELECT subquery or CTE-nested + INSERT/UPDATE/DELETE is generated. + + .. versionadded:: 1.4 + + """ + + __slots__ = ("statement", "_ambiguous_table_name_map") + + plugins: Dict[Tuple[str, str], Type[CompileState]] = {} + + _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] + + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + # factory construction. + + if statement._propagate_attrs: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", "default" + ) + klass = cls.plugins.get( + (plugin_name, statement._effective_plugin_target), None + ) + if klass is None: + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] + + else: + klass = cls.plugins[ + ("default", statement._effective_plugin_target) + ] + + if klass is cls: + return cls(statement, compiler, **kw) + else: + return klass.create_for_statement(statement, compiler, **kw) + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + @classmethod + def get_plugin_class( + cls, statement: Executable + ) -> Optional[Type[CompileState]]: + plugin_name = statement._propagate_attrs.get( + "compile_state_plugin", None + ) + + if plugin_name: + key = (plugin_name, statement._effective_plugin_target) + if key in cls.plugins: + return cls.plugins[key] + + # there's no case where we call upon get_plugin_class() and want + # to get None back, there should always be a default. return that + # if there was no plugin-specific class (e.g. Insert with "orm" + # plugin) + try: + return cls.plugins[("default", statement._effective_plugin_target)] + except KeyError: + return None + + @classmethod + def _get_plugin_class_for_plugin( + cls, statement: Executable, plugin_name: str + ) -> Optional[Type[CompileState]]: + try: + return cls.plugins[ + (plugin_name, statement._effective_plugin_target) + ] + except KeyError: + return None + + @classmethod + def plugin_for( + cls, plugin_name: str, visit_name: str + ) -> Callable[[_Fn], _Fn]: + def decorate(cls_to_decorate): + cls.plugins[(plugin_name, visit_name)] = cls_to_decorate + return cls_to_decorate + + return decorate + + +class Generative(HasMemoized): + """Provide a method-chaining pattern in conjunction with the + @_generative decorator.""" + + def _generate(self) -> Self: + skip = self._memoized_keys + cls = self.__class__ + s = cls.__new__(cls) + if skip: + # ensure this iteration remains atomic + s.__dict__ = { + k: v for k, v in self.__dict__.copy().items() if k not in skip + } + else: + s.__dict__ = self.__dict__.copy() + return s + + +class InPlaceGenerative(HasMemoized): + """Provide a method-chaining pattern in conjunction with the + @_generative decorator that mutates in place.""" + + __slots__ = () + + def _generate(self): + skip = self._memoized_keys + # note __dict__ needs to be in __slots__ if this is used + for k in skip: + self.__dict__.pop(k, None) + return self + + +class HasCompileState(Generative): + """A class that has a :class:`.CompileState` associated with it.""" + + _compile_state_plugin: Optional[Type[CompileState]] = None + + _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT + + _compile_state_factory = CompileState.create_for_statement + + +class _MetaOptions(type): + """metaclass for the Options class. + + This metaclass is actually necessary despite the availability of the + ``__init_subclass__()`` hook as this type also provides custom class-level + behavior for the ``__add__()`` method. + + """ + + _cache_attrs: Tuple[str, ...] + + def __add__(self, other): + o1 = self() + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + + o1.__dict__.update(other) + return o1 + + if TYPE_CHECKING: + + def __getattr__(self, key: str) -> Any: ... + + def __setattr__(self, key: str, value: Any) -> None: ... + + def __delattr__(self, key: str) -> None: ... + + +class Options(metaclass=_MetaOptions): + """A cacheable option dictionary with defaults.""" + + __slots__ = () + + _cache_attrs: Tuple[str, ...] + + def __init_subclass__(cls) -> None: + dict_ = cls.__dict__ + cls._cache_attrs = tuple( + sorted( + d + for d in dict_ + if not d.startswith("__") + and d not in ("_cache_key_traversal",) + ) + ) + super().__init_subclass__() + + def __init__(self, **kw): + self.__dict__.update(kw) + + def __add__(self, other): + o1 = self.__class__.__new__(self.__class__) + o1.__dict__.update(self.__dict__) + + if set(other).difference(self._cache_attrs): + raise TypeError( + "dictionary contains attributes not covered by " + "Options class %s: %r" + % (self, set(other).difference(self._cache_attrs)) + ) + + o1.__dict__.update(other) + return o1 + + def __eq__(self, other): + # TODO: very inefficient. This is used only in test suites + # right now. + for a, b in zip_longest(self._cache_attrs, other._cache_attrs): + if getattr(self, a) != getattr(other, b): + return False + return True + + def __repr__(self): + # TODO: fairly inefficient, used only in debugging right now. + + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join( + "%s=%r" % (k, self.__dict__[k]) + for k in self._cache_attrs + if k in self.__dict__ + ), + ) + + @classmethod + def isinstance(cls, klass: Type[Any]) -> bool: + return issubclass(cls, klass) + + @hybridmethod + def add_to_element(self, name, value): + return self + {name: getattr(self, name) + value} + + @hybridmethod + def _state_dict_inst(self) -> Mapping[str, Any]: + return self.__dict__ + + _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT + + @_state_dict_inst.classlevel + def _state_dict(cls) -> Mapping[str, Any]: + return cls._state_dict_const + + @classmethod + def safe_merge(cls, other): + d = other._state_dict() + + # only support a merge with another object of our class + # and which does not have attrs that we don't. otherwise + # we risk having state that might not be part of our cache + # key strategy + + if ( + cls is not other.__class__ + and other._cache_attrs + and set(other._cache_attrs).difference(cls._cache_attrs) + ): + raise TypeError( + "other element %r is not empty, is not of type %s, " + "and contains attributes not covered here %r" + % ( + other, + cls, + set(other._cache_attrs).difference(cls._cache_attrs), + ) + ) + return cls + d + + @classmethod + def from_execution_options( + cls, key, attrs, exec_options, statement_exec_options + ): + """process Options argument in terms of execution options. + + + e.g.:: + + ( + load_options, + execution_options, + ) = QueryContext.default_load_options.from_execution_options( + "_sa_orm_load_options", + { + "populate_existing", + "autoflush", + "yield_per" + }, + execution_options, + statement._execution_options, + ) + + get back the Options and refresh "_sa_orm_load_options" in the + exec options dict w/ the Options as well + + """ + + # common case is that no options we are looking for are + # in either dictionary, so cancel for that first + check_argnames = attrs.intersection( + set(exec_options).union(statement_exec_options) + ) + + existing_options = exec_options.get(key, cls) + + if check_argnames: + result = {} + for argname in check_argnames: + local = "_" + argname + if argname in exec_options: + result[local] = exec_options[argname] + elif argname in statement_exec_options: + result[local] = statement_exec_options[argname] + + new_options = existing_options + result + exec_options = util.immutabledict().merge_with( + exec_options, {key: new_options} + ) + return new_options, exec_options + + else: + return existing_options, exec_options + + if TYPE_CHECKING: + + def __getattr__(self, key: str) -> Any: ... + + def __setattr__(self, key: str, value: Any) -> None: ... + + def __delattr__(self, key: str) -> None: ... + + +class CacheableOptions(Options, HasCacheKey): + __slots__ = () + + @hybridmethod + def _gen_cache_key_inst(self, anon_map, bindparams): + return HasCacheKey._gen_cache_key(self, anon_map, bindparams) + + @_gen_cache_key_inst.classlevel + def _gen_cache_key(cls, anon_map, bindparams): + return (cls, ()) + + @hybridmethod + def _generate_cache_key(self): + return HasCacheKey._generate_cache_key_for_object(self) + + +class ExecutableOption(HasCopyInternals): + __slots__ = () + + _annotations = util.EMPTY_DICT + + __visit_name__ = "executable_option" + + _is_has_cache_key = False + + _is_core = True + + def _clone(self, **kw): + """Create a shallow copy of this ExecutableOption.""" + c = self.__class__.__new__(self.__class__) + c.__dict__ = dict(self.__dict__) # type: ignore + return c + + +class Executable(roles.StatementRole): + """Mark a :class:`_expression.ClauseElement` as supporting execution. + + :class:`.Executable` is a superclass for all "statement" types + of objects, including :func:`select`, :func:`delete`, :func:`update`, + :func:`insert`, :func:`text`. + + """ + + supports_execution: bool = True + _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT + _is_default_generator = False + _with_options: Tuple[ExecutableOption, ...] = () + _with_context_options: Tuple[ + Tuple[Callable[[CompileState], None], Any], ... + ] = () + _compile_options: Optional[Union[Type[CacheableOptions], CacheableOptions]] + + _executable_traverse_internals = [ + ("_with_options", InternalTraversal.dp_executable_options), + ( + "_with_context_options", + ExtendedInternalTraversal.dp_with_context_options, + ), + ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), + ] + + is_select = False + is_update = False + is_insert = False + is_text = False + is_delete = False + is_dml = False + + if TYPE_CHECKING: + __visit_name__: str + + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: ... + + def _execute_on_connection( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> CursorResult[Any]: ... + + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Any: ... + + @util.ro_non_memoized_property + def _all_selected_columns(self): + raise NotImplementedError() + + @property + def _effective_plugin_target(self) -> str: + return self.__visit_name__ + + @_generative + def options(self, *options: ExecutableOption) -> Self: + """Apply options to this statement. + + In the general sense, options are any kind of Python object + that can be interpreted by the SQL compiler for the statement. + These options can be consumed by specific dialects or specific kinds + of compilers. + + The most commonly known kind of option are the ORM level options + that apply "eager load" and other loading behaviors to an ORM + query. However, options can theoretically be used for many other + purposes. + + For background on specific kinds of options for specific kinds of + statements, refer to the documentation for those option objects. + + .. versionchanged:: 1.4 - added :meth:`.Executable.options` to + Core statement objects towards the goal of allowing unified + Core / ORM querying capabilities. + + .. seealso:: + + :ref:`loading_columns` - refers to options specific to the usage + of ORM queries + + :ref:`relationship_loader_options` - refers to options specific + to the usage of ORM queries + + """ + self._with_options += tuple( + coercions.expect(roles.ExecutableOptionRole, opt) + for opt in options + ) + return self + + @_generative + def _set_compile_options(self, compile_options: CacheableOptions) -> Self: + """Assign the compile options to a new value. + + :param compile_options: appropriate CacheableOptions structure + + """ + + self._compile_options = compile_options + return self + + @_generative + def _update_compile_options(self, options: CacheableOptions) -> Self: + """update the _compile_options with new keys.""" + + assert self._compile_options is not None + self._compile_options += options + return self + + @_generative + def _add_context_option( + self, + callable_: Callable[[CompileState], None], + cache_args: Any, + ) -> Self: + """Add a context option to this statement. + + These are callable functions that will + be given the CompileState object upon compilation. + + A second argument cache_args is required, which will be combined with + the ``__code__`` identity of the function itself in order to produce a + cache key. + + """ + self._with_context_options += ((callable_, cache_args),) + return self + + @overload + def execution_options( + self, + *, + compiled_cache: Optional[CompiledCacheType] = ..., + logging_token: str = ..., + isolation_level: IsolationLevel = ..., + no_parameters: bool = False, + stream_results: bool = False, + max_row_buffer: int = ..., + yield_per: int = ..., + insertmanyvalues_page_size: int = ..., + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + populate_existing: bool = False, + autoflush: bool = False, + synchronize_session: SynchronizeSessionArgument = ..., + dml_strategy: DMLStrategyArgument = ..., + render_nulls: bool = ..., + is_delete_using: bool = ..., + is_update_from: bool = ..., + preserve_rowcount: bool = False, + **opt: Any, + ) -> Self: ... + + @overload + def execution_options(self, **opt: Any) -> Self: ... + + @_generative + def execution_options(self, **kw: Any) -> Self: + """Set non-SQL options for the statement which take effect during + execution. + + Execution options can be set at many scopes, including per-statement, + per-connection, or per execution, using methods such as + :meth:`_engine.Connection.execution_options` and parameters which + accept a dictionary of options such as + :paramref:`_engine.Connection.execute.execution_options` and + :paramref:`_orm.Session.execute.execution_options`. + + The primary characteristic of an execution option, as opposed to + other kinds of options such as ORM loader options, is that + **execution options never affect the compiled SQL of a query, only + things that affect how the SQL statement itself is invoked or how + results are fetched**. That is, execution options are not part of + what's accommodated by SQL compilation nor are they considered part of + the cached state of a statement. + + The :meth:`_sql.Executable.execution_options` method is + :term:`generative`, as + is the case for the method as applied to the :class:`_engine.Engine` + and :class:`_orm.Query` objects, which means when the method is called, + a copy of the object is returned, which applies the given parameters to + that new copy, but leaves the original unchanged:: + + statement = select(table.c.x, table.c.y) + new_statement = statement.execution_options(my_option=True) + + An exception to this behavior is the :class:`_engine.Connection` + object, where the :meth:`_engine.Connection.execution_options` method + is explicitly **not** generative. + + The kinds of options that may be passed to + :meth:`_sql.Executable.execution_options` and other related methods and + parameter dictionaries include parameters that are explicitly consumed + by SQLAlchemy Core or ORM, as well as arbitrary keyword arguments not + defined by SQLAlchemy, which means the methods and/or parameter + dictionaries may be used for user-defined parameters that interact with + custom code, which may access the parameters using methods such as + :meth:`_sql.Executable.get_execution_options` and + :meth:`_engine.Connection.get_execution_options`, or within selected + event hooks using a dedicated ``execution_options`` event parameter + such as + :paramref:`_events.ConnectionEvents.before_execute.execution_options` + or :attr:`_orm.ORMExecuteState.execution_options`, e.g.:: + + from sqlalchemy import event + + @event.listens_for(some_engine, "before_execute") + def _process_opt(conn, statement, multiparams, params, execution_options): + "run a SQL function before invoking a statement" + + if execution_options.get("do_special_thing", False): + conn.exec_driver_sql("run_special_function()") + + Within the scope of options that are explicitly recognized by + SQLAlchemy, most apply to specific classes of objects and not others. + The most common execution options include: + + * :paramref:`_engine.Connection.execution_options.isolation_level` - + sets the isolation level for a connection or a class of connections + via an :class:`_engine.Engine`. This option is accepted only + by :class:`_engine.Connection` or :class:`_engine.Engine`. + + * :paramref:`_engine.Connection.execution_options.stream_results` - + indicates results should be fetched using a server side cursor; + this option is accepted by :class:`_engine.Connection`, by the + :paramref:`_engine.Connection.execute.execution_options` parameter + on :meth:`_engine.Connection.execute`, and additionally by + :meth:`_sql.Executable.execution_options` on a SQL statement object, + as well as by ORM constructs like :meth:`_orm.Session.execute`. + + * :paramref:`_engine.Connection.execution_options.compiled_cache` - + indicates a dictionary that will serve as the + :ref:`SQL compilation cache ` + for a :class:`_engine.Connection` or :class:`_engine.Engine`, as + well as for ORM methods like :meth:`_orm.Session.execute`. + Can be passed as ``None`` to disable caching for statements. + This option is not accepted by + :meth:`_sql.Executable.execution_options` as it is inadvisable to + carry along a compilation cache within a statement object. + + * :paramref:`_engine.Connection.execution_options.schema_translate_map` + - a mapping of schema names used by the + :ref:`Schema Translate Map ` feature, accepted + by :class:`_engine.Connection`, :class:`_engine.Engine`, + :class:`_sql.Executable`, as well as by ORM constructs + like :meth:`_orm.Session.execute`. + + .. seealso:: + + :meth:`_engine.Connection.execution_options` + + :paramref:`_engine.Connection.execute.execution_options` + + :paramref:`_orm.Session.execute.execution_options` + + :ref:`orm_queryguide_execution_options` - documentation on all + ORM-specific execution options + + """ # noqa: E501 + if "isolation_level" in kw: + raise exc.ArgumentError( + "'isolation_level' execution option may only be specified " + "on Connection.execution_options(), or " + "per-engine using the isolation_level " + "argument to create_engine()." + ) + if "compiled_cache" in kw: + raise exc.ArgumentError( + "'compiled_cache' execution option may only be specified " + "on Connection.execution_options(), not per statement." + ) + self._execution_options = self._execution_options.union(kw) + return self + + def get_execution_options(self) -> _ExecuteOptions: + """Get the non-SQL options which will take effect during execution. + + .. versionadded:: 1.3 + + .. seealso:: + + :meth:`.Executable.execution_options` + """ + return self._execution_options + + +class SchemaEventTarget(event.EventTarget): + """Base class for elements that are the targets of :class:`.DDLEvents` + events. + + This includes :class:`.SchemaItem` as well as :class:`.SchemaType`. + + """ + + dispatch: dispatcher[SchemaEventTarget] + + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + """Associate with this SchemaEvent's parent object.""" + + def _set_parent_with_dispatch( + self, parent: SchemaEventTarget, **kw: Any + ) -> None: + self.dispatch.before_parent_attach(self, parent) + self._set_parent(parent, **kw) + self.dispatch.after_parent_attach(self, parent) + + +class SchemaVisitor(ClauseVisitor): + """Define the visiting for ``SchemaItem`` objects.""" + + __traverse_options__ = {"schema_visitor": True} + + +class _SentinelDefaultCharacterization(Enum): + NONE = "none" + UNKNOWN = "unknown" + CLIENTSIDE = "clientside" + SENTINEL_DEFAULT = "sentinel_default" + SERVERSIDE = "serverside" + IDENTITY = "identity" + SEQUENCE = "sequence" + + +class _SentinelColumnCharacterization(NamedTuple): + columns: Optional[Sequence[Column[Any]]] = None + is_explicit: bool = False + is_autoinc: bool = False + default_characterization: _SentinelDefaultCharacterization = ( + _SentinelDefaultCharacterization.NONE + ) + + +_COLKEY = TypeVar("_COLKEY", Union[None, str], str) + +_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) +_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") + + +class _ColumnMetrics(Generic[_COL_co]): + __slots__ = ("column",) + + column: _COL_co + + def __init__( + self, collection: ColumnCollection[Any, _COL_co], col: _COL_co + ): + self.column = col + + # proxy_index being non-empty means it was initialized. + # so we need to update it + pi = collection._proxy_index + if pi: + for eps_col in col._expanded_proxy_set: + pi[eps_col].add(self) + + def get_expanded_proxy_set(self): + return self.column._expanded_proxy_set + + def dispose(self, collection): + pi = collection._proxy_index + if not pi: + return + for col in self.column._expanded_proxy_set: + colset = pi.get(col, None) + if colset: + colset.discard(self) + if colset is not None and not colset: + del pi[col] + + def embedded( + self, + target_set: Union[ + Set[ColumnElement[Any]], FrozenSet[ColumnElement[Any]] + ], + ) -> bool: + expanded_proxy_set = self.column._expanded_proxy_set + for t in target_set.difference(expanded_proxy_set): + if not expanded_proxy_set.intersection(_expand_cloned([t])): + return False + return True + + +class ColumnCollection(Generic[_COLKEY, _COL_co]): + """Collection of :class:`_expression.ColumnElement` instances, + typically for + :class:`_sql.FromClause` objects. + + The :class:`_sql.ColumnCollection` object is most commonly available + as the :attr:`_schema.Table.c` or :attr:`_schema.Table.columns` collection + on the :class:`_schema.Table` object, introduced at + :ref:`metadata_tables_and_columns`. + + The :class:`_expression.ColumnCollection` has both mapping- and sequence- + like behaviors. A :class:`_expression.ColumnCollection` usually stores + :class:`_schema.Column` objects, which are then accessible both via mapping + style access as well as attribute access style. + + To access :class:`_schema.Column` objects using ordinary attribute-style + access, specify the name like any other object attribute, such as below + a column named ``employee_name`` is accessed:: + + >>> employee_table.c.employee_name + + To access columns that have names with special characters or spaces, + index-style access is used, such as below which illustrates a column named + ``employee ' payment`` is accessed:: + + >>> employee_table.c["employee ' payment"] + + As the :class:`_sql.ColumnCollection` object provides a Python dictionary + interface, common dictionary method names like + :meth:`_sql.ColumnCollection.keys`, :meth:`_sql.ColumnCollection.values`, + and :meth:`_sql.ColumnCollection.items` are available, which means that + database columns that are keyed under these names also need to use indexed + access:: + + >>> employee_table.c["values"] + + + The name for which a :class:`_schema.Column` would be present is normally + that of the :paramref:`_schema.Column.key` parameter. In some contexts, + such as a :class:`_sql.Select` object that uses a label style set + using the :meth:`_sql.Select.set_label_style` method, a column of a certain + key may instead be represented under a particular label name such + as ``tablename_columnname``:: + + >>> from sqlalchemy import select, column, table + >>> from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL + >>> t = table("t", column("c")) + >>> stmt = select(t).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + >>> subq = stmt.subquery() + >>> subq.c.t_c + + + :class:`.ColumnCollection` also indexes the columns in order and allows + them to be accessible by their integer position:: + + >>> cc[0] + Column('x', Integer(), table=None) + >>> cc[1] + Column('y', Integer(), table=None) + + .. versionadded:: 1.4 :class:`_expression.ColumnCollection` + allows integer-based + index access to the collection. + + Iterating the collection yields the column expressions in order:: + + >>> list(cc) + [Column('x', Integer(), table=None), + Column('y', Integer(), table=None)] + + The base :class:`_expression.ColumnCollection` object can store + duplicates, which can + mean either two columns with the same key, in which case the column + returned by key access is **arbitrary**:: + + >>> x1, x2 = Column('x', Integer), Column('x', Integer) + >>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)]) + >>> list(cc) + [Column('x', Integer(), table=None), + Column('x', Integer(), table=None)] + >>> cc['x'] is x1 + False + >>> cc['x'] is x2 + True + + Or it can also mean the same column multiple times. These cases are + supported as :class:`_expression.ColumnCollection` + is used to represent the columns in + a SELECT statement which may include duplicates. + + A special subclass :class:`.DedupeColumnCollection` exists which instead + maintains SQLAlchemy's older behavior of not allowing duplicates; this + collection is used for schema level objects like :class:`_schema.Table` + and + :class:`.PrimaryKeyConstraint` where this deduping is helpful. The + :class:`.DedupeColumnCollection` class also has additional mutation methods + as the schema constructs have more use cases that require removal and + replacement of columns. + + .. versionchanged:: 1.4 :class:`_expression.ColumnCollection` + now stores duplicate + column keys as well as the same column in multiple positions. The + :class:`.DedupeColumnCollection` class is added to maintain the + former behavior in those cases where deduplication as well as + additional replace/remove operations are needed. + + + """ + + __slots__ = "_collection", "_index", "_colset", "_proxy_index" + + _collection: List[Tuple[_COLKEY, _COL_co, _ColumnMetrics[_COL_co]]] + _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]] + _proxy_index: Dict[ColumnElement[Any], Set[_ColumnMetrics[_COL_co]]] + _colset: Set[_COL_co] + + def __init__( + self, columns: Optional[Iterable[Tuple[_COLKEY, _COL_co]]] = None + ): + object.__setattr__(self, "_colset", set()) + object.__setattr__(self, "_index", {}) + object.__setattr__( + self, "_proxy_index", collections.defaultdict(util.OrderedSet) + ) + object.__setattr__(self, "_collection", []) + if columns: + self._initial_populate(columns) + + @util.preload_module("sqlalchemy.sql.elements") + def __clause_element__(self) -> ClauseList: + elements = util.preloaded.sql_elements + + return elements.ClauseList( + _literal_as_text_role=roles.ColumnsClauseRole, + group=False, + *self._all_columns, + ) + + def _initial_populate( + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] + ) -> None: + self._populate_separate_keys(iter_) + + @property + def _all_columns(self) -> List[_COL_co]: + return [col for (_, col, _) in self._collection] + + def keys(self) -> List[_COLKEY]: + """Return a sequence of string key names for all columns in this + collection.""" + return [k for (k, _, _) in self._collection] + + def values(self) -> List[_COL_co]: + """Return a sequence of :class:`_sql.ColumnClause` or + :class:`_schema.Column` objects for all columns in this + collection.""" + return [col for (_, col, _) in self._collection] + + def items(self) -> List[Tuple[_COLKEY, _COL_co]]: + """Return a sequence of (key, column) tuples for all columns in this + collection each consisting of a string key name and a + :class:`_sql.ColumnClause` or + :class:`_schema.Column` object. + """ + + return [(k, col) for (k, col, _) in self._collection] + + def __bool__(self) -> bool: + return bool(self._collection) + + def __len__(self) -> int: + return len(self._collection) + + def __iter__(self) -> Iterator[_COL_co]: + # turn to a list first to maintain over a course of changes + return iter([col for _, col, _ in self._collection]) + + @overload + def __getitem__(self, key: Union[str, int]) -> _COL_co: ... + + @overload + def __getitem__( + self, key: Tuple[Union[str, int], ...] + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... + + @overload + def __getitem__( + self, key: slice + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... + + def __getitem__( + self, key: Union[str, int, slice, Tuple[Union[str, int], ...]] + ) -> Union[ReadOnlyColumnCollection[_COLKEY, _COL_co], _COL_co]: + try: + if isinstance(key, (tuple, slice)): + if isinstance(key, slice): + cols = ( + (sub_key, col) + for (sub_key, col, _) in self._collection[key] + ) + else: + cols = (self._index[sub_key] for sub_key in key) + + return ColumnCollection(cols).as_readonly() + else: + return self._index[key][1] + except KeyError as err: + if isinstance(err.args[0], int): + raise IndexError(err.args[0]) from err + else: + raise + + def __getattr__(self, key: str) -> _COL_co: + try: + return self._index[key][1] + except KeyError as err: + raise AttributeError(key) from err + + def __contains__(self, key: str) -> bool: + if key not in self._index: + if not isinstance(key, str): + raise exc.ArgumentError( + "__contains__ requires a string argument" + ) + return False + else: + return True + + def compare(self, other: ColumnCollection[Any, Any]) -> bool: + """Compare this :class:`_expression.ColumnCollection` to another + based on the names of the keys""" + + for l, r in zip_longest(self, other): + if l is not r: + return False + else: + return True + + def __eq__(self, other: Any) -> bool: + return self.compare(other) + + def get( + self, key: str, default: Optional[_COL_co] = None + ) -> Optional[_COL_co]: + """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object + based on a string key name from this + :class:`_expression.ColumnCollection`.""" + + if key in self._index: + return self._index[key][1] + else: + return default + + def __str__(self) -> str: + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join(str(c) for c in self), + ) + + def __setitem__(self, key: str, value: Any) -> NoReturn: + raise NotImplementedError() + + def __delitem__(self, key: str) -> NoReturn: + raise NotImplementedError() + + def __setattr__(self, key: str, obj: Any) -> NoReturn: + raise NotImplementedError() + + def clear(self) -> NoReturn: + """Dictionary clear() is not implemented for + :class:`_sql.ColumnCollection`.""" + raise NotImplementedError() + + def remove(self, column: Any) -> None: + raise NotImplementedError() + + def update(self, iter_: Any) -> NoReturn: + """Dictionary update() is not implemented for + :class:`_sql.ColumnCollection`.""" + raise NotImplementedError() + + # https://github.com/python/mypy/issues/4266 + __hash__ = None # type: ignore + + def _populate_separate_keys( + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] + ) -> None: + """populate from an iterator of (key, column)""" + + self._collection[:] = collection = [ + (k, c, _ColumnMetrics(self, c)) for k, c in iter_ + ] + self._colset.update(c._deannotate() for _, c, _ in collection) + self._index.update( + {idx: (k, c) for idx, (k, c, _) in enumerate(collection)} + ) + self._index.update({k: (k, col) for k, col, _ in reversed(collection)}) + + def add( + self, column: ColumnElement[Any], key: Optional[_COLKEY] = None + ) -> None: + """Add a column to this :class:`_sql.ColumnCollection`. + + .. note:: + + This method is **not normally used by user-facing code**, as the + :class:`_sql.ColumnCollection` is usually part of an existing + object such as a :class:`_schema.Table`. To add a + :class:`_schema.Column` to an existing :class:`_schema.Table` + object, use the :meth:`_schema.Table.append_column` method. + + """ + colkey: _COLKEY + + if key is None: + colkey = column.key # type: ignore + else: + colkey = key + + l = len(self._collection) + + # don't really know how this part is supposed to work w/ the + # covariant thing + + _column = cast(_COL_co, column) + + self._collection.append( + (colkey, _column, _ColumnMetrics(self, _column)) + ) + self._colset.add(_column._deannotate()) + self._index[l] = (colkey, _column) + if colkey not in self._index: + self._index[colkey] = (colkey, _column) + + def __getstate__(self) -> Dict[str, Any]: + return { + "_collection": [(k, c) for k, c, _ in self._collection], + "_index": self._index, + } + + def __setstate__(self, state: Dict[str, Any]) -> None: + object.__setattr__(self, "_index", state["_index"]) + object.__setattr__( + self, "_proxy_index", collections.defaultdict(util.OrderedSet) + ) + object.__setattr__( + self, + "_collection", + [ + (k, c, _ColumnMetrics(self, c)) + for (k, c) in state["_collection"] + ], + ) + object.__setattr__( + self, "_colset", {col for k, col, _ in self._collection} + ) + + def contains_column(self, col: ColumnElement[Any]) -> bool: + """Checks if a column object exists in this collection""" + if col not in self._colset: + if isinstance(col, str): + raise exc.ArgumentError( + "contains_column cannot be used with string arguments. " + "Use ``col_name in table.c`` instead." + ) + return False + else: + return True + + def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: + """Return a "read only" form of this + :class:`_sql.ColumnCollection`.""" + + return ReadOnlyColumnCollection(self) + + def _init_proxy_index(self): + """populate the "proxy index", if empty. + + proxy index is added in 2.0 to provide more efficient operation + for the corresponding_column() method. + + For reasons of both time to construct new .c collections as well as + memory conservation for large numbers of large .c collections, the + proxy_index is only filled if corresponding_column() is called. once + filled it stays that way, and new _ColumnMetrics objects created after + that point will populate it with new data. Note this case would be + unusual, if not nonexistent, as it means a .c collection is being + mutated after corresponding_column() were used, however it is tested in + test/base/test_utils.py. + + """ + pi = self._proxy_index + if pi: + return + + for _, _, metrics in self._collection: + eps = metrics.column._expanded_proxy_set + + for eps_col in eps: + pi[eps_col].add(metrics) + + def corresponding_column( + self, column: _COL, require_embedded: bool = False + ) -> Optional[Union[_COL, _COL_co]]: + """Given a :class:`_expression.ColumnElement`, return the exported + :class:`_expression.ColumnElement` object from this + :class:`_expression.ColumnCollection` + which corresponds to that original :class:`_expression.ColumnElement` + via a common + ancestor column. + + :param column: the target :class:`_expression.ColumnElement` + to be matched. + + :param require_embedded: only return corresponding columns for + the given :class:`_expression.ColumnElement`, if the given + :class:`_expression.ColumnElement` + is actually present within a sub-element + of this :class:`_expression.Selectable`. + Normally the column will match if + it merely shares a common ancestor with one of the exported + columns of this :class:`_expression.Selectable`. + + .. seealso:: + + :meth:`_expression.Selectable.corresponding_column` + - invokes this method + against the collection returned by + :attr:`_expression.Selectable.exported_columns`. + + .. versionchanged:: 1.4 the implementation for ``corresponding_column`` + was moved onto the :class:`_expression.ColumnCollection` itself. + + """ + # TODO: cython candidate + + # don't dig around if the column is locally present + if column in self._colset: + return column + + selected_intersection, selected_metrics = None, None + target_set = column.proxy_set + + pi = self._proxy_index + if not pi: + self._init_proxy_index() + + for current_metrics in ( + mm for ts in target_set if ts in pi for mm in pi[ts] + ): + if not require_embedded or current_metrics.embedded(target_set): + if selected_metrics is None: + # no corresponding column yet, pick this one. + selected_metrics = current_metrics + continue + + current_intersection = target_set.intersection( + current_metrics.column._expanded_proxy_set + ) + if selected_intersection is None: + selected_intersection = target_set.intersection( + selected_metrics.column._expanded_proxy_set + ) + + if len(current_intersection) > len(selected_intersection): + # 'current' has a larger field of correspondence than + # 'selected'. i.e. selectable.c.a1_x->a1.c.x->table.c.x + # matches a1.c.x->table.c.x better than + # selectable.c.x->table.c.x does. + + selected_metrics = current_metrics + selected_intersection = current_intersection + elif current_intersection == selected_intersection: + # they have the same field of correspondence. see + # which proxy_set has fewer columns in it, which + # indicates a closer relationship with the root + # column. Also take into account the "weight" + # attribute which CompoundSelect() uses to give + # higher precedence to columns based on vertical + # position in the compound statement, and discard + # columns that have no reference to the target + # column (also occurs with CompoundSelect) + + selected_col_distance = sum( + [ + sc._annotations.get("weight", 1) + for sc in ( + selected_metrics.column._uncached_proxy_list() + ) + if sc.shares_lineage(column) + ], + ) + current_col_distance = sum( + [ + sc._annotations.get("weight", 1) + for sc in ( + current_metrics.column._uncached_proxy_list() + ) + if sc.shares_lineage(column) + ], + ) + if current_col_distance < selected_col_distance: + selected_metrics = current_metrics + selected_intersection = current_intersection + + return selected_metrics.column if selected_metrics else None + + +_NAMEDCOL = TypeVar("_NAMEDCOL", bound="NamedColumn[Any]") + + +class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): + """A :class:`_expression.ColumnCollection` + that maintains deduplicating behavior. + + This is useful by schema level objects such as :class:`_schema.Table` and + :class:`.PrimaryKeyConstraint`. The collection includes more + sophisticated mutator methods as well to suit schema objects which + require mutable column collections. + + .. versionadded:: 1.4 + + """ + + def add( + self, column: ColumnElement[Any], key: Optional[str] = None + ) -> None: + named_column = cast(_NAMEDCOL, column) + if key is not None and named_column.key != key: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + key = named_column.key + + if key is None: + raise exc.ArgumentError( + "Can't add unnamed column to column collection" + ) + + if key in self._index: + existing = self._index[key][1] + + if existing is named_column: + return + + self.replace(named_column) + + # pop out memoized proxy_set as this + # operation may very well be occurring + # in a _make_proxy operation + util.memoized_property.reset(named_column, "proxy_set") + else: + self._append_new_column(key, named_column) + + def _append_new_column(self, key: str, named_column: _NAMEDCOL) -> None: + l = len(self._collection) + self._collection.append( + (key, named_column, _ColumnMetrics(self, named_column)) + ) + self._colset.add(named_column._deannotate()) + self._index[l] = (key, named_column) + self._index[key] = (key, named_column) + + def _populate_separate_keys( + self, iter_: Iterable[Tuple[str, _NAMEDCOL]] + ) -> None: + """populate from an iterator of (key, column)""" + cols = list(iter_) + + replace_col = [] + for k, col in cols: + if col.key != k: + raise exc.ArgumentError( + "DedupeColumnCollection requires columns be under " + "the same key as their .key" + ) + if col.name in self._index and col.key != col.name: + replace_col.append(col) + elif col.key in self._index: + replace_col.append(col) + else: + self._index[k] = (k, col) + self._collection.append((k, col, _ColumnMetrics(self, col))) + self._colset.update(c._deannotate() for (k, c, _) in self._collection) + + self._index.update( + (idx, (k, c)) for idx, (k, c, _) in enumerate(self._collection) + ) + for col in replace_col: + self.replace(col) + + def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: + self._populate_separate_keys((col.key, col) for col in iter_) + + def remove(self, column: _NAMEDCOL) -> None: + if column not in self._colset: + raise ValueError( + "Can't remove column %r; column is not in this collection" + % column + ) + del self._index[column.key] + self._colset.remove(column) + self._collection[:] = [ + (k, c, metrics) + for (k, c, metrics) in self._collection + if c is not column + ] + for metrics in self._proxy_index.get(column, ()): + metrics.dispose(self) + + self._index.update( + {idx: (k, col) for idx, (k, col, _) in enumerate(self._collection)} + ) + # delete higher index + del self._index[len(self._collection)] + + def replace( + self, + column: _NAMEDCOL, + extra_remove: Optional[Iterable[_NAMEDCOL]] = None, + ) -> None: + """add the given column to this collection, removing unaliased + versions of this column as well as existing columns with the + same key. + + e.g.:: + + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) + + will remove the original 'col1' from the collection, and add + the new column under the name 'columnname'. + + Used by schema.Column to override columns during table reflection. + + """ + + if extra_remove: + remove_col = set(extra_remove) + else: + remove_col = set() + # remove up to two columns based on matches of name as well as key + if column.name in self._index and column.key != column.name: + other = self._index[column.name][1] + if other.name == other.key: + remove_col.add(other) + + if column.key in self._index: + remove_col.add(self._index[column.key][1]) + + if not remove_col: + self._append_new_column(column.key, column) + return + new_cols: List[Tuple[str, _NAMEDCOL, _ColumnMetrics[_NAMEDCOL]]] = [] + replaced = False + for k, col, metrics in self._collection: + if col in remove_col: + if not replaced: + replaced = True + new_cols.append( + (column.key, column, _ColumnMetrics(self, column)) + ) + else: + new_cols.append((k, col, metrics)) + + if remove_col: + self._colset.difference_update(remove_col) + + for rc in remove_col: + for metrics in self._proxy_index.get(rc, ()): + metrics.dispose(self) + + if not replaced: + new_cols.append((column.key, column, _ColumnMetrics(self, column))) + + self._colset.add(column._deannotate()) + self._collection[:] = new_cols + + self._index.clear() + + self._index.update( + {idx: (k, col) for idx, (k, col, _) in enumerate(self._collection)} + ) + self._index.update({k: (k, col) for (k, col, _) in self._collection}) + + +class ReadOnlyColumnCollection( + util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL_co] +): + __slots__ = ("_parent",) + + def __init__(self, collection): + object.__setattr__(self, "_parent", collection) + object.__setattr__(self, "_colset", collection._colset) + object.__setattr__(self, "_index", collection._index) + object.__setattr__(self, "_collection", collection._collection) + object.__setattr__(self, "_proxy_index", collection._proxy_index) + + def __getstate__(self): + return {"_parent": self._parent} + + def __setstate__(self, state): + parent = state["_parent"] + self.__init__(parent) # type: ignore + + def add(self, column: Any, key: Any = ...) -> Any: + self._readonly() + + def extend(self, elements: Any) -> NoReturn: + self._readonly() + + def remove(self, item: Any) -> NoReturn: + self._readonly() + + +class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): + def contains_column(self, col): + return col in self + + def extend(self, cols): + for col in cols: + self.add(col) + + def __eq__(self, other): + l = [] + for c in other: + for local in self: + if c.shares_lineage(local): + l.append(c == local) + return elements.and_(*l) + + def __hash__(self): + return hash(tuple(x for x in self)) + + +def _entity_namespace( + entity: Union[_HasEntityNamespace, ExternallyTraversible] +) -> _EntityNamespace: + """Return the nearest .entity_namespace for the given entity. + + If not immediately available, does an iterate to find a sub-element + that has one, if any. + + """ + try: + return cast(_HasEntityNamespace, entity).entity_namespace + except AttributeError: + for elem in visitors.iterate(cast(ExternallyTraversible, entity)): + if _is_has_entity_namespace(elem): + return elem.entity_namespace + else: + raise + + +def _entity_namespace_key( + entity: Union[_HasEntityNamespace, ExternallyTraversible], + key: str, + default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG, +) -> SQLCoreOperations[Any]: + """Return an entry from an entity_namespace. + + + Raises :class:`_exc.InvalidRequestError` rather than attribute error + on not found. + + """ + + try: + ns = _entity_namespace(entity) + if default is not NO_ARG: + return getattr(ns, key, default) + else: + return getattr(ns, key) # type: ignore + except AttributeError as err: + raise exc.InvalidRequestError( + 'Entity namespace for "%s" has no property "%s"' % (entity, key) + ) from err diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/cache_key.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/cache_key.py new file mode 100644 index 0000000..1172d3c --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/cache_key.py @@ -0,0 +1,1057 @@ +# sql/cache_key.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +from __future__ import annotations + +import enum +from itertools import zip_longest +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import MutableMapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +from .visitors import anon_map +from .visitors import HasTraversalDispatch +from .visitors import HasTraverseInternals +from .visitors import InternalTraversal +from .visitors import prefix_anon_map +from .. import util +from ..inspection import inspect +from ..util import HasMemoized +from ..util.typing import Literal +from ..util.typing import Protocol + +if typing.TYPE_CHECKING: + from .elements import BindParameter + from .elements import ClauseElement + from .elements import ColumnElement + from .visitors import _TraverseInternalsType + from ..engine.interfaces import _CoreSingleExecuteParams + + +class _CacheKeyTraversalDispatchType(Protocol): + def __call__( + s, self: HasCacheKey, visitor: _CacheKeyTraversal + ) -> _CacheKeyTraversalDispatchTypeReturn: ... + + +class CacheConst(enum.Enum): + NO_CACHE = 0 + + +NO_CACHE = CacheConst.NO_CACHE + + +_CacheKeyTraversalType = Union[ + "_TraverseInternalsType", Literal[CacheConst.NO_CACHE], Literal[None] +] + + +class CacheTraverseTarget(enum.Enum): + CACHE_IN_PLACE = 0 + CALL_GEN_CACHE_KEY = 1 + STATIC_CACHE_KEY = 2 + PROPAGATE_ATTRS = 3 + ANON_NAME = 4 + + +( + CACHE_IN_PLACE, + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + PROPAGATE_ATTRS, + ANON_NAME, +) = tuple(CacheTraverseTarget) + +_CacheKeyTraversalDispatchTypeReturn = Sequence[ + Tuple[ + str, + Any, + Union[ + Callable[..., Tuple[Any, ...]], + CacheTraverseTarget, + InternalTraversal, + ], + ] +] + + +class HasCacheKey: + """Mixin for objects which can produce a cache key. + + This class is usually in a hierarchy that starts with the + :class:`.HasTraverseInternals` base, but this is optional. Currently, + the class should be able to work on its own without including + :class:`.HasTraverseInternals`. + + .. seealso:: + + :class:`.CacheKey` + + :ref:`sql_caching` + + """ + + __slots__ = () + + _cache_key_traversal: _CacheKeyTraversalType = NO_CACHE + + _is_has_cache_key = True + + _hierarchy_supports_caching = True + """private attribute which may be set to False to prevent the + inherit_cache warning from being emitted for a hierarchy of subclasses. + + Currently applies to the :class:`.ExecutableDDLElement` hierarchy which + does not implement caching. + + """ + + inherit_cache: Optional[bool] = None + """Indicate if this :class:`.HasCacheKey` instance should make use of the + cache key generation scheme used by its immediate superclass. + + The attribute defaults to ``None``, which indicates that a construct has + not yet taken into account whether or not its appropriate for it to + participate in caching; this is functionally equivalent to setting the + value to ``False``, except that a warning is also emitted. + + This flag can be set to ``True`` on a particular class, if the SQL that + corresponds to the object does not change based on attributes which + are local to this class, and not its superclass. + + .. seealso:: + + :ref:`compilerext_caching` - General guideslines for setting the + :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user + defined SQL constructs. + + """ + + __slots__ = () + + _generated_cache_key_traversal: Any + + @classmethod + def _generate_cache_attrs( + cls, + ) -> Union[_CacheKeyTraversalDispatchType, Literal[CacheConst.NO_CACHE]]: + """generate cache key dispatcher for a new class. + + This sets the _generated_cache_key_traversal attribute once called + so should only be called once per class. + + """ + inherit_cache = cls.__dict__.get("inherit_cache", None) + inherit = bool(inherit_cache) + + if inherit: + _cache_key_traversal = getattr(cls, "_cache_key_traversal", None) + if _cache_key_traversal is None: + try: + assert issubclass(cls, HasTraverseInternals) + _cache_key_traversal = cls._traverse_internals + except AttributeError: + cls._generated_cache_key_traversal = NO_CACHE + return NO_CACHE + + assert _cache_key_traversal is not NO_CACHE, ( + f"class {cls} has _cache_key_traversal=NO_CACHE, " + "which conflicts with inherit_cache=True" + ) + + # TODO: wouldn't we instead get this from our superclass? + # also, our superclass may not have this yet, but in any case, + # we'd generate for the superclass that has it. this is a little + # more complicated, so for the moment this is a little less + # efficient on startup but simpler. + return _cache_key_traversal_visitor.generate_dispatch( + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", + ) + else: + _cache_key_traversal = cls.__dict__.get( + "_cache_key_traversal", None + ) + if _cache_key_traversal is None: + _cache_key_traversal = cls.__dict__.get( + "_traverse_internals", None + ) + if _cache_key_traversal is None: + cls._generated_cache_key_traversal = NO_CACHE + if ( + inherit_cache is None + and cls._hierarchy_supports_caching + ): + util.warn( + "Class %s will not make use of SQL compilation " + "caching as it does not set the 'inherit_cache' " + "attribute to ``True``. This can have " + "significant performance implications including " + "some performance degradations in comparison to " + "prior SQLAlchemy versions. Set this attribute " + "to True if this object can make use of the cache " + "key generated by the superclass. Alternatively, " + "this attribute may be set to False which will " + "disable this warning." % (cls.__name__), + code="cprf", + ) + return NO_CACHE + + return _cache_key_traversal_visitor.generate_dispatch( + cls, + _cache_key_traversal, + "_generated_cache_key_traversal", + ) + + @util.preload_module("sqlalchemy.sql.elements") + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Optional[Tuple[Any, ...]]: + """return an optional cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + If a structure cannot produce a useful cache key, the NO_CACHE + symbol should be added to the anon_map and the method should + return None. + + """ + + cls = self.__class__ + + id_, found = anon_map.get_anon(self) + if found: + return (id_, cls) + + dispatcher: Union[ + Literal[CacheConst.NO_CACHE], + _CacheKeyTraversalDispatchType, + ] + + try: + dispatcher = cls.__dict__["_generated_cache_key_traversal"] + except KeyError: + # traversals.py -> _preconfigure_traversals() + # may be used to run these ahead of time, but + # is not enabled right now. + # this block will generate any remaining dispatchers. + dispatcher = cls._generate_cache_attrs() + + if dispatcher is NO_CACHE: + anon_map[NO_CACHE] = True + return None + + result: Tuple[Any, ...] = (id_, cls) + + # inline of _cache_key_traversal_visitor.run_generated_dispatch() + + for attrname, obj, meth in dispatcher( + self, _cache_key_traversal_visitor + ): + if obj is not None: + # TODO: see if C code can help here as Python lacks an + # efficient switch construct + + if meth is STATIC_CACHE_KEY: + sck = obj._static_cache_key + if sck is NO_CACHE: + anon_map[NO_CACHE] = True + return None + result += (attrname, sck) + elif meth is ANON_NAME: + elements = util.preloaded.sql_elements + if isinstance(obj, elements._anonymous_label): + obj = obj.apply_map(anon_map) # type: ignore + result += (attrname, obj) + elif meth is CALL_GEN_CACHE_KEY: + result += ( + attrname, + obj._gen_cache_key(anon_map, bindparams), + ) + + # remaining cache functions are against + # Python tuples, dicts, lists, etc. so we can skip + # if they are empty + elif obj: + if meth is CACHE_IN_PLACE: + result += (attrname, obj) + elif meth is PROPAGATE_ATTRS: + result += ( + attrname, + obj["compile_state_plugin"], + ( + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None + ), + ) + elif meth is InternalTraversal.dp_annotations_key: + # obj is here is the _annotations dict. Table uses + # a memoized version of it. however in other cases, + # we generate it given anon_map as we may be from a + # Join, Aliased, etc. + # see #8790 + + if self._gen_static_annotations_cache_key: # type: ignore # noqa: E501 + result += self._annotations_cache_key # type: ignore # noqa: E501 + else: + result += self._gen_annotations_cache_key(anon_map) # type: ignore # noqa: E501 + + elif ( + meth is InternalTraversal.dp_clauseelement_list + or meth is InternalTraversal.dp_clauseelement_tuple + or meth + is InternalTraversal.dp_memoized_select_entities + ): + result += ( + attrname, + tuple( + [ + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + ] + ), + ) + else: + result += meth( # type: ignore + attrname, obj, self, anon_map, bindparams + ) + return result + + def _generate_cache_key(self) -> Optional[CacheKey]: + """return a cache key. + + The cache key is a tuple which can contain any series of + objects that are hashable and also identifies + this object uniquely within the presence of a larger SQL expression + or statement, for the purposes of caching the resulting query. + + The cache key should be based on the SQL compiled structure that would + ultimately be produced. That is, two structures that are composed in + exactly the same way should produce the same cache key; any difference + in the structures that would affect the SQL string or the type handlers + should result in a different cache key. + + The cache key returned by this method is an instance of + :class:`.CacheKey`, which consists of a tuple representing the + cache key, as well as a list of :class:`.BindParameter` objects + which are extracted from the expression. While two expressions + that produce identical cache key tuples will themselves generate + identical SQL strings, the list of :class:`.BindParameter` objects + indicates the bound values which may have different values in + each one; these bound parameters must be consulted in order to + execute the statement with the correct parameters. + + a :class:`_expression.ClauseElement` structure that does not implement + a :meth:`._gen_cache_key` method and does not implement a + :attr:`.traverse_internals` attribute will not be cacheable; when + such an element is embedded into a larger structure, this method + will return None, indicating no cache key is available. + + """ + + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + key = self._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + assert key is not None + return CacheKey(key, bindparams) + + @classmethod + def _generate_cache_key_for_object( + cls, obj: HasCacheKey + ) -> Optional[CacheKey]: + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + key = obj._gen_cache_key(_anon_map, bindparams) + if NO_CACHE in _anon_map: + return None + else: + assert key is not None + return CacheKey(key, bindparams) + + +class HasCacheKeyTraverse(HasTraverseInternals, HasCacheKey): + pass + + +class MemoizedHasCacheKey(HasCacheKey, HasMemoized): + __slots__ = () + + @HasMemoized.memoized_instancemethod + def _generate_cache_key(self) -> Optional[CacheKey]: + return HasCacheKey._generate_cache_key(self) + + +class SlotsMemoizedHasCacheKey(HasCacheKey, util.MemoizedSlots): + __slots__ = () + + def _memoized_method__generate_cache_key(self) -> Optional[CacheKey]: + return HasCacheKey._generate_cache_key(self) + + +class CacheKey(NamedTuple): + """The key used to identify a SQL statement construct in the + SQL compilation cache. + + .. seealso:: + + :ref:`sql_caching` + + """ + + key: Tuple[Any, ...] + bindparams: Sequence[BindParameter[Any]] + + # can't set __hash__ attribute because it interferes + # with namedtuple + # can't use "if not TYPE_CHECKING" because mypy rejects it + # inside of a NamedTuple + def __hash__(self) -> Optional[int]: # type: ignore + """CacheKey itself is not hashable - hash the .key portion""" + return None + + def to_offline_string( + self, + statement_cache: MutableMapping[Any, str], + statement: ClauseElement, + parameters: _CoreSingleExecuteParams, + ) -> str: + """Generate an "offline string" form of this :class:`.CacheKey` + + The "offline string" is basically the string SQL for the + statement plus a repr of the bound parameter values in series. + Whereas the :class:`.CacheKey` object is dependent on in-memory + identities in order to work as a cache key, the "offline" version + is suitable for a cache that will work for other processes as well. + + The given ``statement_cache`` is a dictionary-like object where the + string form of the statement itself will be cached. This dictionary + should be in a longer lived scope in order to reduce the time spent + stringifying statements. + + + """ + if self.key not in statement_cache: + statement_cache[self.key] = sql_str = str(statement) + else: + sql_str = statement_cache[self.key] + + if not self.bindparams: + param_tuple = tuple(parameters[key] for key in sorted(parameters)) + else: + param_tuple = tuple( + parameters.get(bindparam.key, bindparam.value) + for bindparam in self.bindparams + ) + + return repr((sql_str, param_tuple)) + + def __eq__(self, other: Any) -> bool: + return bool(self.key == other.key) + + def __ne__(self, other: Any) -> bool: + return not (self.key == other.key) + + @classmethod + def _diff_tuples(cls, left: CacheKey, right: CacheKey) -> str: + ck1 = CacheKey(left, []) + ck2 = CacheKey(right, []) + return ck1._diff(ck2) + + def _whats_different(self, other: CacheKey) -> Iterator[str]: + k1 = self.key + k2 = other.key + + stack: List[int] = [] + pickup_index = 0 + while True: + s1, s2 = k1, k2 + for idx in stack: + s1 = s1[idx] + s2 = s2[idx] + + for idx, (e1, e2) in enumerate(zip_longest(s1, s2)): + if idx < pickup_index: + continue + if e1 != e2: + if isinstance(e1, tuple) and isinstance(e2, tuple): + stack.append(idx) + break + else: + yield "key%s[%d]: %s != %s" % ( + "".join("[%d]" % id_ for id_ in stack), + idx, + e1, + e2, + ) + else: + pickup_index = stack.pop(-1) + break + + def _diff(self, other: CacheKey) -> str: + return ", ".join(self._whats_different(other)) + + def __str__(self) -> str: + stack: List[Union[Tuple[Any, ...], HasCacheKey]] = [self.key] + + output = [] + sentinel = object() + indent = -1 + while stack: + elem = stack.pop(0) + if elem is sentinel: + output.append((" " * (indent * 2)) + "),") + indent -= 1 + elif isinstance(elem, tuple): + if not elem: + output.append((" " * ((indent + 1) * 2)) + "()") + else: + indent += 1 + stack = list(elem) + [sentinel] + stack + output.append((" " * (indent * 2)) + "(") + else: + if isinstance(elem, HasCacheKey): + repr_ = "<%s object at %s>" % ( + type(elem).__name__, + hex(id(elem)), + ) + else: + repr_ = repr(elem) + output.append((" " * (indent * 2)) + " " + repr_ + ", ") + + return "CacheKey(key=%s)" % ("\n".join(output),) + + def _generate_param_dict(self) -> Dict[str, Any]: + """used for testing""" + + _anon_map = prefix_anon_map() + return {b.key % _anon_map: b.effective_value for b in self.bindparams} + + @util.preload_module("sqlalchemy.sql.elements") + def _apply_params_to_element( + self, original_cache_key: CacheKey, target_element: ColumnElement[Any] + ) -> ColumnElement[Any]: + if target_element._is_immutable or original_cache_key is self: + return target_element + + elements = util.preloaded.sql_elements + return elements._OverrideBinds( + target_element, self.bindparams, original_cache_key.bindparams + ) + + +def _ad_hoc_cache_key_from_args( + tokens: Tuple[Any, ...], + traverse_args: Iterable[Tuple[str, InternalTraversal]], + args: Iterable[Any], +) -> Tuple[Any, ...]: + """a quick cache key generator used by reflection.flexi_cache.""" + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + + tup = tokens + + for (attrname, sym), arg in zip(traverse_args, args): + key = sym.name + visit_key = key.replace("dp_", "visit_") + + if arg is None: + tup += (attrname, None) + continue + + meth = getattr(_cache_key_traversal_visitor, visit_key) + if meth is CACHE_IN_PLACE: + tup += (attrname, arg) + elif meth in ( + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + ANON_NAME, + PROPAGATE_ATTRS, + ): + raise NotImplementedError( + f"Haven't implemented symbol {meth} for ad-hoc key from args" + ) + else: + tup += meth(attrname, arg, None, _anon_map, bindparams) + return tup + + +class _CacheKeyTraversal(HasTraversalDispatch): + # very common elements are inlined into the main _get_cache_key() method + # to produce a dramatic savings in Python function call overhead + + visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY + visit_clauseelement_list = InternalTraversal.dp_clauseelement_list + visit_annotations_key = InternalTraversal.dp_annotations_key + visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple + visit_memoized_select_entities = ( + InternalTraversal.dp_memoized_select_entities + ) + + visit_string = visit_boolean = visit_operator = visit_plain_obj = ( + CACHE_IN_PLACE + ) + visit_statement_hint_list = CACHE_IN_PLACE + visit_type = STATIC_CACHE_KEY + visit_anon_name = ANON_NAME + + visit_propagate_attrs = PROPAGATE_ATTRS + + def visit_with_context_options( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return tuple((fn.__code__, c_key) for fn, c_key in obj) + + def visit_inspectable( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams)) + + def visit_string_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return tuple(obj) + + def visit_multi( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + ( + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj + ), + ) + + def visit_multi_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + ( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + ) + for elem in obj + ), + ) + + def visit_has_cache_key_tuples( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple( + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in tup_elem + ) + for tup_elem in obj + ), + ) + + def visit_has_cache_key_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj), + ) + + def visit_executable_options( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple( + elem._gen_cache_key(anon_map, bindparams) + for elem in obj + if elem._is_has_cache_key + ), + ) + + def visit_inspectable_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return self.visit_has_cache_key_list( + attrname, [inspect(o) for o in obj], parent, anon_map, bindparams + ) + + def visit_clauseelement_tuples( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return self.visit_has_cache_key_tuples( + attrname, obj, parent, anon_map, bindparams + ) + + def visit_fromclause_ordered_set( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + return ( + attrname, + tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]), + ) + + def visit_clauseelement_unordered_set( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + cache_keys = [ + elem._gen_cache_key(anon_map, bindparams) for elem in obj + ] + return ( + attrname, + tuple( + sorted(cache_keys) + ), # cache keys all start with (id_, class) + ) + + def visit_named_ddl_element( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return (attrname, obj.name) + + def visit_prefix_sequence( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + + return ( + attrname, + tuple( + [ + (clause._gen_cache_key(anon_map, bindparams), strval) + for clause, strval in obj + ] + ), + ) + + def visit_setup_join_tuple( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return tuple( + ( + target._gen_cache_key(anon_map, bindparams), + ( + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None + ), + ( + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None + ), + tuple([(key, flags[key]) for key in sorted(flags)]), + ) + for (target, onclause, from_, flags) in obj + ) + + def visit_table_hint_list( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + if not obj: + return () + + return ( + attrname, + tuple( + [ + ( + clause._gen_cache_key(anon_map, bindparams), + dialect_name, + text, + ) + for (clause, dialect_name), text in obj.items() + ] + ), + ) + + def visit_plain_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return (attrname, tuple([(key, obj[key]) for key in sorted(obj)])) + + def visit_dialect_options( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + ( + dialect_name, + tuple( + [ + (key, obj[dialect_name][key]) + for key in sorted(obj[dialect_name]) + ] + ), + ) + for dialect_name in sorted(obj) + ), + ) + + def visit_string_clauseelement_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + (key, obj[key]._gen_cache_key(anon_map, bindparams)) + for key in sorted(obj) + ), + ) + + def visit_string_multi_dict( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + ( + key, + ( + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value + ), + ) + for key, value in [(key, obj[key]) for key in sorted(obj)] + ), + ) + + def visit_fromclause_canonical_column_collection( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + # inlining into the internals of ColumnCollection + return ( + attrname, + tuple( + col._gen_cache_key(anon_map, bindparams) + for k, col, _ in obj._collection + ), + ) + + def visit_unknown_structure( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + anon_map[NO_CACHE] = True + return () + + def visit_dml_ordered_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + return ( + attrname, + tuple( + ( + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key + ), + value._gen_cache_key(anon_map, bindparams), + ) + for key, value in obj + ), + ) + + def visit_dml_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + # in py37 we can assume two dictionaries created in the same + # insert ordering will retain that sorting + return ( + attrname, + tuple( + ( + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k + ), + obj[k]._gen_cache_key(anon_map, bindparams), + ) + for k in obj + ), + ) + + def visit_dml_multi_values( + self, + attrname: str, + obj: Any, + parent: Any, + anon_map: anon_map, + bindparams: List[BindParameter[Any]], + ) -> Tuple[Any, ...]: + # multivalues are simply not cacheable right now + anon_map[NO_CACHE] = True + return () + + +_cache_key_traversal_visitor = _CacheKeyTraversal() diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/coercions.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/coercions.py new file mode 100644 index 0000000..22d6091 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/coercions.py @@ -0,0 +1,1389 @@ +# sql/coercions.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +from __future__ import annotations + +import collections.abc as collections_abc +import numbers +import re +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import operators +from . import roles +from . import visitors +from ._typing import is_from_clause +from .base import ExecutableOption +from .base import Options +from .cache_key import HasCacheKey +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util.typing import Literal + +if typing.TYPE_CHECKING: + # elements lambdas schema selectable are set by __init__ + from . import elements + from . import lambdas + from . import schema + from . import selectable + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnsClauseArgument + from ._typing import _DDLColumnArgument + from ._typing import _DMLTableArgument + from ._typing import _FromClauseArgument + from .dml import _DMLTableElement + from .elements import BindParameter + from .elements import ClauseElement + from .elements import ColumnClause + from .elements import ColumnElement + from .elements import DQLDMLClauseElement + from .elements import NamedColumn + from .elements import SQLCoreOperations + from .schema import Column + from .selectable import _ColumnsClauseElement + from .selectable import _JoinTargetProtocol + from .selectable import FromClause + from .selectable import HasCTE + from .selectable import SelectBase + from .selectable import Subquery + from .visitors import _TraverseCallableType + +_SR = TypeVar("_SR", bound=roles.SQLRole) +_F = TypeVar("_F", bound=Callable[..., Any]) +_StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) +_T = TypeVar("_T", bound=Any) + + +def _is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + """ + + return not isinstance( + element, + (Visitable, schema.SchemaEventTarget), + ) and not hasattr(element, "__clause_element__") + + +def _deep_is_literal(element): + """Return whether or not the element is a "literal" in the context + of a SQL expression construct. + + does a deeper more esoteric check than _is_literal. is used + for lambda elements that have to distinguish values that would + be bound vs. not without any context. + + """ + + if isinstance(element, collections_abc.Sequence) and not isinstance( + element, str + ): + for elem in element: + if not _deep_is_literal(elem): + return False + else: + return True + + return ( + not isinstance( + element, + ( + Visitable, + schema.SchemaEventTarget, + HasCacheKey, + Options, + util.langhelpers.symbol, + ), + ) + and not hasattr(element, "__clause_element__") + and ( + not isinstance(element, type) + or not issubclass(element, HasCacheKey) + ) + ) + + +def _document_text_coercion( + paramname: str, meth_rst: str, param_rst: str +) -> Callable[[_F], _F]: + return util.add_parameter_text( + paramname, + ( + ".. warning:: " + "The %s argument to %s can be passed as a Python string argument, " + "which will be treated " + "as **trusted SQL text** and rendered as given. **DO NOT PASS " + "UNTRUSTED INPUT TO THIS PARAMETER**." + ) + % (param_rst, meth_rst), + ) + + +def _expression_collection_was_a_list( + attrname: str, + fnname: str, + args: Union[Sequence[_T], Sequence[Sequence[_T]]], +) -> Sequence[_T]: + if args and isinstance(args[0], (list, set, dict)) and len(args) == 1: + if isinstance(args[0], list): + raise exc.ArgumentError( + f'The "{attrname}" argument to {fnname}(), when ' + "referring to a sequence " + "of items, is now passed as a series of positional " + "elements, rather than as a list. " + ) + return cast("Sequence[_T]", args[0]) + + return cast("Sequence[_T]", args) + + +@overload +def expect( + role: Type[roles.TruncatedLabelRole], + element: Any, + **kw: Any, +) -> str: ... + + +@overload +def expect( + role: Type[roles.DMLColumnRole], + element: Any, + *, + as_key: Literal[True] = ..., + **kw: Any, +) -> str: ... + + +@overload +def expect( + role: Type[roles.LiteralValueRole], + element: Any, + **kw: Any, +) -> BindParameter[Any]: ... + + +@overload +def expect( + role: Type[roles.DDLReferredColumnRole], + element: Any, + **kw: Any, +) -> Column[Any]: ... + + +@overload +def expect( + role: Type[roles.DDLConstraintColumnRole], + element: Any, + **kw: Any, +) -> Union[Column[Any], str]: ... + + +@overload +def expect( + role: Type[roles.StatementOptionRole], + element: Any, + **kw: Any, +) -> DQLDMLClauseElement: ... + + +@overload +def expect( + role: Type[roles.LabeledColumnExprRole[Any]], + element: _ColumnExpressionArgument[_T], + **kw: Any, +) -> NamedColumn[_T]: ... + + +@overload +def expect( + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + ], + element: _ColumnExpressionArgument[_T], + **kw: Any, +) -> ColumnElement[_T]: ... + + +@overload +def expect( + role: Union[ + Type[roles.ExpressionElementRole[Any]], + Type[roles.LimitOffsetRole], + Type[roles.WhereHavingRole], + Type[roles.OnClauseRole], + Type[roles.ColumnArgumentRole], + ], + element: Any, + **kw: Any, +) -> ColumnElement[Any]: ... + + +@overload +def expect( + role: Type[roles.DMLTableRole], + element: _DMLTableArgument, + **kw: Any, +) -> _DMLTableElement: ... + + +@overload +def expect( + role: Type[roles.HasCTERole], + element: HasCTE, + **kw: Any, +) -> HasCTE: ... + + +@overload +def expect( + role: Type[roles.SelectStatementRole], + element: SelectBase, + **kw: Any, +) -> SelectBase: ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: _FromClauseArgument, + **kw: Any, +) -> FromClause: ... + + +@overload +def expect( + role: Type[roles.FromClauseRole], + element: SelectBase, + *, + explicit_subquery: Literal[True] = ..., + **kw: Any, +) -> Subquery: ... + + +@overload +def expect( + role: Type[roles.ColumnsClauseRole], + element: _ColumnsClauseArgument[Any], + **kw: Any, +) -> _ColumnsClauseElement: ... + + +@overload +def expect( + role: Type[roles.JoinTargetRole], + element: _JoinTargetProtocol, + **kw: Any, +) -> _JoinTargetProtocol: ... + + +# catchall for not-yet-implemented overloads +@overload +def expect( + role: Type[_SR], + element: Any, + **kw: Any, +) -> Any: ... + + +def expect( + role: Type[_SR], + element: Any, + *, + apply_propagate_attrs: Optional[ClauseElement] = None, + argname: Optional[str] = None, + post_inspect: bool = False, + disable_inspection: bool = False, + **kw: Any, +) -> Any: + if ( + role.allows_lambda + # note callable() will not invoke a __getattr__() method, whereas + # hasattr(obj, "__call__") will. by keeping the callable() check here + # we prevent most needless calls to hasattr() and therefore + # __getattr__(), which is present on ColumnElement. + and callable(element) + and hasattr(element, "__code__") + ): + return lambdas.LambdaElement( + element, + role, + lambdas.LambdaOptions(**kw), + apply_propagate_attrs=apply_propagate_attrs, + ) + + # major case is that we are given a ClauseElement already, skip more + # elaborate logic up front if possible + impl = _impl_lookup[role] + + original_element = element + + if not isinstance( + element, + ( + elements.CompilerElement, + schema.SchemaItem, + schema.FetchedValue, + lambdas.PyWrapper, + ), + ): + resolved = None + + if impl._resolve_literal_only: + resolved = impl._literal_coercion(element, **kw) + else: + original_element = element + + is_clause_element = False + + # this is a special performance optimization for ORM + # joins used by JoinTargetImpl that we don't go through the + # work of creating __clause_element__() when we only need the + # original QueryableAttribute, as the former will do clause + # adaption and all that which is just thrown away here. + if ( + impl._skip_clauseelement_for_target_match + and isinstance(element, role) + and hasattr(element, "__clause_element__") + ): + is_clause_element = True + else: + while hasattr(element, "__clause_element__"): + is_clause_element = True + + if not getattr(element, "is_clause_element", False): + element = element.__clause_element__() + else: + break + + if not is_clause_element: + if impl._use_inspection and not disable_inspection: + insp = inspection.inspect(element, raiseerr=False) + if insp is not None: + if post_inspect: + insp._post_inspect + try: + resolved = insp.__clause_element__() + except AttributeError: + impl._raise_for_expected(original_element, argname) + + if resolved is None: + resolved = impl._literal_coercion( + element, argname=argname, **kw + ) + else: + resolved = element + elif isinstance(element, lambdas.PyWrapper): + resolved = element._sa__py_wrapper_literal(**kw) + else: + resolved = element + + if apply_propagate_attrs is not None: + if typing.TYPE_CHECKING: + assert isinstance(resolved, (SQLCoreOperations, ClauseElement)) + + if not apply_propagate_attrs._propagate_attrs and getattr( + resolved, "_propagate_attrs", None + ): + apply_propagate_attrs._propagate_attrs = resolved._propagate_attrs + + if impl._role_class in resolved.__class__.__mro__: + if impl._post_coercion: + resolved = impl._post_coercion( + resolved, + argname=argname, + original_element=original_element, + **kw, + ) + return resolved + else: + return impl._implicit_coercions( + original_element, resolved, argname=argname, **kw + ) + + +def expect_as_key( + role: Type[roles.DMLColumnRole], element: Any, **kw: Any +) -> str: + kw.pop("as_key", None) + return expect(role, element, as_key=True, **kw) + + +def expect_col_expression_collection( + role: Type[roles.DDLConstraintColumnRole], + expressions: Iterable[_DDLColumnArgument], +) -> Iterator[ + Tuple[ + Union[str, Column[Any]], + Optional[ColumnClause[Any]], + Optional[str], + Optional[Union[Column[Any], str]], + ] +]: + for expr in expressions: + strname = None + column = None + + resolved: Union[Column[Any], str] = expect(role, expr) + if isinstance(resolved, str): + assert isinstance(expr, str) + strname = resolved = expr + else: + cols: List[Column[Any]] = [] + col_append: _TraverseCallableType[Column[Any]] = cols.append + visitors.traverse(resolved, {}, {"column": col_append}) + if cols: + column = cols[0] + add_element = column if column is not None else strname + + yield resolved, column, strname, add_element + + +class RoleImpl: + __slots__ = ("_role_class", "name", "_use_inspection") + + def _literal_coercion(self, element, **kw): + raise NotImplementedError() + + _post_coercion: Any = None + _resolve_literal_only = False + _skip_clauseelement_for_target_match = False + + def __init__(self, role_class): + self._role_class = role_class + self.name = role_class._role_name + self._use_inspection = issubclass(role_class, roles.UsesInspection) + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + self._raise_for_expected(element, argname, resolved) + + def _raise_for_expected( + self, + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: + if resolved is not None and resolved is not element: + got = "%r object resolved from %r object" % (resolved, element) + else: + got = repr(element) + + if argname: + msg = "%s expected for argument %r; got %s." % ( + self.name, + argname, + got, + ) + else: + msg = "%s expected, got %s." % (self.name, got) + + if advice: + msg += " " + advice + + raise exc.ArgumentError(msg, code=code) from err + + +class _Deannotate: + __slots__ = () + + def _post_coercion(self, resolved, **kw): + from .util import _deep_deannotate + + return _deep_deannotate(resolved) + + +class _StringOnly: + __slots__ = () + + _resolve_literal_only = True + + +class _ReturnsStringKey(RoleImpl): + __slots__ = () + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + if isinstance(element, str): + return element + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class _ColumnCoercions(RoleImpl): + __slots__ = () + + def _warn_for_scalar_subquery_coercion(self): + util.warn( + "implicitly coercing SELECT object to scalar subquery; " + "please use the .scalar_subquery() method to produce a scalar " + "subquery.", + ) + + def _implicit_coercions(self, element, resolved, argname=None, **kw): + original_element = element + if not getattr(resolved, "is_clause_element", False): + self._raise_for_expected(original_element, argname, resolved) + elif resolved._is_select_base: + self._warn_for_scalar_subquery_coercion() + return resolved.scalar_subquery() + elif resolved._is_from_clause and isinstance( + resolved, selectable.Subquery + ): + self._warn_for_scalar_subquery_coercion() + return resolved.element.scalar_subquery() + elif self._role_class.allows_lambda and resolved._is_lambda_element: + return resolved + else: + self._raise_for_expected(original_element, argname, resolved) + + +def _no_text_coercion( + element: Any, + argname: Optional[str] = None, + exc_cls: Type[exc.SQLAlchemyError] = exc.ArgumentError, + extra: Optional[str] = None, + err: Optional[Exception] = None, +) -> NoReturn: + raise exc_cls( + "%(extra)sTextual SQL expression %(expr)r %(argname)sshould be " + "explicitly declared as text(%(expr)r)" + % { + "expr": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "extra": "%s " % extra if extra else "", + } + ) from err + + +class _NoTextCoercion(RoleImpl): + __slots__ = () + + def _literal_coercion(self, element, argname=None, **kw): + if isinstance(element, str) and issubclass( + elements.TextClause, self._role_class + ): + _no_text_coercion(element, argname) + else: + self._raise_for_expected(element, argname) + + +class _CoerceLiterals(RoleImpl): + __slots__ = () + _coerce_consts = False + _coerce_star = False + _coerce_numerics = False + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + def _literal_coercion(self, element, argname=None, **kw): + if isinstance(element, str): + if self._coerce_star and element == "*": + return elements.ColumnClause("*", is_literal=True) + else: + return self._text_coercion(element, argname, **kw) + + if self._coerce_consts: + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + + if self._coerce_numerics and isinstance(element, (numbers.Number)): + return elements.ColumnClause(str(element), is_literal=True) + + self._raise_for_expected(element, argname) + + +class LiteralValueImpl(RoleImpl): + _resolve_literal_only = True + + def _implicit_coercions( + self, + element, + resolved, + argname, + type_=None, + literal_execute=False, + **kw, + ): + if not _is_literal(resolved): + self._raise_for_expected( + element, resolved=resolved, argname=argname, **kw + ) + + return elements.BindParameter( + None, + element, + type_=type_, + unique=True, + literal_execute=literal_execute, + ) + + def _literal_coercion(self, element, argname=None, type_=None, **kw): + return element + + +class _SelectIsNotFrom(RoleImpl): + __slots__ = () + + def _raise_for_expected( + self, + element: Any, + argname: Optional[str] = None, + resolved: Optional[Any] = None, + advice: Optional[str] = None, + code: Optional[str] = None, + err: Optional[Exception] = None, + **kw: Any, + ) -> NoReturn: + if ( + not advice + and isinstance(element, roles.SelectStatementRole) + or isinstance(resolved, roles.SelectStatementRole) + ): + advice = ( + "To create a " + "FROM clause from a %s object, use the .subquery() method." + % (resolved.__class__ if resolved is not None else element,) + ) + code = "89ve" + else: + code = None + + super()._raise_for_expected( + element, + argname=argname, + resolved=resolved, + advice=advice, + code=code, + err=err, + **kw, + ) + # never reached + assert False + + +class HasCacheKeyImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, HasCacheKey): + return element + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class ExecutableOptionImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, ExecutableOption): + return element + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, **kw): + return element + + +class ExpressionElementImpl(_ColumnCoercions, RoleImpl): + __slots__ = () + + def _literal_coercion( + self, element, name=None, type_=None, argname=None, is_crud=False, **kw + ): + if ( + element is None + and not is_crud + and (type_ is None or not type_.should_evaluate_none) + ): + # TODO: there's no test coverage now for the + # "should_evaluate_none" part of this, as outside of "crud" this + # codepath is not normally used except in some special cases + return elements.Null() + else: + try: + return elements.BindParameter( + name, element, type_, unique=True, _is_crud=is_crud + ) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) + + def _raise_for_expected(self, element, argname=None, resolved=None, **kw): + # select uses implicit coercion with warning instead of raising + if isinstance(element, selectable.Values): + advice = ( + "To create a column expression from a VALUES clause, " + "use the .scalar_values() method." + ) + elif isinstance(element, roles.AnonymizedFromClauseRole): + advice = ( + "To create a column expression from a FROM clause row " + "as a whole, use the .table_valued() method." + ) + else: + advice = None + + return super()._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + +class BinaryElementImpl(ExpressionElementImpl, RoleImpl): + __slots__ = () + + def _literal_coercion( + self, element, expr, operator, bindparam_type=None, argname=None, **kw + ): + try: + return expr._bind_param(operator, element, type_=bindparam_type) + except exc.ArgumentError as err: + self._raise_for_expected(element, err=err) + + def _post_coercion(self, resolved, expr, bindparam_type=None, **kw): + if resolved.type._isnull and not expr.type._isnull: + resolved = resolved._with_binary_element_type( + bindparam_type if bindparam_type is not None else expr.type + ) + return resolved + + +class InElementImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.element._is_select_base + ): + self._warn_for_implicit_coercion(resolved) + return self._post_coercion(resolved.element, **kw) + else: + self._warn_for_implicit_coercion(resolved) + return self._post_coercion(resolved.select(), **kw) + else: + self._raise_for_expected(element, argname, resolved) + + def _warn_for_implicit_coercion(self, elem): + util.warn( + "Coercing %s object into a select() for use in IN(); " + "please pass a select() construct explicitly" + % (elem.__class__.__name__) + ) + + def _literal_coercion(self, element, expr, operator, **kw): + if util.is_non_string_iterable(element): + non_literal_expressions: Dict[ + Optional[operators.ColumnOperators], + operators.ColumnOperators, + ] = {} + element = list(element) + for o in element: + if not _is_literal(o): + if not isinstance(o, operators.ColumnOperators): + self._raise_for_expected(element, **kw) + + else: + non_literal_expressions[o] = o + elif o is None: + non_literal_expressions[o] = elements.Null() + + if non_literal_expressions: + return elements.ClauseList( + *[ + ( + non_literal_expressions[o] + if o in non_literal_expressions + else expr._bind_param(operator, o) + ) + for o in element + ] + ) + else: + return expr._bind_param(operator, element, expanding=True) + + else: + self._raise_for_expected(element, **kw) + + def _post_coercion(self, element, expr, operator, **kw): + if element._is_select_base: + # for IN, we are doing scalar_subquery() coercion without + # a warning + return element.scalar_subquery() + elif isinstance(element, elements.ClauseList): + assert not len(element.clauses) == 0 + return element.self_group(against=operator) + + elif isinstance(element, elements.BindParameter): + element = element._clone(maintain_key=True) + element.expanding = True + element.expand_op = operator + + return element + elif isinstance(element, selectable.Values): + return element.scalar_values() + else: + return element + + +class OnClauseImpl(_ColumnCoercions, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _literal_coercion( + self, element, name=None, type_=None, argname=None, is_crud=False, **kw + ): + self._raise_for_expected(element) + + def _post_coercion(self, resolved, original_element=None, **kw): + # this is a hack right now as we want to use coercion on an + # ORM InstrumentedAttribute, but we want to return the object + # itself if it is one, not its clause element. + # ORM context _join and _legacy_join() would need to be improved + # to look for annotations in a clause element form. + if isinstance(original_element, roles.JoinTargetRole): + return original_element + return resolved + + +class WhereHavingImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return _no_text_coercion(element, argname) + + +class StatementOptionImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements.TextClause(element) + + +class ColumnArgumentImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + +class ColumnArgumentOrKeyImpl(_ReturnsStringKey, RoleImpl): + __slots__ = () + + +class StrAsPlainColumnImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + def _text_coercion(self, element, argname=None): + return elements.ColumnClause(element) + + +class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + return elements._textual_label_reference(element) + + +class OrderByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _post_coercion(self, resolved, **kw): + if ( + isinstance(resolved, self._role_class) + and resolved._order_by_label_element is not None + ): + return elements._label_reference(resolved) + else: + return resolved + + +class GroupByImpl(ByOfImpl, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if is_from_clause(resolved): + return elements.ClauseList(*resolved.c) + else: + return resolved + + +class DMLColumnImpl(_ReturnsStringKey, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, as_key=False, **kw): + if as_key: + return element.key + else: + return element + + +class ConstExprImpl(RoleImpl): + __slots__ = () + + def _literal_coercion(self, element, argname=None, **kw): + if element is None: + return elements.Null() + elif element is False: + return elements.False_() + elif element is True: + return elements.True_() + else: + self._raise_for_expected(element, argname) + + +class TruncatedLabelImpl(_StringOnly, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(element, str): + return resolved + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, argname=None, **kw): + """coerce the given value to :class:`._truncated_label`. + + Existing :class:`._truncated_label` and + :class:`._anonymous_label` objects are passed + unchanged. + """ + + if isinstance(element, elements._truncated_label): + return element + else: + return elements._truncated_label(element) + + +class DDLExpressionImpl(_Deannotate, _CoerceLiterals, RoleImpl): + __slots__ = () + + _coerce_consts = True + + def _text_coercion(self, element, argname=None): + # see #5754 for why we can't easily deprecate this coercion. + # essentially expressions like postgresql_where would have to be + # text() as they come back from reflection and we don't want to + # have text() elements wired into the inspection dictionaries. + return elements.TextClause(element) + + +class DDLConstraintColumnImpl(_Deannotate, _ReturnsStringKey, RoleImpl): + __slots__ = () + + +class DDLReferredColumnImpl(DDLConstraintColumnImpl): + __slots__ = () + + +class LimitOffsetImpl(RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved is None: + return None + else: + self._raise_for_expected(element, argname, resolved) + + def _literal_coercion(self, element, name, type_, **kw): + if element is None: + return None + else: + value = util.asint(element) + return selectable._OffsetLimitParam( + name, value, type_=type_, unique=True + ) + + +class LabeledColumnExprImpl(ExpressionElementImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if isinstance(resolved, roles.ExpressionElementRole): + return resolved.label(None) + else: + new = super()._implicit_coercions( + element, resolved, argname=argname, **kw + ) + if isinstance(new, roles.ExpressionElementRole): + return new.label(None) + else: + self._raise_for_expected(element, argname, resolved) + + +class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): + __slots__ = () + + _coerce_consts = True + _coerce_numerics = True + _coerce_star = True + + _guess_straight_column = re.compile(r"^\w\S*$", re.I) + + def _raise_for_expected( + self, element, argname=None, resolved=None, advice=None, **kw + ): + if not advice and isinstance(element, list): + advice = ( + f"Did you mean to say select(" + f"{', '.join(repr(e) for e in element)})?" + ) + + return super()._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + def _text_coercion(self, element, argname=None): + element = str(element) + + guess_is_literal = not self._guess_straight_column.match(element) + raise exc.ArgumentError( + "Textual column expression %(column)r %(argname)sshould be " + "explicitly declared with text(%(column)r), " + "or use %(literal_column)s(%(column)r) " + "for more specificity" + % { + "column": util.ellipses_string(element), + "argname": "for argument %s" % (argname,) if argname else "", + "literal_column": ( + "literal_column" if guess_is_literal else "column" + ), + } + ) + + +class ReturnsRowsImpl(RoleImpl): + __slots__ = () + + +class StatementImpl(_CoerceLiterals, RoleImpl): + __slots__ = () + + def _post_coercion(self, resolved, original_element, argname=None, **kw): + if resolved is not original_element and not isinstance( + original_element, str + ): + # use same method as Connection uses; this will later raise + # ObjectNotExecutableError + try: + original_element._execute_on_connection + except AttributeError: + util.warn_deprecated( + "Object %r should not be used directly in a SQL statement " + "context, such as passing to methods such as " + "session.execute(). This usage will be disallowed in a " + "future release. " + "Please use Core select() / update() / delete() etc. " + "with Session.execute() and other statement execution " + "methods." % original_element, + "1.4", + ) + + return resolved + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_lambda_element: + return resolved + else: + return super()._implicit_coercions( + element, resolved, argname=argname, **kw + ) + + +class SelectStatementImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_text_clause: + return resolved.columns() + else: + self._raise_for_expected(element, argname, resolved) + + +class HasCTEImpl(ReturnsRowsImpl): + __slots__ = () + + +class IsCTEImpl(RoleImpl): + __slots__ = () + + +class JoinTargetImpl(RoleImpl): + __slots__ = () + + _skip_clauseelement_for_target_match = True + + def _literal_coercion(self, element, argname=None, **kw): + self._raise_for_expected(element, argname) + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + legacy: bool = False, + **kw: Any, + ) -> Any: + if isinstance(element, roles.JoinTargetRole): + # note that this codepath no longer occurs as of + # #6550, unless JoinTargetImpl._skip_clauseelement_for_target_match + # were set to False. + return element + elif legacy and resolved._is_select_base: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + # TODO: doing _implicit_subquery here causes tests to fail, + # how was this working before? probably that ORM + # join logic treated it as a select and subquery would happen + # in _ORMJoin->Join + return resolved + else: + self._raise_for_expected(element, argname, resolved) + + +class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = True, + **kw: Any, + ) -> Any: + if resolved._is_select_base: + if explicit_subquery: + return resolved.subquery() + elif allow_select: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT " + "constructs into FROM clauses is deprecated; please call " + ".subquery() on any Core select or ORM Query object in " + "order to produce a subquery object.", + version="1.4", + ) + return resolved._implicit_subquery + elif resolved._is_text_clause: + return resolved + else: + self._raise_for_expected(element, argname, resolved) + + def _post_coercion(self, element, deannotate=False, **kw): + if deannotate: + return element._deannotate() + else: + return element + + +class StrictFromClauseImpl(FromClauseImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + explicit_subquery: bool = False, + allow_select: bool = False, + **kw: Any, + ) -> Any: + if resolved._is_select_base and allow_select: + util.warn_deprecated( + "Implicit coercion of SELECT and textual SELECT constructs " + "into FROM clauses is deprecated; please call .subquery() " + "on any Core select or ORM Query object in order to produce a " + "subquery object.", + version="1.4", + ) + return resolved._implicit_subquery + else: + self._raise_for_expected(element, argname, resolved) + + +class AnonymizedFromClauseImpl(StrictFromClauseImpl): + __slots__ = () + + def _post_coercion(self, element, flat=False, name=None, **kw): + assert name is None + + return element._anonymous_fromclause(flat=flat) + + +class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, **kw): + if "dml_table" in element._annotations: + return element._annotations["dml_table"] + else: + return element + + +class DMLSelectImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _implicit_coercions( + self, + element: Any, + resolved: Any, + argname: Optional[str] = None, + **kw: Any, + ) -> Any: + if resolved._is_from_clause: + if ( + isinstance(resolved, selectable.Alias) + and resolved.element._is_select_base + ): + return resolved.element + else: + return resolved.select() + else: + self._raise_for_expected(element, argname, resolved) + + +class CompoundElementImpl(_NoTextCoercion, RoleImpl): + __slots__ = () + + def _raise_for_expected(self, element, argname=None, resolved=None, **kw): + if isinstance(element, roles.FromClauseRole): + if element._is_subquery: + advice = ( + "Use the plain select() object without " + "calling .subquery() or .alias()." + ) + else: + advice = ( + "To SELECT from any FROM clause, use the .select() method." + ) + else: + advice = None + return super()._raise_for_expected( + element, argname=argname, resolved=resolved, advice=advice, **kw + ) + + +_impl_lookup = {} + + +for name in dir(roles): + cls = getattr(roles, name) + if name.endswith("Role"): + name = name.replace("Role", "Impl") + if name in globals(): + impl = globals()[name](cls) + _impl_lookup[cls] = impl + +if not TYPE_CHECKING: + ee_impl = _impl_lookup[roles.ExpressionElementRole] + + for py_type in (int, bool, str, float): + _impl_lookup[roles.ExpressionElementRole[py_type]] = ee_impl diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/compiler.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/compiler.py new file mode 100644 index 0000000..c354ba8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/compiler.py @@ -0,0 +1,7811 @@ +# sql/compiler.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Base SQL and DDL compiler implementations. + +Classes provided include: + +:class:`.compiler.SQLCompiler` - renders SQL +strings + +:class:`.compiler.DDLCompiler` - renders DDL +(data definition language) strings + +:class:`.compiler.GenericTypeCompiler` - renders +type specification strings. + +To generate user-defined SQL strings, see +:doc:`/ext/compiler`. + +""" +from __future__ import annotations + +import collections +import collections.abc as collections_abc +import contextlib +from enum import IntEnum +import functools +import itertools +import operator +import re +from time import perf_counter +import typing +from typing import Any +from typing import Callable +from typing import cast +from typing import ClassVar +from typing import Dict +from typing import FrozenSet +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import MutableMapping +from typing import NamedTuple +from typing import NoReturn +from typing import Optional +from typing import Pattern +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import Union + +from . import base +from . import coercions +from . import crud +from . import elements +from . import functions +from . import operators +from . import roles +from . import schema +from . import selectable +from . import sqltypes +from . import util as sql_util +from ._typing import is_column_element +from ._typing import is_dml +from .base import _de_clone +from .base import _from_objects +from .base import _NONE_NAME +from .base import _SentinelDefaultCharacterization +from .base import Executable +from .base import NO_ARG +from .elements import ClauseElement +from .elements import quoted_name +from .schema import Column +from .sqltypes import TupleType +from .type_api import TypeEngine +from .visitors import prefix_anon_map +from .visitors import Visitable +from .. import exc +from .. import util +from ..util import FastIntFlag +from ..util.typing import Literal +from ..util.typing import Protocol +from ..util.typing import TypedDict + +if typing.TYPE_CHECKING: + from .annotation import _AnnotationDict + from .base import _AmbiguousTableNameMap + from .base import CompileState + from .cache_key import CacheKey + from .ddl import ExecutableDDLElement + from .dml import Insert + from .dml import UpdateBase + from .dml import ValuesBase + from .elements import _truncated_label + from .elements import BindParameter + from .elements import ColumnClause + from .elements import ColumnElement + from .elements import Label + from .functions import Function + from .schema import Table + from .selectable import AliasedReturnsRows + from .selectable import CompoundSelectState + from .selectable import CTE + from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import ReturnsRows + from .selectable import Select + from .selectable import SelectState + from .type_api import _BindProcessorType + from ..engine.cursor import CursorResultMetaData + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _DBAPIAnyExecuteParams + from ..engine.interfaces import _DBAPIMultiExecuteParams + from ..engine.interfaces import _DBAPISingleExecuteParams + from ..engine.interfaces import _ExecuteOptions + from ..engine.interfaces import _GenericSetInputSizesType + from ..engine.interfaces import _MutableCoreSingleExecuteParams + from ..engine.interfaces import Dialect + from ..engine.interfaces import SchemaTranslateMapType + +_FromHintsType = Dict["FromClause", str] + +RESERVED_WORDS = { + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "authorization", + "between", + "binary", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "cross", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "for", + "foreign", + "freeze", + "from", + "full", + "grant", + "group", + "having", + "ilike", + "in", + "initially", + "inner", + "intersect", + "into", + "is", + "isnull", + "join", + "leading", + "left", + "like", + "limit", + "localtime", + "localtimestamp", + "natural", + "new", + "not", + "notnull", + "null", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "outer", + "overlaps", + "placing", + "primary", + "references", + "right", + "select", + "session_user", + "set", + "similar", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "verbose", + "when", + "where", +} + +LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I) +LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I) +ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"]) + +FK_ON_DELETE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_ON_UPDATE = re.compile( + r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I +) +FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I) +BIND_PARAMS = re.compile(r"(? ", + operators.ge: " >= ", + operators.eq: " = ", + operators.is_distinct_from: " IS DISTINCT FROM ", + operators.is_not_distinct_from: " IS NOT DISTINCT FROM ", + operators.concat_op: " || ", + operators.match_op: " MATCH ", + operators.not_match_op: " NOT MATCH ", + operators.in_op: " IN ", + operators.not_in_op: " NOT IN ", + operators.comma_op: ", ", + operators.from_: " FROM ", + operators.as_: " AS ", + operators.is_: " IS ", + operators.is_not: " IS NOT ", + operators.collate: " COLLATE ", + # unary + operators.exists: "EXISTS ", + operators.distinct_op: "DISTINCT ", + operators.inv: "NOT ", + operators.any_op: "ANY ", + operators.all_op: "ALL ", + # modifiers + operators.desc_op: " DESC", + operators.asc_op: " ASC", + operators.nulls_first_op: " NULLS FIRST", + operators.nulls_last_op: " NULLS LAST", + # bitwise + operators.bitwise_xor_op: " ^ ", + operators.bitwise_or_op: " | ", + operators.bitwise_and_op: " & ", + operators.bitwise_not_op: "~", + operators.bitwise_lshift_op: " << ", + operators.bitwise_rshift_op: " >> ", +} + +FUNCTIONS: Dict[Type[Function[Any]], str] = { + functions.coalesce: "coalesce", + functions.current_date: "CURRENT_DATE", + functions.current_time: "CURRENT_TIME", + functions.current_timestamp: "CURRENT_TIMESTAMP", + functions.current_user: "CURRENT_USER", + functions.localtime: "LOCALTIME", + functions.localtimestamp: "LOCALTIMESTAMP", + functions.random: "random", + functions.sysdate: "sysdate", + functions.session_user: "SESSION_USER", + functions.user: "USER", + functions.cube: "CUBE", + functions.rollup: "ROLLUP", + functions.grouping_sets: "GROUPING SETS", +} + + +EXTRACT_MAP = { + "month": "month", + "day": "day", + "year": "year", + "second": "second", + "hour": "hour", + "doy": "doy", + "minute": "minute", + "quarter": "quarter", + "dow": "dow", + "week": "week", + "epoch": "epoch", + "milliseconds": "milliseconds", + "microseconds": "microseconds", + "timezone_hour": "timezone_hour", + "timezone_minute": "timezone_minute", +} + +COMPOUND_KEYWORDS = { + selectable._CompoundSelectKeyword.UNION: "UNION", + selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL", + selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT", + selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL", + selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT", + selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL", +} + + +class ResultColumnsEntry(NamedTuple): + """Tracks a column expression that is expected to be represented + in the result rows for this statement. + + This normally refers to the columns clause of a SELECT statement + but may also refer to a RETURNING clause, as well as for dialect-specific + emulations. + + """ + + keyname: str + """string name that's expected in cursor.description""" + + name: str + """column name, may be labeled""" + + objects: Tuple[Any, ...] + """sequence of objects that should be able to locate this column + in a RowMapping. This is typically string names and aliases + as well as Column objects. + + """ + + type: TypeEngine[Any] + """Datatype to be associated with this column. This is where + the "result processing" logic directly links the compiled statement + to the rows that come back from the cursor. + + """ + + +class _ResultMapAppender(Protocol): + def __call__( + self, + keyname: str, + name: str, + objects: Sequence[Any], + type_: TypeEngine[Any], + ) -> None: ... + + +# integer indexes into ResultColumnsEntry used by cursor.py. +# some profiling showed integer access faster than named tuple +RM_RENDERED_NAME: Literal[0] = 0 +RM_NAME: Literal[1] = 1 +RM_OBJECTS: Literal[2] = 2 +RM_TYPE: Literal[3] = 3 + + +class _BaseCompilerStackEntry(TypedDict): + asfrom_froms: Set[FromClause] + correlate_froms: Set[FromClause] + selectable: ReturnsRows + + +class _CompilerStackEntry(_BaseCompilerStackEntry, total=False): + compile_state: CompileState + need_result_map_for_nested: bool + need_result_map_for_compound: bool + select_0: ReturnsRows + insert_from_select: Select[Any] + + +class ExpandedState(NamedTuple): + """represents state to use when producing "expanded" and + "post compile" bound parameters for a statement. + + "expanded" parameters are parameters that are generated at + statement execution time to suit a number of parameters passed, the most + prominent example being the individual elements inside of an IN expression. + + "post compile" parameters are parameters where the SQL literal value + will be rendered into the SQL statement at execution time, rather than + being passed as separate parameters to the driver. + + To create an :class:`.ExpandedState` instance, use the + :meth:`.SQLCompiler.construct_expanded_state` method on any + :class:`.SQLCompiler` instance. + + """ + + statement: str + """String SQL statement with parameters fully expanded""" + + parameters: _CoreSingleExecuteParams + """Parameter dictionary with parameters fully expanded. + + For a statement that uses named parameters, this dictionary will map + exactly to the names in the statement. For a statement that uses + positional parameters, the :attr:`.ExpandedState.positional_parameters` + will yield a tuple with the positional parameter set. + + """ + + processors: Mapping[str, _BindProcessorType[Any]] + """mapping of bound value processors""" + + positiontup: Optional[Sequence[str]] + """Sequence of string names indicating the order of positional + parameters""" + + parameter_expansion: Mapping[str, List[str]] + """Mapping representing the intermediary link from original parameter + name to list of "expanded" parameter names, for those parameters that + were expanded.""" + + @property + def positional_parameters(self) -> Tuple[Any, ...]: + """Tuple of positional parameters, for statements that were compiled + using a positional paramstyle. + + """ + if self.positiontup is None: + raise exc.InvalidRequestError( + "statement does not use a positional paramstyle" + ) + return tuple(self.parameters[key] for key in self.positiontup) + + @property + def additional_parameters(self) -> _CoreSingleExecuteParams: + """synonym for :attr:`.ExpandedState.parameters`.""" + return self.parameters + + +class _InsertManyValues(NamedTuple): + """represents state to use for executing an "insertmanyvalues" statement. + + The primary consumers of this object are the + :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and + :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods. + + .. versionadded:: 2.0 + + """ + + is_default_expr: bool + """if True, the statement is of the form + ``INSERT INTO TABLE DEFAULT VALUES``, and can't be rewritten as a "batch" + + """ + + single_values_expr: str + """The rendered "values" clause of the INSERT statement. + + This is typically the parenthesized section e.g. "(?, ?, ?)" or similar. + The insertmanyvalues logic uses this string as a search and replace + target. + + """ + + insert_crud_params: List[crud._CrudParamElementStr] + """List of Column / bind names etc. used while rewriting the statement""" + + num_positional_params_counted: int + """the number of bound parameters in a single-row statement. + + This count may be larger or smaller than the actual number of columns + targeted in the INSERT, as it accommodates for SQL expressions + in the values list that may have zero or more parameters embedded + within them. + + This count is part of what's used to organize rewritten parameter lists + when batching. + + """ + + sort_by_parameter_order: bool = False + """if the deterministic_returnined_order parameter were used on the + insert. + + All of the attributes following this will only be used if this is True. + + """ + + includes_upsert_behaviors: bool = False + """if True, we have to accommodate for upsert behaviors. + + This will in some cases downgrade "insertmanyvalues" that requests + deterministic ordering. + + """ + + sentinel_columns: Optional[Sequence[Column[Any]]] = None + """List of sentinel columns that were located. + + This list is only here if the INSERT asked for + sort_by_parameter_order=True, + and dialect-appropriate sentinel columns were located. + + .. versionadded:: 2.0.10 + + """ + + num_sentinel_columns: int = 0 + """how many sentinel columns are in the above list, if any. + + This is the same as + ``len(sentinel_columns) if sentinel_columns is not None else 0`` + + """ + + sentinel_param_keys: Optional[Sequence[str]] = None + """parameter str keys in each param dictionary / tuple + that would link to the client side "sentinel" values for that row, which + we can use to match up parameter sets to result rows. + + This is only present if sentinel_columns is present and the INSERT + statement actually refers to client side values for these sentinel + columns. + + .. versionadded:: 2.0.10 + + .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys + only, used against the "compiled parameteters" collection before + the parameters were converted by bound parameter processors + + """ + + implicit_sentinel: bool = False + """if True, we have exactly one sentinel column and it uses a server side + value, currently has to generate an incrementing integer value. + + The dialect in question would have asserted that it supports receiving + these values back and sorting on that value as a means of guaranteeing + correlation with the incoming parameter list. + + .. versionadded:: 2.0.10 + + """ + + embed_values_counter: bool = False + """Whether to embed an incrementing integer counter in each parameter + set within the VALUES clause as parameters are batched over. + + This is only used for a specific INSERT..SELECT..VALUES..RETURNING syntax + where a subquery is used to produce value tuples. Current support + includes PostgreSQL, Microsoft SQL Server. + + .. versionadded:: 2.0.10 + + """ + + +class _InsertManyValuesBatch(NamedTuple): + """represents an individual batch SQL statement for insertmanyvalues. + + This is passed through the + :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and + :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods out + to the :class:`.Connection` within the + :meth:`.Connection._exec_insertmany_context` method. + + .. versionadded:: 2.0.10 + + """ + + replaced_statement: str + replaced_parameters: _DBAPIAnyExecuteParams + processed_setinputsizes: Optional[_GenericSetInputSizesType] + batch: Sequence[_DBAPISingleExecuteParams] + sentinel_values: Sequence[Tuple[Any, ...]] + current_batch_size: int + batchnum: int + total_batches: int + rows_sorted: bool + is_downgraded: bool + + +class InsertmanyvaluesSentinelOpts(FastIntFlag): + """bitflag enum indicating styles of PK defaults + which can work as implicit sentinel columns + + """ + + NOT_SUPPORTED = 1 + AUTOINCREMENT = 2 + IDENTITY = 4 + SEQUENCE = 8 + + ANY_AUTOINCREMENT = AUTOINCREMENT | IDENTITY | SEQUENCE + _SUPPORTED_OR_NOT = NOT_SUPPORTED | ANY_AUTOINCREMENT + + USE_INSERT_FROM_SELECT = 16 + RENDER_SELECT_COL_CASTS = 64 + + +class CompilerState(IntEnum): + COMPILING = 0 + """statement is present, compilation phase in progress""" + + STRING_APPLIED = 1 + """statement is present, string form of the statement has been applied. + + Additional processors by subclasses may still be pending. + + """ + + NO_STATEMENT = 2 + """compiler does not have a statement to compile, is used + for method access""" + + +class Linting(IntEnum): + """represent preferences for the 'SQL linting' feature. + + this feature currently includes support for flagging cartesian products + in SQL statements. + + """ + + NO_LINTING = 0 + "Disable all linting." + + COLLECT_CARTESIAN_PRODUCTS = 1 + """Collect data on FROMs and cartesian products and gather into + 'self.from_linter'""" + + WARN_LINTING = 2 + "Emit warnings for linters that find problems" + + FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING + """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS + and WARN_LINTING""" + + +NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple( + Linting +) + + +class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + """represents current state for the "cartesian product" detection + feature.""" + + def lint(self, start=None): + froms = self.froms + if not froms: + return None, None + + edges = set(self.edges) + the_rest = set(froms) + + if start is not None: + start_with = start + the_rest.remove(start_with) + else: + start_with = the_rest.pop() + + stack = collections.deque([start_with]) + + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + + # comparison of nodes in edges here is based on hash equality, as + # there are "annotated" elements that match the non-annotated ones. + # to remove the need for in-python hash() calls, use native + # containment routines (e.g. "node in edge", "edge.index(node)") + to_remove = {edge for edge in edges if node in edge} + + # appendleft the node in each edge that is not + # the one that matched. + stack.extendleft(edge[not edge.index(node)] for edge in to_remove) + edges.difference_update(to_remove) + + # FROMS left over? boom + if the_rest: + return the_rest, start_with + else: + return None, None + + def warn(self, stmt_type="SELECT"): + the_rest, start_with = self.lint() + + # FROMS left over? boom + if the_rest: + froms = the_rest + if froms: + template = ( + "{stmt_type} statement has a cartesian product between " + "FROM element(s) {froms} and " + 'FROM element "{start}". Apply join condition(s) ' + "between each element to resolve." + ) + froms_str = ", ".join( + f'"{self.froms[from_]}"' for from_ in froms + ) + message = template.format( + stmt_type=stmt_type, + froms=froms_str, + start=self.froms[start_with], + ) + + util.warn(message) + + +class Compiled: + """Represent a compiled SQL or DDL expression. + + The ``__str__`` method of the ``Compiled`` object should produce + the actual text of the statement. ``Compiled`` objects are + specific to their underlying database dialect, and also may + or may not be specific to the columns referenced within a + particular set of bind parameters. In no case should the + ``Compiled`` object be dependent on the actual values of those + bind parameters, even though it may reference those values as + defaults. + """ + + statement: Optional[ClauseElement] = None + "The statement to compile." + string: str = "" + "The string representation of the ``statement``" + + state: CompilerState + """description of the compiler's state""" + + is_sql = False + is_ddl = False + + _cached_metadata: Optional[CursorResultMetaData] = None + + _result_columns: Optional[List[ResultColumnsEntry]] = None + + schema_translate_map: Optional[SchemaTranslateMapType] = None + + execution_options: _ExecuteOptions = util.EMPTY_DICT + """ + Execution options propagated from the statement. In some cases, + sub-elements of the statement can modify these. + """ + + preparer: IdentifierPreparer + + _annotations: _AnnotationDict = util.EMPTY_DICT + + compile_state: Optional[CompileState] = None + """Optional :class:`.CompileState` object that maintains additional + state used by the compiler. + + Major executable objects such as :class:`_expression.Insert`, + :class:`_expression.Update`, :class:`_expression.Delete`, + :class:`_expression.Select` will generate this + state when compiled in order to calculate additional information about the + object. For the top level object that is to be executed, the state can be + stored here where it can also have applicability towards result set + processing. + + .. versionadded:: 1.4 + + """ + + dml_compile_state: Optional[CompileState] = None + """Optional :class:`.CompileState` assigned at the same point that + .isinsert, .isupdate, or .isdelete is assigned. + + This will normally be the same object as .compile_state, with the + exception of cases like the :class:`.ORMFromStatementCompileState` + object. + + .. versionadded:: 1.4.40 + + """ + + cache_key: Optional[CacheKey] = None + """The :class:`.CacheKey` that was generated ahead of creating this + :class:`.Compiled` object. + + This is used for routines that need access to the original + :class:`.CacheKey` instance generated when the :class:`.Compiled` + instance was first cached, typically in order to reconcile + the original list of :class:`.BindParameter` objects with a + per-statement list that's generated on each call. + + """ + + _gen_time: float + """Generation time of this :class:`.Compiled`, used for reporting + cache stats.""" + + def __init__( + self, + dialect: Dialect, + statement: Optional[ClauseElement], + schema_translate_map: Optional[SchemaTranslateMapType] = None, + render_schema_translate: bool = False, + compile_kwargs: Mapping[str, Any] = util.immutabledict(), + ): + """Construct a new :class:`.Compiled` object. + + :param dialect: :class:`.Dialect` to compile against. + + :param statement: :class:`_expression.ClauseElement` to be compiled. + + :param schema_translate_map: dictionary of schema names to be + translated when forming the resultant SQL + + .. seealso:: + + :ref:`schema_translating` + + :param compile_kwargs: additional kwargs that will be + passed to the initial call to :meth:`.Compiled.process`. + + + """ + self.dialect = dialect + self.preparer = self.dialect.identifier_preparer + if schema_translate_map: + self.schema_translate_map = schema_translate_map + self.preparer = self.preparer._with_schema_translate( + schema_translate_map + ) + + if statement is not None: + self.state = CompilerState.COMPILING + self.statement = statement + self.can_execute = statement.supports_execution + self._annotations = statement._annotations + if self.can_execute: + if TYPE_CHECKING: + assert isinstance(statement, Executable) + self.execution_options = statement._execution_options + self.string = self.process(self.statement, **compile_kwargs) + + if render_schema_translate: + self.string = self.preparer._render_schema_translates( + self.string, schema_translate_map + ) + + self.state = CompilerState.STRING_APPLIED + else: + self.state = CompilerState.NO_STATEMENT + + self._gen_time = perf_counter() + + def __init_subclass__(cls) -> None: + cls._init_compiler_cls() + return super().__init_subclass__() + + @classmethod + def _init_compiler_cls(cls): + pass + + def _execute_on_connection( + self, connection, distilled_params, execution_options + ): + if self.can_execute: + return connection._execute_compiled( + self, distilled_params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self.statement) + + def visit_unsupported_compilation(self, element, err, **kw): + raise exc.UnsupportedCompilationError(self, type(element)) from err + + @property + def sql_compiler(self): + """Return a Compiled that is capable of processing SQL expressions. + + If this compiler is one, it would likely just return 'self'. + + """ + + raise NotImplementedError() + + def process(self, obj: Visitable, **kwargs: Any) -> str: + return obj._compiler_dispatch(self, **kwargs) + + def __str__(self) -> str: + """Return the string text of the generated SQL or DDL.""" + + if self.state is CompilerState.STRING_APPLIED: + return self.string + else: + return "" + + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + escape_names: bool = True, + ) -> Optional[_MutableCoreSingleExecuteParams]: + """Return the bind params for this compiled object. + + :param params: a dict of string/object pairs whose values will + override bind values compiled in to the + statement. + """ + + raise NotImplementedError() + + @property + def params(self): + """Return the bind params for this compiled object.""" + return self.construct_params() + + +class TypeCompiler(util.EnsureKWArg): + """Produces DDL specification for TypeEngine objects.""" + + ensure_kwarg = r"visit_\w+" + + def __init__(self, dialect: Dialect): + self.dialect = dialect + + def process(self, type_: TypeEngine[Any], **kw: Any) -> str: + if ( + type_._variant_mapping + and self.dialect.name in type_._variant_mapping + ): + type_ = type_._variant_mapping[self.dialect.name] + return type_._compiler_dispatch(self, **kw) + + def visit_unsupported_compilation( + self, element: Any, err: Exception, **kw: Any + ) -> NoReturn: + raise exc.UnsupportedCompilationError(self, element) from err + + +# this was a Visitable, but to allow accurate detection of +# column elements this is actually a column element +class _CompileLabel( + roles.BinaryElementRole[Any], elements.CompilerColumnElement +): + """lightweight label object which acts as an expression.Label.""" + + __visit_name__ = "label" + __slots__ = "element", "name", "_alt_names" + + def __init__(self, col, name, alt_names=()): + self.element = col + self.name = name + self._alt_names = (col,) + alt_names + + @property + def proxy_set(self): + return self.element.proxy_set + + @property + def type(self): + return self.element.type + + def self_group(self, **kw): + return self + + +class ilike_case_insensitive( + roles.BinaryElementRole[Any], elements.CompilerColumnElement +): + """produce a wrapping element for a case-insensitive portion of + an ILIKE construct. + + The construct usually renders the ``lower()`` function, but on + PostgreSQL will pass silently with the assumption that "ILIKE" + is being used. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "ilike_case_insensitive_operand" + __slots__ = "element", "comparator" + + def __init__(self, element): + self.element = element + self.comparator = element.comparator + + @property + def proxy_set(self): + return self.element.proxy_set + + @property + def type(self): + return self.element.type + + def self_group(self, **kw): + return self + + def _with_binary_element_type(self, type_): + return ilike_case_insensitive( + self.element._with_binary_element_type(type_) + ) + + +class SQLCompiler(Compiled): + """Default implementation of :class:`.Compiled`. + + Compiles :class:`_expression.ClauseElement` objects into SQL strings. + + """ + + extract_map = EXTRACT_MAP + + bindname_escape_characters: ClassVar[Mapping[str, str]] = ( + util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } + ) + ) + """A mapping (e.g. dict or similar) containing a lookup of + characters keyed to replacement characters which will be applied to all + 'bind names' used in SQL statements as a form of 'escaping'; the given + characters are replaced entirely with the 'replacement' character when + rendered in the SQL statement, and a similar translation is performed + on the incoming names used in parameter dictionaries passed to methods + like :meth:`_engine.Connection.execute`. + + This allows bound parameter names used in :func:`_sql.bindparam` and + other constructs to have any arbitrary characters present without any + concern for characters that aren't allowed at all on the target database. + + Third party dialects can establish their own dictionary here to replace the + default mapping, which will ensure that the particular characters in the + mapping will never appear in a bound parameter name. + + The dictionary is evaluated at **class creation time**, so cannot be + modified at runtime; it must be present on the class when the class + is first declared. + + Note that for dialects that have additional bound parameter rules such + as additional restrictions on leading characters, the + :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented. + See the cx_Oracle compiler for an example of this. + + .. versionadded:: 2.0.0rc1 + + """ + + _bind_translate_re: ClassVar[Pattern[str]] + _bind_translate_chars: ClassVar[Mapping[str, str]] + + is_sql = True + + compound_keywords = COMPOUND_KEYWORDS + + isdelete: bool = False + isinsert: bool = False + isupdate: bool = False + """class-level defaults which can be set at the instance + level to define if this Compiled instance represents + INSERT/UPDATE/DELETE + """ + + postfetch: Optional[List[Column[Any]]] + """list of columns that can be post-fetched after INSERT or UPDATE to + receive server-updated values""" + + insert_prefetch: Sequence[Column[Any]] = () + """list of columns for which default values should be evaluated before + an INSERT takes place""" + + update_prefetch: Sequence[Column[Any]] = () + """list of columns for which onupdate default values should be evaluated + before an UPDATE takes place""" + + implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None + """list of "implicit" returning columns for a toplevel INSERT or UPDATE + statement, used to receive newly generated values of columns. + + .. versionadded:: 2.0 ``implicit_returning`` replaces the previous + ``returning`` collection, which was not a generalized RETURNING + collection and instead was in fact specific to the "implicit returning" + feature. + + """ + + isplaintext: bool = False + + binds: Dict[str, BindParameter[Any]] + """a dictionary of bind parameter keys to BindParameter instances.""" + + bind_names: Dict[BindParameter[Any], str] + """a dictionary of BindParameter instances to "compiled" names + that are actually present in the generated SQL""" + + stack: List[_CompilerStackEntry] + """major statements such as SELECT, INSERT, UPDATE, DELETE are + tracked in this stack using an entry format.""" + + returning_precedes_values: bool = False + """set to True classwide to generate RETURNING + clauses before the VALUES or WHERE clause (i.e. MSSQL) + """ + + render_table_with_column_in_update_from: bool = False + """set to True classwide to indicate the SET clause + in a multi-table UPDATE statement should qualify + columns with the table name (i.e. MySQL only) + """ + + ansi_bind_rules: bool = False + """SQL 92 doesn't allow bind parameters to be used + in the columns clause of a SELECT, nor does it allow + ambiguous expressions like "? = ?". A compiler + subclass can set this flag to False if the target + driver/DB enforces this + """ + + bindtemplate: str + """template to render bound parameters based on paramstyle.""" + + compilation_bindtemplate: str + """template used by compiler to render parameters before positional + paramstyle application""" + + _numeric_binds_identifier_char: str + """Character that's used to as the identifier of a numerical bind param. + For example if this char is set to ``$``, numerical binds will be rendered + in the form ``$1, $2, $3``. + """ + + _result_columns: List[ResultColumnsEntry] + """relates label names in the final SQL to a tuple of local + column/label name, ColumnElement object (if any) and + TypeEngine. CursorResult uses this for type processing and + column targeting""" + + _textual_ordered_columns: bool = False + """tell the result object that the column names as rendered are important, + but they are also "ordered" vs. what is in the compiled object here. + + As of 1.4.42 this condition is only present when the statement is a + TextualSelect, e.g. text("....").columns(...), where it is required + that the columns are considered positionally and not by name. + + """ + + _ad_hoc_textual: bool = False + """tell the result that we encountered text() or '*' constructs in the + middle of the result columns, but we also have compiled columns, so + if the number of columns in cursor.description does not match how many + expressions we have, that means we can't rely on positional at all and + should match on name. + + """ + + _ordered_columns: bool = True + """ + if False, means we can't be sure the list of entries + in _result_columns is actually the rendered order. Usually + True unless using an unordered TextualSelect. + """ + + _loose_column_name_matching: bool = False + """tell the result object that the SQL statement is textual, wants to match + up to Column objects, and may be using the ._tq_label in the SELECT rather + than the base name. + + """ + + _numeric_binds: bool = False + """ + True if paramstyle is "numeric". This paramstyle is trickier than + all the others. + + """ + + _render_postcompile: bool = False + """ + whether to render out POSTCOMPILE params during the compile phase. + + This attribute is used only for end-user invocation of stmt.compile(); + it's never used for actual statement execution, where instead the + dialect internals access and render the internal postcompile structure + directly. + + """ + + _post_compile_expanded_state: Optional[ExpandedState] = None + """When render_postcompile is used, the ``ExpandedState`` used to create + the "expanded" SQL is assigned here, and then used by the ``.params`` + accessor and ``.construct_params()`` methods for their return values. + + .. versionadded:: 2.0.0rc1 + + """ + + _pre_expanded_string: Optional[str] = None + """Stores the original string SQL before 'post_compile' is applied, + for cases where 'post_compile' were used. + + """ + + _pre_expanded_positiontup: Optional[List[str]] = None + + _insertmanyvalues: Optional[_InsertManyValues] = None + + _insert_crud_params: Optional[crud._CrudParamSequence] = None + + literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset() + """bindparameter objects that are rendered as literal values at statement + execution time. + + """ + + post_compile_params: FrozenSet[BindParameter[Any]] = frozenset() + """bindparameter objects that are rendered as bound parameter placeholders + at statement execution time. + + """ + + escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT + """Late escaping of bound parameter names that has to be converted + to the original name when looking in the parameter dictionary. + + """ + + has_out_parameters = False + """if True, there are bindparam() objects that have the isoutparam + flag set.""" + + postfetch_lastrowid = False + """if True, and this in insert, use cursor.lastrowid to populate + result.inserted_primary_key. """ + + _cache_key_bind_match: Optional[ + Tuple[ + Dict[ + BindParameter[Any], + List[BindParameter[Any]], + ], + Dict[ + str, + BindParameter[Any], + ], + ] + ] = None + """a mapping that will relate the BindParameter object we compile + to those that are part of the extracted collection of parameters + in the cache key, if we were given a cache key. + + """ + + positiontup: Optional[List[str]] = None + """for a compiled construct that uses a positional paramstyle, will be + a sequence of strings, indicating the names of bound parameters in order. + + This is used in order to render bound parameters in their correct order, + and is combined with the :attr:`_sql.Compiled.params` dictionary to + render parameters. + + This sequence always contains the unescaped name of the parameters. + + .. seealso:: + + :ref:`faq_sql_expression_string` - includes a usage example for + debugging use cases. + + """ + _values_bindparam: Optional[List[str]] = None + + _visited_bindparam: Optional[List[str]] = None + + inline: bool = False + + ctes: Optional[MutableMapping[CTE, str]] + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + ctes_by_level_name: Dict[Tuple[int, str], CTE] + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]] + + ctes_recursive: bool + + _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]") + _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s") + _positional_pattern = re.compile( + f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}" + ) + + @classmethod + def _init_compiler_cls(cls): + cls._init_bind_translate() + + @classmethod + def _init_bind_translate(cls): + reg = re.escape("".join(cls.bindname_escape_characters)) + cls._bind_translate_re = re.compile(f"[{reg}]") + cls._bind_translate_chars = cls.bindname_escape_characters + + def __init__( + self, + dialect: Dialect, + statement: Optional[ClauseElement], + cache_key: Optional[CacheKey] = None, + column_keys: Optional[Sequence[str]] = None, + for_executemany: bool = False, + linting: Linting = NO_LINTING, + _supporting_against: Optional[SQLCompiler] = None, + **kwargs: Any, + ): + """Construct a new :class:`.SQLCompiler` object. + + :param dialect: :class:`.Dialect` to be used + + :param statement: :class:`_expression.ClauseElement` to be compiled + + :param column_keys: a list of column names to be compiled into an + INSERT or UPDATE statement. + + :param for_executemany: whether INSERT / UPDATE statements should + expect that they are to be invoked in an "executemany" style, + which may impact how the statement will be expected to return the + values of defaults and autoincrement / sequences and similar. + Depending on the backend and driver in use, support for retrieving + these values may be disabled which means SQL expressions may + be rendered inline, RETURNING may not be rendered, etc. + + :param kwargs: additional keyword arguments to be consumed by the + superclass. + + """ + self.column_keys = column_keys + + self.cache_key = cache_key + + if cache_key: + cksm = {b.key: b for b in cache_key[1]} + ckbm = {b: [b] for b in cache_key[1]} + self._cache_key_bind_match = (ckbm, cksm) + + # compile INSERT/UPDATE defaults/sequences to expect executemany + # style execution, which may mean no pre-execute of defaults, + # or no RETURNING + self.for_executemany = for_executemany + + self.linting = linting + + # a dictionary of bind parameter keys to BindParameter + # instances. + self.binds = {} + + # a dictionary of BindParameter instances to "compiled" names + # that are actually present in the generated SQL + self.bind_names = util.column_dict() + + # stack which keeps track of nested SELECT statements + self.stack = [] + + self._result_columns = [] + + # true if the paramstyle is positional + self.positional = dialect.positional + if self.positional: + self._numeric_binds = nb = dialect.paramstyle.startswith("numeric") + if nb: + self._numeric_binds_identifier_char = ( + "$" if dialect.paramstyle == "numeric_dollar" else ":" + ) + + self.compilation_bindtemplate = _pyformat_template + else: + self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + self.ctes = None + + self.label_length = ( + dialect.label_length or dialect.max_identifier_length + ) + + # a map which tracks "anonymous" identifiers that are created on + # the fly here + self.anon_map = prefix_anon_map() + + # a map which tracks "truncated" names based on + # dialect.label_length or dialect.max_identifier_length + self.truncated_names: Dict[Tuple[str, str], str] = {} + self._truncated_counters: Dict[str, int] = {} + + Compiled.__init__(self, dialect, statement, **kwargs) + + if self.isinsert or self.isupdate or self.isdelete: + if TYPE_CHECKING: + assert isinstance(statement, UpdateBase) + + if self.isinsert or self.isupdate: + if TYPE_CHECKING: + assert isinstance(statement, ValuesBase) + if statement._inline: + self.inline = True + elif self.for_executemany and ( + not self.isinsert + or ( + self.dialect.insert_executemany_returning + and statement._return_defaults + ) + ): + self.inline = True + + self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] + + if _supporting_against: + self.__dict__.update( + { + k: v + for k, v in _supporting_against.__dict__.items() + if k + not in { + "state", + "dialect", + "preparer", + "positional", + "_numeric_binds", + "compilation_bindtemplate", + "bindtemplate", + } + } + ) + + if self.state is CompilerState.STRING_APPLIED: + if self.positional: + if self._numeric_binds: + self._process_numeric() + else: + self._process_positional() + + if self._render_postcompile: + parameters = self.construct_params( + escape_names=False, + _no_postcompile=True, + ) + + self._process_parameters_for_postcompile( + parameters, _populate_self=True + ) + + @property + def insert_single_values_expr(self) -> Optional[str]: + """When an INSERT is compiled with a single set of parameters inside + a VALUES expression, the string is assigned here, where it can be + used for insert batching schemes to rewrite the VALUES expression. + + .. versionadded:: 1.3.8 + + .. versionchanged:: 2.0 This collection is no longer used by + SQLAlchemy's built-in dialects, in favor of the currently + internal ``_insertmanyvalues`` collection that is used only by + :class:`.SQLCompiler`. + + """ + if self._insertmanyvalues is None: + return None + else: + return self._insertmanyvalues.single_values_expr + + @util.ro_memoized_property + def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]: + """The effective "returning" columns for INSERT, UPDATE or DELETE. + + This is either the so-called "implicit returning" columns which are + calculated by the compiler on the fly, or those present based on what's + present in ``self.statement._returning`` (expanded into individual + columns using the ``._all_selected_columns`` attribute) i.e. those set + explicitly using the :meth:`.UpdateBase.returning` method. + + .. versionadded:: 2.0 + + """ + if self.implicit_returning: + return self.implicit_returning + elif self.statement is not None and is_dml(self.statement): + return [ + c + for c in self.statement._all_selected_columns + if is_column_element(c) + ] + + else: + return None + + @property + def returning(self): + """backwards compatibility; returns the + effective_returning collection. + + """ + return self.effective_returning + + @property + def current_executable(self): + """Return the current 'executable' that is being compiled. + + This is currently the :class:`_sql.Select`, :class:`_sql.Insert`, + :class:`_sql.Update`, :class:`_sql.Delete`, + :class:`_sql.CompoundSelect` object that is being compiled. + Specifically it's assigned to the ``self.stack`` list of elements. + + When a statement like the above is being compiled, it normally + is also assigned to the ``.statement`` attribute of the + :class:`_sql.Compiler` object. However, all SQL constructs are + ultimately nestable, and this attribute should never be consulted + by a ``visit_`` method, as it is not guaranteed to be assigned + nor guaranteed to correspond to the current statement being compiled. + + .. versionadded:: 1.3.21 + + For compatibility with previous versions, use the following + recipe:: + + statement = getattr(self, "current_executable", False) + if statement is False: + statement = self.stack[-1]["selectable"] + + For versions 1.4 and above, ensure only .current_executable + is used; the format of "self.stack" may change. + + + """ + try: + return self.stack[-1]["selectable"] + except IndexError as ie: + raise IndexError("Compiler does not have a stack entry") from ie + + @property + def prefetch(self): + return list(self.insert_prefetch) + list(self.update_prefetch) + + @util.memoized_property + def _global_attributes(self) -> Dict[Any, Any]: + return {} + + @util.memoized_instancemethod + def _init_cte_state(self) -> MutableMapping[CTE, str]: + """Initialize collections related to CTEs only if + a CTE is located, to save on the overhead of + these collections otherwise. + + """ + # collect CTEs to tack on top of a SELECT + # To store the query to print - Dict[cte, text_query] + ctes: MutableMapping[CTE, str] = util.OrderedDict() + self.ctes = ctes + + # Detect same CTE references - Dict[(level, name), cte] + # Level is required for supporting nesting + self.ctes_by_level_name = {} + + # To retrieve key/level in ctes_by_level_name - + # Dict[cte_reference, (level, cte_name, cte_opts)] + self.level_name_by_cte = {} + + self.ctes_recursive = False + + return ctes + + @contextlib.contextmanager + def _nested_result(self): + """special API to support the use case of 'nested result sets'""" + result_columns, ordered_columns = ( + self._result_columns, + self._ordered_columns, + ) + self._result_columns, self._ordered_columns = [], False + + try: + if self.stack: + entry = self.stack[-1] + entry["need_result_map_for_nested"] = True + else: + entry = None + yield self._result_columns, self._ordered_columns + finally: + if entry: + entry.pop("need_result_map_for_nested") + self._result_columns, self._ordered_columns = ( + result_columns, + ordered_columns, + ) + + def _process_positional(self): + assert not self.positiontup + assert self.state is CompilerState.STRING_APPLIED + assert not self._numeric_binds + + if self.dialect.paramstyle == "format": + placeholder = "%s" + else: + assert self.dialect.paramstyle == "qmark" + placeholder = "?" + + positions = [] + + def find_position(m: re.Match[str]) -> str: + normal_bind = m.group(1) + if normal_bind: + positions.append(normal_bind) + return placeholder + else: + # this a post-compile bind + positions.append(m.group(2)) + return m.group(0) + + self.string = re.sub( + self._positional_pattern, find_position, self.string + ) + + if self.escaped_bind_names: + reverse_escape = {v: k for k, v in self.escaped_bind_names.items()} + assert len(self.escaped_bind_names) == len(reverse_escape) + self.positiontup = [ + reverse_escape.get(name, name) for name in positions + ] + else: + self.positiontup = positions + + if self._insertmanyvalues: + positions = [] + + single_values_expr = re.sub( + self._positional_pattern, + find_position, + self._insertmanyvalues.single_values_expr, + ) + insert_crud_params = [ + ( + v[0], + v[1], + re.sub(self._positional_pattern, find_position, v[2]), + v[3], + ) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = self._insertmanyvalues._replace( + single_values_expr=single_values_expr, + insert_crud_params=insert_crud_params, + ) + + def _process_numeric(self): + assert self._numeric_binds + assert self.state is CompilerState.STRING_APPLIED + + num = 1 + param_pos: Dict[str, str] = {} + order: Iterable[str] + if self._insertmanyvalues and self._values_bindparam is not None: + # bindparams that are not in values are always placed first. + # this avoids the need of changing them when using executemany + # values () () + order = itertools.chain( + ( + name + for name in self.bind_names.values() + if name not in self._values_bindparam + ), + self.bind_names.values(), + ) + else: + order = self.bind_names.values() + + for bind_name in order: + if bind_name in param_pos: + continue + bind = self.binds[bind_name] + if ( + bind in self.post_compile_params + or bind in self.literal_execute_params + ): + # set to None to just mark the in positiontup, it will not + # be replaced below. + param_pos[bind_name] = None # type: ignore + else: + ph = f"{self._numeric_binds_identifier_char}{num}" + num += 1 + param_pos[bind_name] = ph + + self.next_numeric_pos = num + + self.positiontup = list(param_pos) + if self.escaped_bind_names: + len_before = len(param_pos) + param_pos = { + self.escaped_bind_names.get(name, name): pos + for name, pos in param_pos.items() + } + assert len(param_pos) == len_before + + # Can't use format here since % chars are not escaped. + self.string = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], self.string + ) + + if self._insertmanyvalues: + single_values_expr = ( + # format is ok here since single_values_expr includes only + # place-holders + self._insertmanyvalues.single_values_expr + % param_pos + ) + insert_crud_params = [ + (v[0], v[1], "%s", v[3]) + for v in self._insertmanyvalues.insert_crud_params + ] + + self._insertmanyvalues = self._insertmanyvalues._replace( + # This has the numbers (:1, :2) + single_values_expr=single_values_expr, + # The single binds are instead %s so they can be formatted + insert_crud_params=insert_crud_params, + ) + + @util.memoized_property + def _bind_processors( + self, + ) -> MutableMapping[ + str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]] + ]: + # mypy is not able to see the two value types as the above Union, + # it just sees "object". don't know how to resolve + return { + key: value # type: ignore + for key, value in ( + ( + self.bind_names[bindparam], + ( + bindparam.type._cached_bind_processor(self.dialect) + if not bindparam.type._is_tuple_type + else tuple( + elem_type._cached_bind_processor(self.dialect) + for elem_type in cast( + TupleType, bindparam.type + ).types + ) + ), + ) + for bindparam in self.bind_names + ) + if value is not None + } + + def is_subquery(self): + return len(self.stack) > 1 + + @property + def sql_compiler(self): + return self + + def construct_expanded_state( + self, + params: Optional[_CoreSingleExecuteParams] = None, + escape_names: bool = True, + ) -> ExpandedState: + """Return a new :class:`.ExpandedState` for a given parameter set. + + For queries that use "expanding" or other late-rendered parameters, + this method will provide for both the finalized SQL string as well + as the parameters that would be used for a particular parameter set. + + .. versionadded:: 2.0.0rc1 + + """ + parameters = self.construct_params( + params, + escape_names=escape_names, + _no_postcompile=True, + ) + return self._process_parameters_for_postcompile( + parameters, + ) + + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + escape_names: bool = True, + _group_number: Optional[int] = None, + _check: bool = True, + _no_postcompile: bool = False, + ) -> _MutableCoreSingleExecuteParams: + """return a dictionary of bind parameter keys and values""" + + if self._render_postcompile and not _no_postcompile: + assert self._post_compile_expanded_state is not None + if not params: + return dict(self._post_compile_expanded_state.parameters) + else: + raise exc.InvalidRequestError( + "can't construct new parameters when render_postcompile " + "is used; the statement is hard-linked to the original " + "parameters. Use construct_expanded_state to generate a " + "new statement and parameters." + ) + + has_escaped_names = escape_names and bool(self.escaped_bind_names) + + if extracted_parameters: + # related the bound parameters collected in the original cache key + # to those collected in the incoming cache key. They will not have + # matching names but they will line up positionally in the same + # way. The parameters present in self.bind_names may be clones of + # these original cache key params in the case of DML but the .key + # will be guaranteed to match. + if self.cache_key is None: + raise exc.CompileError( + "This compiled object has no original cache key; " + "can't pass extracted_parameters to construct_params" + ) + else: + orig_extracted = self.cache_key[1] + + ckbm_tuple = self._cache_key_bind_match + assert ckbm_tuple is not None + ckbm, _ = ckbm_tuple + resolved_extracted = { + bind: extracted + for b, extracted in zip(orig_extracted, extracted_parameters) + for bind in ckbm[b] + } + else: + resolved_extracted = None + + if params: + pd = {} + for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + + if bindparam.key in params: + pd[escaped_name] = params[bindparam.key] + elif name in params: + pd[escaped_name] = params[name] + + elif _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key, + code="cd3x", + ) + else: + if resolved_extracted: + value_param = resolved_extracted.get( + bindparam, bindparam + ) + else: + value_param = bindparam + + if bindparam.callable: + pd[escaped_name] = value_param.effective_value + else: + pd[escaped_name] = value_param.value + return pd + else: + pd = {} + for bindparam, name in self.bind_names.items(): + escaped_name = ( + self.escaped_bind_names.get(name, name) + if has_escaped_names + else name + ) + + if _check and bindparam.required: + if _group_number: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r, " + "in parameter group %d" + % (bindparam.key, _group_number), + code="cd3x", + ) + else: + raise exc.InvalidRequestError( + "A value is required for bind parameter %r" + % bindparam.key, + code="cd3x", + ) + + if resolved_extracted: + value_param = resolved_extracted.get(bindparam, bindparam) + else: + value_param = bindparam + + if bindparam.callable: + pd[escaped_name] = value_param.effective_value + else: + pd[escaped_name] = value_param.value + + return pd + + @util.memoized_instancemethod + def _get_set_input_sizes_lookup(self): + dialect = self.dialect + + include_types = dialect.include_set_input_sizes + exclude_types = dialect.exclude_set_input_sizes + + dbapi = dialect.dbapi + + def lookup_type(typ): + dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi) + + if ( + dbtype is not None + and (exclude_types is None or dbtype not in exclude_types) + and (include_types is None or dbtype in include_types) + ): + return dbtype + else: + return None + + inputsizes = {} + + literal_execute_params = self.literal_execute_params + + for bindparam in self.bind_names: + if bindparam in literal_execute_params: + continue + + if bindparam.type._is_tuple_type: + inputsizes[bindparam] = [ + lookup_type(typ) + for typ in cast(TupleType, bindparam.type).types + ] + else: + inputsizes[bindparam] = lookup_type(bindparam.type) + + return inputsizes + + @property + def params(self): + """Return the bind param dictionary embedded into this + compiled object, for those values that are present. + + .. seealso:: + + :ref:`faq_sql_expression_string` - includes a usage example for + debugging use cases. + + """ + return self.construct_params(_check=False) + + def _process_parameters_for_postcompile( + self, + parameters: _MutableCoreSingleExecuteParams, + _populate_self: bool = False, + ) -> ExpandedState: + """handle special post compile parameters. + + These include: + + * "expanding" parameters -typically IN tuples that are rendered + on a per-parameter basis for an otherwise fixed SQL statement string. + + * literal_binds compiled with the literal_execute flag. Used for + things like SQL Server "TOP N" where the driver does not accommodate + N as a bound parameter. + + """ + + expanded_parameters = {} + new_positiontup: Optional[List[str]] + + pre_expanded_string = self._pre_expanded_string + if pre_expanded_string is None: + pre_expanded_string = self.string + + if self.positional: + new_positiontup = [] + + pre_expanded_positiontup = self._pre_expanded_positiontup + if pre_expanded_positiontup is None: + pre_expanded_positiontup = self.positiontup + + else: + new_positiontup = pre_expanded_positiontup = None + + processors = self._bind_processors + single_processors = cast( + "Mapping[str, _BindProcessorType[Any]]", processors + ) + tuple_processors = cast( + "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors + ) + + new_processors: Dict[str, _BindProcessorType[Any]] = {} + + replacement_expressions: Dict[str, Any] = {} + to_update_sets: Dict[str, Any] = {} + + # notes: + # *unescaped* parameter names in: + # self.bind_names, self.binds, self._bind_processors, self.positiontup + # + # *escaped* parameter names in: + # construct_params(), replacement_expressions + + numeric_positiontup: Optional[List[str]] = None + + if self.positional and pre_expanded_positiontup is not None: + names: Iterable[str] = pre_expanded_positiontup + if self._numeric_binds: + numeric_positiontup = [] + else: + names = self.bind_names.values() + + ebn = self.escaped_bind_names + for name in names: + escaped_name = ebn.get(name, name) if ebn else name + parameter = self.binds[name] + + if parameter in self.literal_execute_params: + if escaped_name not in replacement_expressions: + replacement_expressions[escaped_name] = ( + self.render_literal_bindparam( + parameter, + render_literal_value=parameters.pop(escaped_name), + ) + ) + continue + + if parameter in self.post_compile_params: + if escaped_name in replacement_expressions: + to_update = to_update_sets[escaped_name] + values = None + else: + # we are removing the parameter from parameters + # because it is a list value, which is not expected by + # TypeEngine objects that would otherwise be asked to + # process it. the single name is being replaced with + # individual numbered parameters for each value in the + # param. + # + # note we are also inserting *escaped* parameter names + # into the given dictionary. default dialect will + # use these param names directly as they will not be + # in the escaped_bind_names dictionary. + values = parameters.pop(name) + + leep_res = self._literal_execute_expanding_parameter( + escaped_name, parameter, values + ) + (to_update, replacement_expr) = leep_res + + to_update_sets[escaped_name] = to_update + replacement_expressions[escaped_name] = replacement_expr + + if not parameter.literal_execute: + parameters.update(to_update) + if parameter.type._is_tuple_type: + assert values is not None + new_processors.update( + ( + "%s_%s_%s" % (name, i, j), + tuple_processors[name][j - 1], + ) + for i, tuple_element in enumerate(values, 1) + for j, _ in enumerate(tuple_element, 1) + if name in tuple_processors + and tuple_processors[name][j - 1] is not None + ) + else: + new_processors.update( + (key, single_processors[name]) + for key, _ in to_update + if name in single_processors + ) + if numeric_positiontup is not None: + numeric_positiontup.extend( + name for name, _ in to_update + ) + elif new_positiontup is not None: + # to_update has escaped names, but that's ok since + # these are new names, that aren't in the + # escaped_bind_names dict. + new_positiontup.extend(name for name, _ in to_update) + expanded_parameters[name] = [ + expand_key for expand_key, _ in to_update + ] + elif new_positiontup is not None: + new_positiontup.append(name) + + def process_expanding(m): + key = m.group(1) + expr = replacement_expressions[key] + + # if POSTCOMPILE included a bind_expression, render that + # around each element + if m.group(2): + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + expr = ", ".join( + "%s%s%s" % (be_left, exp, be_right) + for exp in expr.split(", ") + ) + return expr + + statement = re.sub( + self._post_compile_pattern, process_expanding, pre_expanded_string + ) + + if numeric_positiontup is not None: + assert new_positiontup is not None + param_pos = { + key: f"{self._numeric_binds_identifier_char}{num}" + for num, key in enumerate( + numeric_positiontup, self.next_numeric_pos + ) + } + # Can't use format here since % chars are not escaped. + statement = self._pyformat_pattern.sub( + lambda m: param_pos[m.group(1)], statement + ) + new_positiontup.extend(numeric_positiontup) + + expanded_state = ExpandedState( + statement, + parameters, + new_processors, + new_positiontup, + expanded_parameters, + ) + + if _populate_self: + # this is for the "render_postcompile" flag, which is not + # otherwise used internally and is for end-user debugging and + # special use cases. + self._pre_expanded_string = pre_expanded_string + self._pre_expanded_positiontup = pre_expanded_positiontup + self.string = expanded_state.statement + self.positiontup = ( + list(expanded_state.positiontup or ()) + if self.positional + else None + ) + self._post_compile_expanded_state = expanded_state + + return expanded_state + + @util.preload_module("sqlalchemy.engine.cursor") + def _create_result_map(self): + """utility method used for unit tests only.""" + cursor = util.preloaded.engine_cursor + return cursor.CursorResultMetaData._create_description_match_map( + self._result_columns + ) + + # assigned by crud.py for insert/update statements + _get_bind_name_for_col: _BindNameForColProtocol + + @util.memoized_property + def _within_exec_param_key_getter(self) -> Callable[[Any], str]: + getter = self._get_bind_name_for_col + return getter + + @util.memoized_property + @util.preload_module("sqlalchemy.engine.result") + def _inserted_primary_key_from_lastrowid_getter(self): + result = util.preloaded.engine_result + + param_key_getter = self._within_exec_param_key_getter + + assert self.compile_state is not None + statement = self.compile_state.statement + + if TYPE_CHECKING: + assert isinstance(statement, Insert) + + table = statement.table + + getters = [ + (operator.methodcaller("get", param_key_getter(col), None), col) + for col in table.primary_key + ] + + autoinc_getter = None + autoinc_col = table._autoincrement_column + if autoinc_col is not None: + # apply type post processors to the lastrowid + lastrowid_processor = autoinc_col.type._cached_result_processor( + self.dialect, None + ) + autoinc_key = param_key_getter(autoinc_col) + + # if a bind value is present for the autoincrement column + # in the parameters, we need to do the logic dictated by + # #7998; honor a non-None user-passed parameter over lastrowid. + # previously in the 1.4 series we weren't fetching lastrowid + # at all if the key were present in the parameters + if autoinc_key in self.binds: + + def _autoinc_getter(lastrowid, parameters): + param_value = parameters.get(autoinc_key, lastrowid) + if param_value is not None: + # they supplied non-None parameter, use that. + # SQLite at least is observed to return the wrong + # cursor.lastrowid for INSERT..ON CONFLICT so it + # can't be used in all cases + return param_value + else: + # use lastrowid + return lastrowid + + # work around mypy https://github.com/python/mypy/issues/14027 + autoinc_getter = _autoinc_getter + + else: + lastrowid_processor = None + + row_fn = result.result_tuple([col.key for col in table.primary_key]) + + def get(lastrowid, parameters): + """given cursor.lastrowid value and the parameters used for INSERT, + return a "row" that represents the primary key, either by + using the "lastrowid" or by extracting values from the parameters + that were sent along with the INSERT. + + """ + if lastrowid_processor is not None: + lastrowid = lastrowid_processor(lastrowid) + + if lastrowid is None: + return row_fn(getter(parameters) for getter, col in getters) + else: + return row_fn( + ( + ( + autoinc_getter(lastrowid, parameters) + if autoinc_getter is not None + else lastrowid + ) + if col is autoinc_col + else getter(parameters) + ) + for getter, col in getters + ) + + return get + + @util.memoized_property + @util.preload_module("sqlalchemy.engine.result") + def _inserted_primary_key_from_returning_getter(self): + if typing.TYPE_CHECKING: + from ..engine import result + else: + result = util.preloaded.engine_result + + assert self.compile_state is not None + statement = self.compile_state.statement + + if TYPE_CHECKING: + assert isinstance(statement, Insert) + + param_key_getter = self._within_exec_param_key_getter + table = statement.table + + returning = self.implicit_returning + assert returning is not None + ret = {col: idx for idx, col in enumerate(returning)} + + getters = cast( + "List[Tuple[Callable[[Any], Any], bool]]", + [ + ( + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller( + "get", param_key_getter(col), None + ), + False, + ) + ) + for col in table.primary_key + ], + ) + + row_fn = result.result_tuple([col.key for col in table.primary_key]) + + def get(row, parameters): + return row_fn( + getter(row) if use_row else getter(parameters) + for getter, use_row in getters + ) + + return get + + def default_from(self): + """Called when a SELECT statement has no froms, and no FROM clause is + to be appended. + + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. + + """ + return "" + + def visit_override_binds(self, override_binds, **kw): + """SQL compile the nested element of an _OverrideBinds with + bindparams swapped out. + + The _OverrideBinds is not normally expected to be compiled; it + is meant to be used when an already cached statement is to be used, + the compilation was already performed, and only the bound params should + be swapped in at execution time. + + However, there are test cases that exericise this object, and + additionally the ORM subquery loader is known to feed in expressions + which include this construct into new queries (discovered in #11173), + so it has to do the right thing at compile time as well. + + """ + + # get SQL text first + sqltext = override_binds.element._compiler_dispatch(self, **kw) + + # for a test compile that is not for caching, change binds after the + # fact. note that we don't try to + # swap the bindparam as we compile, because our element may be + # elsewhere in the statement already (e.g. a subquery or perhaps a + # CTE) and was already visited / compiled. See + # test_relationship_criteria.py -> + # test_selectinload_local_criteria_subquery + for k in override_binds.translate: + if k not in self.binds: + continue + bp = self.binds[k] + + # so this would work, just change the value of bp in place. + # but we dont want to mutate things outside. + # bp.value = override_binds.translate[bp.key] + # continue + + # instead, need to replace bp with new_bp or otherwise accommodate + # in all internal collections + new_bp = bp._with_value( + override_binds.translate[bp.key], + maintain_key=True, + required=False, + ) + + name = self.bind_names[bp] + self.binds[k] = self.binds[name] = new_bp + self.bind_names[new_bp] = name + self.bind_names.pop(bp, None) + + if bp in self.post_compile_params: + self.post_compile_params |= {new_bp} + if bp in self.literal_execute_params: + self.literal_execute_params |= {new_bp} + + ckbm_tuple = self._cache_key_bind_match + if ckbm_tuple: + ckbm, cksm = ckbm_tuple + for bp in bp._cloned_set: + if bp.key in cksm: + cb = cksm[bp.key] + ckbm[cb].append(new_bp) + + return sqltext + + def visit_grouping(self, grouping, asfrom=False, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + + def visit_select_statement_grouping(self, grouping, **kwargs): + return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" + + def visit_label_reference( + self, element, within_columns_clause=False, **kwargs + ): + if self.stack and self.dialect.supports_simple_order_by_label: + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + raise exc.CompileError( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ) from ke + + ( + with_cols, + only_froms, + only_cols, + ) = compile_state._label_resolve_dict + if within_columns_clause: + resolve_dict = only_froms + else: + resolve_dict = only_cols + + # this can be None in the case that a _label_reference() + # were subject to a replacement operation, in which case + # the replacement of the Label element may have changed + # to something else like a ColumnClause expression. + order_by_elem = element.element._order_by_label_element + + if ( + order_by_elem is not None + and order_by_elem.name in resolve_dict + and order_by_elem.shares_lineage( + resolve_dict[order_by_elem.name] + ) + ): + kwargs["render_label_as_label"] = ( + element.element._order_by_label_element + ) + return self.process( + element.element, + within_columns_clause=within_columns_clause, + **kwargs, + ) + + def visit_textual_label_reference( + self, element, within_columns_clause=False, **kwargs + ): + if not self.stack: + # compiling the element outside of the context of a SELECT + return self.process(element._text_clause) + + try: + compile_state = cast( + "Union[SelectState, CompoundSelectState]", + self.stack[-1]["compile_state"], + ) + except KeyError as ke: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=ke, + ) + + with_cols, only_froms, only_cols = compile_state._label_resolve_dict + try: + if within_columns_clause: + col = only_froms[element.element] + else: + col = with_cols[element.element] + except KeyError as err: + coercions._no_text_coercion( + element.element, + extra=( + "Can't resolve label reference for ORDER BY / " + "GROUP BY / DISTINCT etc." + ), + exc_cls=exc.CompileError, + err=err, + ) + else: + kwargs["render_label_as_label"] = col + return self.process( + col, within_columns_clause=within_columns_clause, **kwargs + ) + + def visit_label( + self, + label, + add_to_result_map=None, + within_label_clause=False, + within_columns_clause=False, + render_label_as_label=None, + result_map_targets=(), + **kw, + ): + # only render labels within the columns clause + # or ORDER BY clause of a select. dialect-specific compilers + # can modify this behavior. + render_label_with_as = ( + within_columns_clause and not within_label_clause + ) + render_label_only = render_label_as_label is label + + if render_label_only or render_label_with_as: + if isinstance(label.name, elements._truncated_label): + labelname = self._truncated_identifier("colident", label.name) + else: + labelname = label.name + + if render_label_with_as: + if add_to_result_map is not None: + add_to_result_map( + labelname, + label.name, + (label, labelname) + label._alt_names + result_map_targets, + label.type, + ) + return ( + label.element._compiler_dispatch( + self, + within_columns_clause=True, + within_label_clause=True, + **kw, + ) + + OPERATORS[operators.as_] + + self.preparer.format_label(label, labelname) + ) + elif render_label_only: + return self.preparer.format_label(label, labelname) + else: + return label.element._compiler_dispatch( + self, within_columns_clause=False, **kw + ) + + def _fallback_column_name(self, column): + raise exc.CompileError( + "Cannot compile Column object until its 'name' is assigned." + ) + + def visit_lambda_element(self, element, **kw): + sql_element = element._resolved + return self.process(sql_element, **kw) + + def visit_column( + self, + column: ColumnClause[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + include_table: bool = True, + result_map_targets: Tuple[Any, ...] = (), + ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None, + **kwargs: Any, + ) -> str: + name = orig_name = column.name + if name is None: + name = self._fallback_column_name(column) + + is_literal = column.is_literal + if not is_literal and isinstance(name, elements._truncated_label): + name = self._truncated_identifier("colident", name) + + if add_to_result_map is not None: + targets = (column, name, column.key) + result_map_targets + if column._tq_label: + targets += (column._tq_label,) + + add_to_result_map(name, orig_name, targets, column.type) + + if is_literal: + # note we are not currently accommodating for + # literal_column(quoted_name('ident', True)) here + name = self.escape_literal_column(name) + else: + name = self.preparer.quote(name) + table = column.table + if table is None or not include_table or not table.named_with_column: + return name + else: + effective_schema = self.preparer.schema_for_object(table) + + if effective_schema: + schema_prefix = ( + self.preparer.quote_schema(effective_schema) + "." + ) + else: + schema_prefix = "" + + if TYPE_CHECKING: + assert isinstance(table, NamedFromClause) + tablename = table.name + + if ( + not effective_schema + and ambiguous_table_name_map + and tablename in ambiguous_table_name_map + ): + tablename = ambiguous_table_name_map[tablename] + + if isinstance(tablename, elements._truncated_label): + tablename = self._truncated_identifier("alias", tablename) + + return schema_prefix + self.preparer.quote(tablename) + "." + name + + def visit_collation(self, element, **kw): + return self.preparer.format_collation(element.collation) + + def visit_fromclause(self, fromclause, **kwargs): + return fromclause.name + + def visit_index(self, index, **kwargs): + return index.name + + def visit_typeclause(self, typeclause, **kw): + kw["type_expression"] = typeclause + kw["identifier_preparer"] = self.preparer + return self.dialect.type_compiler_instance.process( + typeclause.type, **kw + ) + + def post_process_text(self, text): + if self.preparer._double_percents: + text = text.replace("%", "%%") + return text + + def escape_literal_column(self, text): + if self.preparer._double_percents: + text = text.replace("%", "%%") + return text + + def visit_textclause(self, textclause, add_to_result_map=None, **kw): + def do_bindparam(m): + name = m.group(1) + if name in textclause._bindparams: + return self.process(textclause._bindparams[name], **kw) + else: + return self.bindparam_string(name, **kw) + + if not self.stack: + self.isplaintext = True + + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + + # un-escape any \:params + return BIND_PARAMS_ESC.sub( + lambda m: m.group(1), + BIND_PARAMS.sub( + do_bindparam, self.post_process_text(textclause.text) + ), + ) + + def visit_textual_select( + self, taf, compound_index=None, asfrom=False, **kw + ): + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + new_entry: _CompilerStackEntry = { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": taf, + } + self.stack.append(new_entry) + + if taf._independent_ctes: + self._dispatch_independent_ctes(taf, kw) + + populate_result_map = ( + toplevel + or ( + compound_index == 0 + and entry.get("need_result_map_for_compound", False) + ) + or entry.get("need_result_map_for_nested", False) + ) + + if populate_result_map: + self._ordered_columns = self._textual_ordered_columns = ( + taf.positional + ) + + # enable looser result column matching when the SQL text links to + # Column objects by name only + self._loose_column_name_matching = not taf.positional and bool( + taf.column_args + ) + + for c in taf.column_args: + self.process( + c, + within_columns_clause=True, + add_to_result_map=self._add_to_result_map, + ) + + text = self.process(taf.element, **kw) + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + self.stack.pop(-1) + + return text + + def visit_null(self, expr, **kw): + return "NULL" + + def visit_true(self, expr, **kw): + if self.dialect.supports_native_boolean: + return "true" + else: + return "1" + + def visit_false(self, expr, **kw): + if self.dialect.supports_native_boolean: + return "false" + else: + return "0" + + def _generate_delimited_list(self, elements, separator, **kw): + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in elements) + if s + ) + + def _generate_delimited_and_list(self, clauses, **kw): + lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean( + operators.and_, + elements.True_._singleton, + elements.False_._singleton, + clauses, + ) + if lcc == 1: + return clauses[0]._compiler_dispatch(self, **kw) + else: + separator = OPERATORS[operators.and_] + return separator.join( + s + for s in (c._compiler_dispatch(self, **kw) for c in clauses) + if s + ) + + def visit_tuple(self, clauselist, **kw): + return "(%s)" % self.visit_clauselist(clauselist, **kw) + + def visit_clauselist(self, clauselist, **kw): + sep = clauselist.operator + if sep is None: + sep = " " + else: + sep = OPERATORS[clauselist.operator] + + return self._generate_delimited_list(clauselist.clauses, sep, **kw) + + def visit_expression_clauselist(self, clauselist, **kw): + operator_ = clauselist.operator + + disp = self._get_operator_dispatch( + operator_, "expression_clauselist", None + ) + if disp: + return disp(clauselist, operator_, **kw) + + try: + opstring = OPERATORS[operator_] + except KeyError as err: + raise exc.UnsupportedCompilationError(self, operator_) from err + else: + kw["_in_operator_expression"] = True + return self._generate_delimited_list( + clauselist.clauses, opstring, **kw + ) + + def visit_case(self, clause, **kwargs): + x = "CASE " + if clause.value is not None: + x += clause.value._compiler_dispatch(self, **kwargs) + " " + for cond, result in clause.whens: + x += ( + "WHEN " + + cond._compiler_dispatch(self, **kwargs) + + " THEN " + + result._compiler_dispatch(self, **kwargs) + + " " + ) + if clause.else_ is not None: + x += ( + "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " " + ) + x += "END" + return x + + def visit_type_coerce(self, type_coerce, **kw): + return type_coerce.typed_expression._compiler_dispatch(self, **kw) + + def visit_cast(self, cast, **kwargs): + type_clause = cast.typeclause._compiler_dispatch(self, **kwargs) + match = re.match("(.*)( COLLATE .*)", type_clause) + return "CAST(%s AS %s)%s" % ( + cast.clause._compiler_dispatch(self, **kwargs), + match.group(1) if match else type_clause, + match.group(2) if match else "", + ) + + def _format_frame_clause(self, range_, **kw): + return "%s AND %s" % ( + ( + "UNBOUNDED PRECEDING" + if range_[0] is elements.RANGE_UNBOUNDED + else ( + "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else ( + "%s PRECEDING" + % ( + self.process( + elements.literal(abs(range_[0])), **kw + ), + ) + if range_[0] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),) + ) + ) + ), + ( + "UNBOUNDED FOLLOWING" + if range_[1] is elements.RANGE_UNBOUNDED + else ( + "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else ( + "%s PRECEDING" + % ( + self.process( + elements.literal(abs(range_[1])), **kw + ), + ) + if range_[1] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),) + ) + ) + ), + ) + + def visit_over(self, over, **kwargs): + text = over.element._compiler_dispatch(self, **kwargs) + if over.range_: + range_ = "RANGE BETWEEN %s" % self._format_frame_clause( + over.range_, **kwargs + ) + elif over.rows: + range_ = "ROWS BETWEEN %s" % self._format_frame_clause( + over.rows, **kwargs + ) + else: + range_ = None + + return "%s OVER (%s)" % ( + text, + " ".join( + [ + "%s BY %s" + % (word, clause._compiler_dispatch(self, **kwargs)) + for word, clause in ( + ("PARTITION", over.partition_by), + ("ORDER", over.order_by), + ) + if clause is not None and len(clause) + ] + + ([range_] if range_ else []) + ), + ) + + def visit_withingroup(self, withingroup, **kwargs): + return "%s WITHIN GROUP (ORDER BY %s)" % ( + withingroup.element._compiler_dispatch(self, **kwargs), + withingroup.order_by._compiler_dispatch(self, **kwargs), + ) + + def visit_funcfilter(self, funcfilter, **kwargs): + return "%s FILTER (WHERE %s)" % ( + funcfilter.func._compiler_dispatch(self, **kwargs), + funcfilter.criterion._compiler_dispatch(self, **kwargs), + ) + + def visit_extract(self, extract, **kwargs): + field = self.extract_map.get(extract.field, extract.field) + return "EXTRACT(%s FROM %s)" % ( + field, + extract.expr._compiler_dispatch(self, **kwargs), + ) + + def visit_scalar_function_column(self, element, **kw): + compiled_fn = self.visit_function(element.fn, **kw) + compiled_col = self.visit_column(element, **kw) + return "(%s).%s" % (compiled_fn, compiled_col) + + def visit_function( + self, + func: Function[Any], + add_to_result_map: Optional[_ResultMapAppender] = None, + **kwargs: Any, + ) -> str: + if add_to_result_map is not None: + add_to_result_map(func.name, func.name, (), func.type) + + disp = getattr(self, "visit_%s_func" % func.name.lower(), None) + + text: str + + if disp: + text = disp(func, **kwargs) + else: + name = FUNCTIONS.get(func._deannotate().__class__, None) + if name: + if func._has_args: + name += "%(expr)s" + else: + name = func.name + name = ( + self.preparer.quote(name) + if self.preparer._requires_quotes_illegal_chars(name) + or isinstance(name, elements.quoted_name) + else name + ) + name = name + "%(expr)s" + text = ".".join( + [ + ( + self.preparer.quote(tok) + if self.preparer._requires_quotes_illegal_chars(tok) + or isinstance(name, elements.quoted_name) + else tok + ) + for tok in func.packagenames + ] + + [name] + ) % {"expr": self.function_argspec(func, **kwargs)} + + if func._with_ordinality: + text += " WITH ORDINALITY" + return text + + def visit_next_value_func(self, next_value, **kw): + return self.visit_sequence(next_value.sequence) + + def visit_sequence(self, sequence, **kw): + raise NotImplementedError( + "Dialect '%s' does not support sequence increments." + % self.dialect.name + ) + + def function_argspec(self, func, **kwargs): + return func.clause_expr._compiler_dispatch(self, **kwargs) + + def visit_compound_select( + self, cs, asfrom=False, compound_index=None, **kwargs + ): + toplevel = not self.stack + + compile_state = cs._compile_state_factory(cs, self, **kwargs) + + if toplevel and not self.compile_state: + self.compile_state = compile_state + + compound_stmt = compile_state.statement + + entry = self._default_stack_entry if toplevel else self.stack[-1] + need_result_map = toplevel or ( + not compound_index + and entry.get("need_result_map_for_compound", False) + ) + + # indicates there is already a CompoundSelect in play + if compound_index == 0: + entry["select_0"] = cs + + self.stack.append( + { + "correlate_froms": entry["correlate_froms"], + "asfrom_froms": entry["asfrom_froms"], + "selectable": cs, + "compile_state": compile_state, + "need_result_map_for_compound": need_result_map, + } + ) + + if compound_stmt._independent_ctes: + self._dispatch_independent_ctes(compound_stmt, kwargs) + + keyword = self.compound_keywords[cs.keyword] + + text = (" " + keyword + " ").join( + ( + c._compiler_dispatch( + self, asfrom=asfrom, compound_index=i, **kwargs + ) + for i, c in enumerate(cs.selects) + ) + ) + + kwargs["include_table"] = False + text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs)) + text += self.order_by_clause(cs, **kwargs) + if cs._has_row_limiting_clause: + text += self._row_limit_clause(cs, **kwargs) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ) + + text + ) + + self.stack.pop(-1) + return text + + def _row_limit_clause(self, cs, **kwargs): + if cs._fetch_clause is not None: + return self.fetch_clause(cs, **kwargs) + else: + return self.limit_clause(cs, **kwargs) + + def _get_operator_dispatch(self, operator_, qualifier1, qualifier2): + attrname = "visit_%s_%s%s" % ( + operator_.__name__, + qualifier1, + "_" + qualifier2 if qualifier2 else "", + ) + return getattr(self, attrname, None) + + def visit_unary( + self, unary, add_to_result_map=None, result_map_targets=(), **kw + ): + if add_to_result_map is not None: + result_map_targets += (unary,) + kw["add_to_result_map"] = add_to_result_map + kw["result_map_targets"] = result_map_targets + + if unary.operator: + if unary.modifier: + raise exc.CompileError( + "Unary expression does not support operator " + "and modifier simultaneously" + ) + disp = self._get_operator_dispatch( + unary.operator, "unary", "operator" + ) + if disp: + return disp(unary, unary.operator, **kw) + else: + return self._generate_generic_unary_operator( + unary, OPERATORS[unary.operator], **kw + ) + elif unary.modifier: + disp = self._get_operator_dispatch( + unary.modifier, "unary", "modifier" + ) + if disp: + return disp(unary, unary.modifier, **kw) + else: + return self._generate_generic_unary_modifier( + unary, OPERATORS[unary.modifier], **kw + ) + else: + raise exc.CompileError( + "Unary expression has no operator or modifier" + ) + + def visit_truediv_binary(self, binary, operator, **kw): + if self.dialect.div_is_floordiv: + return ( + self.process(binary.left, **kw) + + " / " + # TODO: would need a fast cast again here, + # unless we want to use an implicit cast like "+ 0.0" + + self.process( + elements.Cast( + binary.right, + ( + binary.right.type + if binary.right.type._type_affinity + is sqltypes.Numeric + else sqltypes.Numeric() + ), + ), + **kw, + ) + ) + else: + return ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + + def visit_floordiv_binary(self, binary, operator, **kw): + if ( + self.dialect.div_is_floordiv + and binary.right.type._type_affinity is sqltypes.Integer + ): + return ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + else: + return "FLOOR(%s)" % ( + self.process(binary.left, **kw) + + " / " + + self.process(binary.right, **kw) + ) + + def visit_is_true_unary_operator(self, element, operator, **kw): + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): + return self.process(element.element, **kw) + else: + return "%s = 1" % self.process(element.element, **kw) + + def visit_is_false_unary_operator(self, element, operator, **kw): + if ( + element._is_implicitly_boolean + or self.dialect.supports_native_boolean + ): + return "NOT %s" % self.process(element.element, **kw) + else: + return "%s = 0" % self.process(element.element, **kw) + + def visit_not_match_op_binary(self, binary, operator, **kw): + return "NOT %s" % self.visit_binary( + binary, override_operator=operators.match_op + ) + + def visit_not_in_op_binary(self, binary, operator, **kw): + # The brackets are required in the NOT IN operation because the empty + # case is handled using the form "(col NOT IN (null) OR 1 = 1)". + # The presence of the OR makes the brackets required. + return "(%s)" % self._generate_generic_binary( + binary, OPERATORS[operator], **kw + ) + + def visit_empty_set_op_expr(self, type_, expand_op, **kw): + if expand_op is operators.not_in_op: + if len(type_) > 1: + return "(%s)) OR (1 = 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) OR (1 = 1" + elif expand_op is operators.in_op: + if len(type_) > 1: + return "(%s)) AND (1 != 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) AND (1 != 1" + else: + return self.visit_empty_set_expr(type_) + + def visit_empty_set_expr(self, element_types, **kw): + raise NotImplementedError( + "Dialect '%s' does not support empty set expression." + % self.dialect.name + ) + + def _literal_execute_expanding_parameter_literal_binds( + self, parameter, values, bind_expression_template=None + ): + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + + if not values: + # empty IN expression. note we don't need to use + # bind_expression_template here because there are no + # expressions to render. + + if typ_dialect_impl._is_tuple_type: + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op + ) + + else: + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op + ) + + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], collections_abc.Sequence) + and not isinstance(values[0], (str, bytes)) + ): + if typ_dialect_impl._has_bind_expression: + raise NotImplementedError( + "bind_expression() on TupleType not supported with " + "literal_binds" + ) + + replacement_expression = ( + "VALUES " if self.dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + self.render_literal_value(value, param_type) + for value, param_type in zip( + tuple_element, parameter.type.types + ) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + if bind_expression_template: + post_compile_pattern = self._post_compile_pattern + m = post_compile_pattern.search(bind_expression_template) + assert m and m.group( + 2 + ), "unexpected format for expanding parameter" + + tok = m.group(2).split("~~") + be_left, be_right = tok[1], tok[3] + replacement_expression = ", ".join( + "%s%s%s" + % ( + be_left, + self.render_literal_value(value, parameter.type), + be_right, + ) + for value in values + ) + else: + replacement_expression = ", ".join( + self.render_literal_value(value, parameter.type) + for value in values + ) + + return (), replacement_expression + + def _literal_execute_expanding_parameter(self, name, parameter, values): + if parameter.literal_execute: + return self._literal_execute_expanding_parameter_literal_binds( + parameter, values + ) + + dialect = self.dialect + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect) + + if self._numeric_binds: + bind_template = self.compilation_bindtemplate + else: + bind_template = self.bindtemplate + + if ( + self.dialect._bind_typing_render_casts + and typ_dialect_impl.render_bind_cast + ): + + def _render_bindtemplate(name): + return self.render_bind_cast( + parameter.type, + typ_dialect_impl, + bind_template % {"name": name}, + ) + + else: + + def _render_bindtemplate(name): + return bind_template % {"name": name} + + if not values: + to_update = [] + if typ_dialect_impl._is_tuple_type: + replacement_expression = self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op + ) + else: + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op + ) + + elif typ_dialect_impl._is_tuple_type or ( + typ_dialect_impl._isnull + and isinstance(values[0], collections_abc.Sequence) + and not isinstance(values[0], (str, bytes)) + ): + assert not typ_dialect_impl._is_array + to_update = [ + ("%s_%s_%s" % (name, i, j), value) + for i, tuple_element in enumerate(values, 1) + for j, value in enumerate(tuple_element, 1) + ] + + replacement_expression = ( + "VALUES " if dialect.tuple_in_values else "" + ) + ", ".join( + "(%s)" + % ( + ", ".join( + _render_bindtemplate( + to_update[i * len(tuple_element) + j][0] + ) + for j, value in enumerate(tuple_element) + ) + ) + for i, tuple_element in enumerate(values) + ) + else: + to_update = [ + ("%s_%s" % (name, i), value) + for i, value in enumerate(values, 1) + ] + replacement_expression = ", ".join( + _render_bindtemplate(key) for key, value in to_update + ) + + return to_update, replacement_expression + + def visit_binary( + self, + binary, + override_operator=None, + eager_grouping=False, + from_linter=None, + lateral_from_linter=None, + **kw, + ): + if from_linter and operators.is_comparison(binary.operator): + if lateral_from_linter is not None: + enclosing_lateral = kw["enclosing_lateral"] + lateral_from_linter.edges.update( + itertools.product( + _de_clone( + binary.left._from_objects + [enclosing_lateral] + ), + _de_clone( + binary.right._from_objects + [enclosing_lateral] + ), + ) + ) + else: + from_linter.edges.update( + itertools.product( + _de_clone(binary.left._from_objects), + _de_clone(binary.right._from_objects), + ) + ) + + # don't allow "? = ?" to render + if ( + self.ansi_bind_rules + and isinstance(binary.left, elements.BindParameter) + and isinstance(binary.right, elements.BindParameter) + ): + kw["literal_execute"] = True + + operator_ = override_operator or binary.operator + disp = self._get_operator_dispatch(operator_, "binary", None) + if disp: + return disp(binary, operator_, **kw) + else: + try: + opstring = OPERATORS[operator_] + except KeyError as err: + raise exc.UnsupportedCompilationError(self, operator_) from err + else: + return self._generate_generic_binary( + binary, + opstring, + from_linter=from_linter, + lateral_from_linter=lateral_from_linter, + **kw, + ) + + def visit_function_as_comparison_op_binary(self, element, operator, **kw): + return self.process(element.sql_function, **kw) + + def visit_mod_binary(self, binary, operator, **kw): + if self.preparer._double_percents: + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) + else: + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) + + def visit_custom_op_binary(self, element, operator, **kw): + kw["eager_grouping"] = operator.eager_grouping + return self._generate_generic_binary( + element, + " " + self.escape_literal_column(operator.opstring) + " ", + **kw, + ) + + def visit_custom_op_unary_operator(self, element, operator, **kw): + return self._generate_generic_unary_operator( + element, self.escape_literal_column(operator.opstring) + " ", **kw + ) + + def visit_custom_op_unary_modifier(self, element, operator, **kw): + return self._generate_generic_unary_modifier( + element, " " + self.escape_literal_column(operator.opstring), **kw + ) + + def _generate_generic_binary( + self, binary, opstring, eager_grouping=False, **kw + ): + _in_operator_expression = kw.get("_in_operator_expression", False) + + kw["_in_operator_expression"] = True + kw["_binary_op"] = binary.operator + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) + + if _in_operator_expression and eager_grouping: + text = "(%s)" % text + return text + + def _generate_generic_unary_operator(self, unary, opstring, **kw): + return opstring + unary.element._compiler_dispatch(self, **kw) + + def _generate_generic_unary_modifier(self, unary, opstring, **kw): + return unary.element._compiler_dispatch(self, **kw) + opstring + + @util.memoized_property + def _like_percent_literal(self): + return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE) + + def visit_ilike_case_insensitive_operand(self, element, **kw): + return f"lower({element.element._compiler_dispatch(self, **kw)})" + + def visit_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right).concat(percent) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_contains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right).concat(percent) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_icontains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat( + ilike_case_insensitive(binary.right) + ).concat(percent) + return self.visit_ilike_op_binary(binary, operator, **kw) + + def visit_not_icontains_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat( + ilike_case_insensitive(binary.right) + ).concat(percent) + return self.visit_not_ilike_op_binary(binary, operator, **kw) + + def visit_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent._rconcat(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_startswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent._rconcat(binary.right) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_istartswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent._rconcat(ilike_case_insensitive(binary.right)) + return self.visit_ilike_op_binary(binary, operator, **kw) + + def visit_not_istartswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent._rconcat(ilike_case_insensitive(binary.right)) + return self.visit_not_ilike_op_binary(binary, operator, **kw) + + def visit_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right) + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_endswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.right = percent.concat(binary.right) + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_iendswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat(ilike_case_insensitive(binary.right)) + return self.visit_ilike_op_binary(binary, operator, **kw) + + def visit_not_iendswith_op_binary(self, binary, operator, **kw): + binary = binary._clone() + percent = self._like_percent_literal + binary.left = ilike_case_insensitive(binary.left) + binary.right = percent.concat(ilike_case_insensitive(binary.right)) + return self.visit_not_ilike_op_binary(binary, operator, **kw) + + def visit_like_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + + return "%s LIKE %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape is not None + else "" + ) + + def visit_not_like_op_binary(self, binary, operator, **kw): + escape = binary.modifiers.get("escape", None) + return "%s NOT LIKE %s" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape is not None + else "" + ) + + def visit_ilike_op_binary(self, binary, operator, **kw): + if operator is operators.ilike_op: + binary = binary._clone() + binary.left = ilike_case_insensitive(binary.left) + binary.right = ilike_case_insensitive(binary.right) + # else we assume ilower() has been applied + + return self.visit_like_op_binary(binary, operator, **kw) + + def visit_not_ilike_op_binary(self, binary, operator, **kw): + if operator is operators.not_ilike_op: + binary = binary._clone() + binary.left = ilike_case_insensitive(binary.left) + binary.right = ilike_case_insensitive(binary.right) + # else we assume ilower() has been applied + + return self.visit_not_like_op_binary(binary, operator, **kw) + + def visit_between_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw + ) + + def visit_not_between_op_binary(self, binary, operator, **kw): + symmetric = binary.modifiers.get("symmetric", False) + return self._generate_generic_binary( + binary, + " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ", + **kw, + ) + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expressions" + % self.dialect.name + ) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expressions" + % self.dialect.name + ) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + raise exc.CompileError( + "%s dialect does not support regular expression replacements" + % self.dialect.name + ) + + def visit_bindparam( + self, + bindparam, + within_columns_clause=False, + literal_binds=False, + skip_bind_expression=False, + literal_execute=False, + render_postcompile=False, + **kwargs, + ): + + if not skip_bind_expression: + impl = bindparam.type.dialect_impl(self.dialect) + if impl._has_bind_expression: + bind_expression = impl.bind_expression(bindparam) + wrapped = self.process( + bind_expression, + skip_bind_expression=True, + within_columns_clause=within_columns_clause, + literal_binds=literal_binds and not bindparam.expanding, + literal_execute=literal_execute, + render_postcompile=render_postcompile, + **kwargs, + ) + if bindparam.expanding: + # for postcompile w/ expanding, move the "wrapped" part + # of this into the inside + + m = re.match( + r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped + ) + assert m, "unexpected format for expanding parameter" + wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( + m.group(2), + m.group(1), + m.group(3), + ) + + if literal_binds: + ret = self.render_literal_bindparam( + bindparam, + within_columns_clause=True, + bind_expression_template=wrapped, + **kwargs, + ) + return "(%s)" % ret + + return wrapped + + if not literal_binds: + literal_execute = ( + literal_execute + or bindparam.literal_execute + or (within_columns_clause and self.ansi_bind_rules) + ) + post_compile = literal_execute or bindparam.expanding + else: + post_compile = False + + if literal_binds: + ret = self.render_literal_bindparam( + bindparam, within_columns_clause=True, **kwargs + ) + if bindparam.expanding: + ret = "(%s)" % ret + return ret + + name = self._truncate_bindparam(bindparam) + + if name in self.binds: + existing = self.binds[name] + if existing is not bindparam: + if ( + (existing.unique or bindparam.unique) + and not existing.proxy_set.intersection( + bindparam.proxy_set + ) + and not existing._cloned_set.intersection( + bindparam._cloned_set + ) + ): + raise exc.CompileError( + "Bind parameter '%s' conflicts with " + "unique bind parameter of the same name" % name + ) + elif existing.expanding != bindparam.expanding: + raise exc.CompileError( + "Can't reuse bound parameter name '%s' in both " + "'expanding' (e.g. within an IN expression) and " + "non-expanding contexts. If this parameter is to " + "receive a list/array value, set 'expanding=True' on " + "it for expressions that aren't IN, otherwise use " + "a different parameter name." % (name,) + ) + elif existing._is_crud or bindparam._is_crud: + if existing._is_crud and bindparam._is_crud: + # TODO: this condition is not well understood. + # see tests in test/sql/test_update.py + raise exc.CompileError( + "Encountered unsupported case when compiling an " + "INSERT or UPDATE statement. If this is a " + "multi-table " + "UPDATE statement, please provide string-named " + "arguments to the " + "values() method with distinct names; support for " + "multi-table UPDATE statements that " + "target multiple tables for UPDATE is very " + "limited", + ) + else: + raise exc.CompileError( + f"bindparam() name '{bindparam.key}' is reserved " + "for automatic usage in the VALUES or SET " + "clause of this " + "insert/update statement. Please use a " + "name other than column name when using " + "bindparam() " + "with insert() or update() (for example, " + f"'b_{bindparam.key}')." + ) + + self.binds[bindparam.key] = self.binds[name] = bindparam + + # if we are given a cache key that we're going to match against, + # relate the bindparam here to one that is most likely present + # in the "extracted params" portion of the cache key. this is used + # to set up a positional mapping that is used to determine the + # correct parameters for a subsequent use of this compiled with + # a different set of parameter values. here, we accommodate for + # parameters that may have been cloned both before and after the cache + # key was been generated. + ckbm_tuple = self._cache_key_bind_match + + if ckbm_tuple: + ckbm, cksm = ckbm_tuple + for bp in bindparam._cloned_set: + if bp.key in cksm: + cb = cksm[bp.key] + ckbm[cb].append(bindparam) + + if bindparam.isoutparam: + self.has_out_parameters = True + + if post_compile: + if render_postcompile: + self._render_postcompile = True + + if literal_execute: + self.literal_execute_params |= {bindparam} + else: + self.post_compile_params |= {bindparam} + + ret = self.bindparam_string( + name, + post_compile=post_compile, + expanding=bindparam.expanding, + bindparam_type=bindparam.type, + **kwargs, + ) + + if bindparam.expanding: + ret = "(%s)" % ret + + return ret + + def render_bind_cast(self, type_, dbapi_type, sqltext): + raise NotImplementedError() + + def render_literal_bindparam( + self, + bindparam, + render_literal_value=NO_ARG, + bind_expression_template=None, + **kw, + ): + if render_literal_value is not NO_ARG: + value = render_literal_value + else: + if bindparam.value is None and bindparam.callable is None: + op = kw.get("_binary_op", None) + if op and op not in (operators.is_, operators.is_not): + util.warn_limited( + "Bound parameter '%s' rendering literal NULL in a SQL " + "expression; comparisons to NULL should not use " + "operators outside of 'is' or 'is not'", + (bindparam.key,), + ) + return self.process(sqltypes.NULLTYPE, **kw) + value = bindparam.effective_value + + if bindparam.expanding: + leep = self._literal_execute_expanding_parameter_literal_binds + to_update, replacement_expr = leep( + bindparam, + value, + bind_expression_template=bind_expression_template, + ) + return replacement_expr + else: + return self.render_literal_value(value, bindparam.type) + + def render_literal_value(self, value, type_): + """Render the value of a bind parameter as a quoted literal. + + This is used for statement sections that do not accept bind parameters + on the target driver/database. + + This should be implemented by subclasses using the quoting services + of the DBAPI. + + """ + + if value is None and not type_.should_evaluate_none: + # issue #10535 - handle NULL in the compiler without placing + # this onto each type, except for "evaluate None" types + # (e.g. JSON) + return self.process(elements.Null._instance()) + + processor = type_._cached_literal_processor(self.dialect) + if processor: + try: + return processor(value) + except Exception as e: + raise exc.CompileError( + f"Could not render literal value " + f'"{sql_util._repr_single_value(value)}" ' + f"with datatype " + f"{type_}; see parent stack trace for " + "more detail." + ) from e + + else: + raise exc.CompileError( + f"No literal value renderer is available for literal value " + f'"{sql_util._repr_single_value(value)}" ' + f"with datatype {type_}" + ) + + def _truncate_bindparam(self, bindparam): + if bindparam in self.bind_names: + return self.bind_names[bindparam] + + bind_name = bindparam.key + if isinstance(bind_name, elements._truncated_label): + bind_name = self._truncated_identifier("bindparam", bind_name) + + # add to bind_names for translation + self.bind_names[bindparam] = bind_name + + return bind_name + + def _truncated_identifier( + self, ident_class: str, name: _truncated_label + ) -> str: + if (ident_class, name) in self.truncated_names: + return self.truncated_names[(ident_class, name)] + + anonname = name.apply_map(self.anon_map) + + if len(anonname) > self.label_length - 6: + counter = self._truncated_counters.get(ident_class, 1) + truncname = ( + anonname[0 : max(self.label_length - 6, 0)] + + "_" + + hex(counter)[2:] + ) + self._truncated_counters[ident_class] = counter + 1 + else: + truncname = anonname + self.truncated_names[(ident_class, name)] = truncname + return truncname + + def _anonymize(self, name: str) -> str: + return name % self.anon_map + + def bindparam_string( + self, + name: str, + post_compile: bool = False, + expanding: bool = False, + escaped_from: Optional[str] = None, + bindparam_type: Optional[TypeEngine[Any]] = None, + accumulate_bind_names: Optional[Set[str]] = None, + visited_bindparam: Optional[List[str]] = None, + **kw: Any, + ) -> str: + # TODO: accumulate_bind_names is passed by crud.py to gather + # names on a per-value basis, visited_bindparam is passed by + # visit_insert() to collect all parameters in the statement. + # see if this gathering can be simplified somehow + if accumulate_bind_names is not None: + accumulate_bind_names.add(name) + if visited_bindparam is not None: + visited_bindparam.append(name) + + if not escaped_from: + if self._bind_translate_re.search(name): + # not quite the translate use case as we want to + # also get a quick boolean if we even found + # unusual characters in the name + new_name = self._bind_translate_re.sub( + lambda m: self._bind_translate_chars[m.group(0)], + name, + ) + escaped_from = name + name = new_name + + if escaped_from: + self.escaped_bind_names = self.escaped_bind_names.union( + {escaped_from: name} + ) + if post_compile: + ret = "__[POSTCOMPILE_%s]" % name + if expanding: + # for expanding, bound parameters or literal values will be + # rendered per item + return ret + + # otherwise, for non-expanding "literal execute", apply + # bind casts as determined by the datatype + if bindparam_type is not None: + type_impl = bindparam_type._unwrapped_dialect_impl( + self.dialect + ) + if type_impl.render_literal_cast: + ret = self.render_bind_cast(bindparam_type, type_impl, ret) + return ret + elif self.state is CompilerState.COMPILING: + ret = self.compilation_bindtemplate % {"name": name} + else: + ret = self.bindtemplate % {"name": name} + + if ( + bindparam_type is not None + and self.dialect._bind_typing_render_casts + ): + type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect) + if type_impl.render_bind_cast: + ret = self.render_bind_cast(bindparam_type, type_impl, ret) + + return ret + + def _dispatch_independent_ctes(self, stmt, kw): + local_kw = kw.copy() + local_kw.pop("cte_opts", None) + for cte, opt in zip( + stmt._independent_ctes, stmt._independent_ctes_opts + ): + cte._compiler_dispatch(self, cte_opts=opt, **local_kw) + + def visit_cte( + self, + cte: CTE, + asfrom: bool = False, + ashint: bool = False, + fromhints: Optional[_FromHintsType] = None, + visiting_cte: Optional[CTE] = None, + from_linter: Optional[FromLinter] = None, + cte_opts: selectable._CTEOpts = selectable._CTEOpts(False), + **kwargs: Any, + ) -> Optional[str]: + self_ctes = self._init_cte_state() + assert self_ctes is self.ctes + + kwargs["visiting_cte"] = cte + + cte_name = cte.name + + if isinstance(cte_name, elements._truncated_label): + cte_name = self._truncated_identifier("alias", cte_name) + + is_new_cte = True + embedded_in_current_named_cte = False + + _reference_cte = cte._get_reference_cte() + + nesting = cte.nesting or cte_opts.nesting + + # check for CTE already encountered + if _reference_cte in self.level_name_by_cte: + cte_level, _, existing_cte_opts = self.level_name_by_cte[ + _reference_cte + ] + assert _ == cte_name + + cte_level_name = (cte_level, cte_name) + existing_cte = self.ctes_by_level_name[cte_level_name] + + # check if we are receiving it here with a specific + # "nest_here" location; if so, move it to this location + + if cte_opts.nesting: + if existing_cte_opts.nesting: + raise exc.CompileError( + "CTE is stated as 'nest_here' in " + "more than one location" + ) + + old_level_name = (cte_level, cte_name) + cte_level = len(self.stack) if nesting else 1 + cte_level_name = new_level_name = (cte_level, cte_name) + + del self.ctes_by_level_name[old_level_name] + self.ctes_by_level_name[new_level_name] = existing_cte + self.level_name_by_cte[_reference_cte] = new_level_name + ( + cte_opts, + ) + + else: + cte_level = len(self.stack) if nesting else 1 + cte_level_name = (cte_level, cte_name) + + if cte_level_name in self.ctes_by_level_name: + existing_cte = self.ctes_by_level_name[cte_level_name] + else: + existing_cte = None + + if existing_cte is not None: + embedded_in_current_named_cte = visiting_cte is existing_cte + + # we've generated a same-named CTE that we are enclosed in, + # or this is the same CTE. just return the name. + if cte is existing_cte._restates or cte is existing_cte: + is_new_cte = False + elif existing_cte is cte._restates: + # we've generated a same-named CTE that is + # enclosed in us - we take precedence, so + # discard the text for the "inner". + del self_ctes[existing_cte] + + existing_cte_reference_cte = existing_cte._get_reference_cte() + + assert existing_cte_reference_cte is _reference_cte + assert existing_cte_reference_cte is existing_cte + + del self.level_name_by_cte[existing_cte_reference_cte] + else: + # if the two CTEs are deep-copy identical, consider them + # the same, **if** they are clones, that is, they came from + # the ORM or other visit method + if ( + cte._is_clone_of is not None + or existing_cte._is_clone_of is not None + ) and cte.compare(existing_cte): + is_new_cte = False + else: + raise exc.CompileError( + "Multiple, unrelated CTEs found with " + "the same name: %r" % cte_name + ) + + if not asfrom and not is_new_cte: + return None + + if cte._cte_alias is not None: + pre_alias_cte = cte._cte_alias + cte_pre_alias_name = cte._cte_alias.name + if isinstance(cte_pre_alias_name, elements._truncated_label): + cte_pre_alias_name = self._truncated_identifier( + "alias", cte_pre_alias_name + ) + else: + pre_alias_cte = cte + cte_pre_alias_name = None + + if is_new_cte: + self.ctes_by_level_name[cte_level_name] = cte + self.level_name_by_cte[_reference_cte] = cte_level_name + ( + cte_opts, + ) + + if pre_alias_cte not in self.ctes: + self.visit_cte(pre_alias_cte, **kwargs) + + if not cte_pre_alias_name and cte not in self_ctes: + if cte.recursive: + self.ctes_recursive = True + text = self.preparer.format_alias(cte, cte_name) + if cte.recursive: + col_source = cte.element + + # TODO: can we get at the .columns_plus_names collection + # that is already (or will be?) generated for the SELECT + # rather than calling twice? + recur_cols = [ + # TODO: proxy_name is not technically safe, + # see test_cte-> + # test_with_recursive_no_name_currently_buggy. not + # clear what should be done with such a case + fallback_label_name or proxy_name + for ( + _, + proxy_name, + fallback_label_name, + c, + repeated, + ) in (col_source._generate_columns_plus_names(True)) + if not repeated + ] + + text += "(%s)" % ( + ", ".join( + self.preparer.format_label_name( + ident, anon_map=self.anon_map + ) + for ident in recur_cols + ) + ) + + assert kwargs.get("subquery", False) is False + + if not self.stack: + # toplevel, this is a stringify of the + # cte directly. just compile the inner + # the way alias() does. + return cte.element._compiler_dispatch( + self, asfrom=asfrom, **kwargs + ) + else: + prefixes = self._generate_prefixes( + cte, cte._prefixes, **kwargs + ) + inner = cte.element._compiler_dispatch( + self, asfrom=True, **kwargs + ) + + text += " AS %s\n(%s)" % (prefixes, inner) + + if cte._suffixes: + text += " " + self._generate_prefixes( + cte, cte._suffixes, **kwargs + ) + + self_ctes[cte] = text + + if asfrom: + if from_linter: + from_linter.froms[cte._de_clone()] = cte_name + + if not is_new_cte and embedded_in_current_named_cte: + return self.preparer.format_alias(cte, cte_name) + + if cte_pre_alias_name: + text = self.preparer.format_alias(cte, cte_pre_alias_name) + if self.preparer._requires_quotes(cte_name): + cte_name = self.preparer.quote(cte_name) + text += self.get_render_as_alias_suffix(cte_name) + return text + else: + return self.preparer.format_alias(cte, cte_name) + + return None + + def visit_table_valued_alias(self, element, **kw): + if element.joins_implicitly: + kw["from_linter"] = None + if element._is_lateral: + return self.visit_lateral(element, **kw) + else: + return self.visit_alias(element, **kw) + + def visit_table_valued_column(self, element, **kw): + return self.visit_column(element, **kw) + + def visit_alias( + self, + alias, + asfrom=False, + ashint=False, + iscrud=False, + fromhints=None, + subquery=False, + lateral=False, + enclosing_alias=None, + from_linter=None, + **kwargs, + ): + if lateral: + if "enclosing_lateral" not in kwargs: + # if lateral is set and enclosing_lateral is not + # present, we assume we are being called directly + # from visit_lateral() and we need to set enclosing_lateral. + assert alias._is_lateral + kwargs["enclosing_lateral"] = alias + + # for lateral objects, we track a second from_linter that is... + # lateral! to the level above us. + if ( + from_linter + and "lateral_from_linter" not in kwargs + and "enclosing_lateral" in kwargs + ): + kwargs["lateral_from_linter"] = from_linter + + if enclosing_alias is not None and enclosing_alias.element is alias: + inner = alias.element._compiler_dispatch( + self, + asfrom=asfrom, + ashint=ashint, + iscrud=iscrud, + fromhints=fromhints, + lateral=lateral, + enclosing_alias=alias, + **kwargs, + ) + if subquery and (asfrom or lateral): + inner = "(%s)" % (inner,) + return inner + else: + enclosing_alias = kwargs["enclosing_alias"] = alias + + if asfrom or ashint: + if isinstance(alias.name, elements._truncated_label): + alias_name = self._truncated_identifier("alias", alias.name) + else: + alias_name = alias.name + + if ashint: + return self.preparer.format_alias(alias, alias_name) + elif asfrom: + if from_linter: + from_linter.froms[alias._de_clone()] = alias_name + + inner = alias.element._compiler_dispatch( + self, asfrom=True, lateral=lateral, **kwargs + ) + if subquery: + inner = "(%s)" % (inner,) + + ret = inner + self.get_render_as_alias_suffix( + self.preparer.format_alias(alias, alias_name) + ) + + if alias._supports_derived_columns and alias._render_derived: + ret += "(%s)" % ( + ", ".join( + "%s%s" + % ( + self.preparer.quote(col.name), + ( + " %s" + % self.dialect.type_compiler_instance.process( + col.type, **kwargs + ) + if alias._render_derived_w_types + else "" + ), + ) + for col in alias.c + ) + ) + + if fromhints and alias in fromhints: + ret = self.format_from_hint_text( + ret, alias, fromhints[alias], iscrud + ) + + return ret + else: + # note we cancel the "subquery" flag here as well + return alias.element._compiler_dispatch( + self, lateral=lateral, **kwargs + ) + + def visit_subquery(self, subquery, **kw): + kw["subquery"] = True + return self.visit_alias(subquery, **kw) + + def visit_lateral(self, lateral_, **kw): + kw["lateral"] = True + return "LATERAL %s" % self.visit_alias(lateral_, **kw) + + def visit_tablesample(self, tablesample, asfrom=False, **kw): + text = "%s TABLESAMPLE %s" % ( + self.visit_alias(tablesample, asfrom=True, **kw), + tablesample._get_method()._compiler_dispatch(self, **kw), + ) + + if tablesample.seed is not None: + text += " REPEATABLE (%s)" % ( + tablesample.seed._compiler_dispatch(self, **kw) + ) + + return text + + def _render_values(self, element, **kw): + kw.setdefault("literal_binds", element.literal_binds) + tuples = ", ".join( + self.process( + elements.Tuple( + types=element._column_types, *elem + ).self_group(), + **kw, + ) + for chunk in element._data + for elem in chunk + ) + return f"VALUES {tuples}" + + def visit_values(self, element, asfrom=False, from_linter=None, **kw): + v = self._render_values(element, **kw) + + if element._unnamed: + name = None + elif isinstance(element.name, elements._truncated_label): + name = self._truncated_identifier("values", element.name) + else: + name = element.name + + if element._is_lateral: + lateral = "LATERAL " + else: + lateral = "" + + if asfrom: + if from_linter: + from_linter.froms[element._de_clone()] = ( + name if name is not None else "(unnamed VALUES element)" + ) + + if name: + kw["include_table"] = False + v = "%s(%s)%s (%s)" % ( + lateral, + v, + self.get_render_as_alias_suffix(self.preparer.quote(name)), + ( + ", ".join( + c._compiler_dispatch(self, **kw) + for c in element.columns + ) + ), + ) + else: + v = "%s(%s)" % (lateral, v) + return v + + def visit_scalar_values(self, element, **kw): + return f"({self._render_values(element, **kw)})" + + def get_render_as_alias_suffix(self, alias_name_text): + return " AS " + alias_name_text + + def _add_to_result_map( + self, + keyname: str, + name: str, + objects: Tuple[Any, ...], + type_: TypeEngine[Any], + ) -> None: + if keyname is None or keyname == "*": + self._ordered_columns = False + self._ad_hoc_textual = True + if type_._is_tuple_type: + raise exc.CompileError( + "Most backends don't support SELECTing " + "from a tuple() object. If this is an ORM query, " + "consider using the Bundle object." + ) + self._result_columns.append( + ResultColumnsEntry(keyname, name, objects, type_) + ) + + def _label_returning_column( + self, stmt, column, populate_result_map, column_clause_args=None, **kw + ): + """Render a column with necessary labels inside of a RETURNING clause. + + This method is provided for individual dialects in place of calling + the _label_select_column method directly, so that the two use cases + of RETURNING vs. SELECT can be disambiguated going forward. + + .. versionadded:: 1.4.21 + + """ + return self._label_select_column( + None, + column, + populate_result_map, + False, + {} if column_clause_args is None else column_clause_args, + **kw, + ) + + def _label_select_column( + self, + select, + column, + populate_result_map, + asfrom, + column_clause_args, + name=None, + proxy_name=None, + fallback_label_name=None, + within_columns_clause=True, + column_is_repeated=False, + need_column_expressions=False, + include_table=True, + ): + """produce labeled columns present in a select().""" + impl = column.type.dialect_impl(self.dialect) + + if impl._has_column_expression and ( + need_column_expressions or populate_result_map + ): + col_expr = impl.column_expression(column) + else: + col_expr = column + + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( + keyname, name, (column,) + objects, type_ + ) + + else: + add_to_result_map = None + + # this method is used by some of the dialects for RETURNING, + # which has different inputs. _label_returning_column was added + # as the better target for this now however for 1.4 we will keep + # _label_select_column directly compatible with this use case. + # these assertions right now set up the current expected inputs + assert within_columns_clause, ( + "_label_select_column is only relevant within " + "the columns clause of a SELECT or RETURNING" + ) + if isinstance(column, elements.Label): + if col_expr is not column: + result_expr = _CompileLabel( + col_expr, column.name, alt_names=(column.element,) + ) + else: + result_expr = col_expr + + elif name: + # here, _columns_plus_names has determined there's an explicit + # label name we need to use. this is the default for + # tablenames_plus_columnnames as well as when columns are being + # deduplicated on name + + assert ( + proxy_name is not None + ), "proxy_name is required if 'name' is passed" + + result_expr = _CompileLabel( + col_expr, + name, + alt_names=( + proxy_name, + # this is a hack to allow legacy result column lookups + # to work as they did before; this goes away in 2.0. + # TODO: this only seems to be tested indirectly + # via test/orm/test_deprecations.py. should be a + # resultset test for this + column._tq_label, + ), + ) + else: + # determine here whether this column should be rendered in + # a labelled context or not, as we were given no required label + # name from the caller. Here we apply heuristics based on the kind + # of SQL expression involved. + + if col_expr is not column: + # type-specific expression wrapping the given column, + # so we render a label + render_with_label = True + elif isinstance(column, elements.ColumnClause): + # table-bound column, we render its name as a label if we are + # inside of a subquery only + render_with_label = ( + asfrom + and not column.is_literal + and column.table is not None + ) + elif isinstance(column, elements.TextClause): + render_with_label = False + elif isinstance(column, elements.UnaryExpression): + render_with_label = column.wraps_column_expression or asfrom + elif ( + # general class of expressions that don't have a SQL-column + # addressible name. includes scalar selects, bind parameters, + # SQL functions, others + not isinstance(column, elements.NamedColumn) + # deeper check that indicates there's no natural "name" to + # this element, which accommodates for custom SQL constructs + # that might have a ".name" attribute (but aren't SQL + # functions) but are not implementing this more recently added + # base class. in theory the "NamedColumn" check should be + # enough, however here we seek to maintain legacy behaviors + # as well. + and column._non_anon_label is None + ): + render_with_label = True + else: + render_with_label = False + + if render_with_label: + if not fallback_label_name: + # used by the RETURNING case right now. we generate it + # here as 3rd party dialects may be referring to + # _label_select_column method directly instead of the + # just-added _label_returning_column method + assert not column_is_repeated + fallback_label_name = column._anon_name_label + + fallback_label_name = ( + elements._truncated_label(fallback_label_name) + if not isinstance( + fallback_label_name, elements._truncated_label + ) + else fallback_label_name + ) + + result_expr = _CompileLabel( + col_expr, fallback_label_name, alt_names=(proxy_name,) + ) + else: + result_expr = col_expr + + column_clause_args.update( + within_columns_clause=within_columns_clause, + add_to_result_map=add_to_result_map, + include_table=include_table, + ) + return result_expr._compiler_dispatch(self, **column_clause_args) + + def format_from_hint_text(self, sqltext, table, hint, iscrud): + hinttext = self.get_from_hint_text(table, hint) + if hinttext: + sqltext += " " + hinttext + return sqltext + + def get_select_hint_text(self, byfroms): + return None + + def get_from_hint_text(self, table, text): + return None + + def get_crud_hint_text(self, table, text): + return None + + def get_statement_hint_text(self, hint_texts): + return " ".join(hint_texts) + + _default_stack_entry: _CompilerStackEntry + + if not typing.TYPE_CHECKING: + _default_stack_entry = util.immutabledict( + [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())] + ) + + def _display_froms_for_select( + self, select_stmt, asfrom, lateral=False, **kw + ): + # utility method to help external dialects + # get the correct from list for a select. + # specifically the oracle dialect needs this feature + # right now. + toplevel = not self.stack + entry = self._default_stack_entry if toplevel else self.stack[-1] + + compile_state = select_stmt._compile_state_factory(select_stmt, self) + + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] + + if asfrom and not lateral: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms + ), + implicit_correlate_froms=(), + ) + else: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms, + ) + return froms + + translate_select_structure: Any = None + """if not ``None``, should be a callable which accepts ``(select_stmt, + **kw)`` and returns a select object. this is used for structural changes + mostly to accommodate for LIMIT/OFFSET schemes + + """ + + def visit_select( + self, + select_stmt, + asfrom=False, + insert_into=False, + fromhints=None, + compound_index=None, + select_wraps_for=None, + lateral=False, + from_linter=None, + **kwargs, + ): + assert select_wraps_for is None, ( + "SQLAlchemy 1.4 requires use of " + "the translate_select_structure hook for structural " + "translations of SELECT objects" + ) + + # initial setup of SELECT. the compile_state_factory may now + # be creating a totally different SELECT from the one that was + # passed in. for ORM use this will convert from an ORM-state + # SELECT to a regular "Core" SELECT. other composed operations + # such as computation of joins will be performed. + + kwargs["within_columns_clause"] = False + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + kwargs["ambiguous_table_name_map"] = ( + compile_state._ambiguous_table_name_map + ) + + select_stmt = compile_state.statement + + toplevel = not self.stack + + if toplevel and not self.compile_state: + self.compile_state = compile_state + + is_embedded_select = compound_index is not None or insert_into + + # translate step for Oracle, SQL Server which often need to + # restructure the SELECT to allow for LIMIT/OFFSET and possibly + # other conditions + if self.translate_select_structure: + new_select_stmt = self.translate_select_structure( + select_stmt, asfrom=asfrom, **kwargs + ) + + # if SELECT was restructured, maintain a link to the originals + # and assemble a new compile state + if new_select_stmt is not select_stmt: + compile_state_wraps_for = compile_state + select_wraps_for = select_stmt + select_stmt = new_select_stmt + + compile_state = select_stmt._compile_state_factory( + select_stmt, self, **kwargs + ) + select_stmt = compile_state.statement + + entry = self._default_stack_entry if toplevel else self.stack[-1] + + populate_result_map = need_column_expressions = ( + toplevel + or entry.get("need_result_map_for_compound", False) + or entry.get("need_result_map_for_nested", False) + ) + + # indicates there is a CompoundSelect in play and we are not the + # first select + if compound_index: + populate_result_map = False + + # this was first proposed as part of #3372; however, it is not + # reached in current tests and could possibly be an assertion + # instead. + if not populate_result_map and "add_to_result_map" in kwargs: + del kwargs["add_to_result_map"] + + froms = self._setup_select_stack( + select_stmt, compile_state, entry, asfrom, lateral, compound_index + ) + + column_clause_args = kwargs.copy() + column_clause_args.update( + {"within_label_clause": False, "within_columns_clause": False} + ) + + text = "SELECT " # we're off to a good start ! + + if select_stmt._hints: + hint_text, byfrom = self._setup_select_hints(select_stmt) + if hint_text: + text += hint_text + " " + else: + byfrom = None + + if select_stmt._independent_ctes: + self._dispatch_independent_ctes(select_stmt, kwargs) + + if select_stmt._prefixes: + text += self._generate_prefixes( + select_stmt, select_stmt._prefixes, **kwargs + ) + + text += self.get_select_precolumns(select_stmt, **kwargs) + # the actual list of columns to print in the SELECT column list. + inner_columns = [ + c + for c in [ + self._label_select_column( + select_stmt, + column, + populate_result_map, + asfrom, + column_clause_args, + name=name, + proxy_name=proxy_name, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + need_column_expressions=need_column_expressions, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in compile_state.columns_plus_names + ] + if c is not None + ] + + if populate_result_map and select_wraps_for is not None: + # if this select was generated from translate_select, + # rewrite the targeted columns in the result map + + translate = dict( + zip( + [ + name + for ( + key, + proxy_name, + fallback_label_name, + name, + repeated, + ) in compile_state.columns_plus_names + ], + [ + name + for ( + key, + proxy_name, + fallback_label_name, + name, + repeated, + ) in compile_state_wraps_for.columns_plus_names + ], + ) + ) + + self._result_columns = [ + ResultColumnsEntry( + key, name, tuple(translate.get(o, o) for o in obj), type_ + ) + for key, name, obj, type_ in self._result_columns + ] + + text = self._compose_select_body( + text, + select_stmt, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, + ) + + if select_stmt._statement_hints: + per_dialect = [ + ht + for (dialect_name, ht) in select_stmt._statement_hints + if dialect_name in ("*", self.dialect.name) + ] + if per_dialect: + text += " " + self.get_statement_hint_text(per_dialect) + + # In compound query, CTEs are shared at the compound level + if self.ctes and (not is_embedded_select or toplevel): + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + if select_stmt._suffixes: + text += " " + self._generate_prefixes( + select_stmt, select_stmt._suffixes, **kwargs + ) + + self.stack.pop(-1) + + return text + + def _setup_select_hints( + self, select: Select[Any] + ) -> Tuple[str, _FromHintsType]: + byfrom = { + from_: hinttext + % {"name": from_._compiler_dispatch(self, ashint=True)} + for (from_, dialect), hinttext in select._hints.items() + if dialect in ("*", self.dialect.name) + } + hint_text = self.get_select_hint_text(byfrom) + return hint_text, byfrom + + def _setup_select_stack( + self, select, compile_state, entry, asfrom, lateral, compound_index + ): + correlate_froms = entry["correlate_froms"] + asfrom_froms = entry["asfrom_froms"] + + if compound_index == 0: + entry["select_0"] = select + elif compound_index: + select_0 = entry["select_0"] + numcols = len(select_0._all_selected_columns) + + if len(compile_state.columns_plus_names) != numcols: + raise exc.CompileError( + "All selectables passed to " + "CompoundSelect must have identical numbers of " + "columns; select #%d has %d columns, select " + "#%d has %d" + % ( + 1, + numcols, + compound_index + 1, + len(select._all_selected_columns), + ) + ) + + if asfrom and not lateral: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms.difference( + asfrom_froms + ), + implicit_correlate_froms=(), + ) + else: + froms = compile_state._get_display_froms( + explicit_correlate_froms=correlate_froms, + implicit_correlate_froms=asfrom_froms, + ) + + new_correlate_froms = set(_from_objects(*froms)) + all_correlate_froms = new_correlate_froms.union(correlate_froms) + + new_entry: _CompilerStackEntry = { + "asfrom_froms": new_correlate_froms, + "correlate_froms": all_correlate_froms, + "selectable": select, + "compile_state": compile_state, + } + self.stack.append(new_entry) + + return froms + + def _compose_select_body( + self, + text, + select, + compile_state, + inner_columns, + froms, + byfrom, + toplevel, + kwargs, + ): + text += ", ".join(inner_columns) + + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + + # adjust the whitespace for no inner columns, part of #9440, + # so that a no-col SELECT comes out as "SELECT WHERE..." or + # "SELECT FROM ...". + # while it would be better to have built the SELECT starting string + # without trailing whitespace first, then add whitespace only if inner + # cols were present, this breaks compatibility with various custom + # compilation schemes that are currently being tested. + if not inner_columns: + text = text.rstrip() + + if froms: + text += " \nFROM " + + if select._hints: + text += ", ".join( + [ + f._compiler_dispatch( + self, + asfrom=True, + fromhints=byfrom, + from_linter=from_linter, + **kwargs, + ) + for f in froms + ] + ) + else: + text += ", ".join( + [ + f._compiler_dispatch( + self, + asfrom=True, + from_linter=from_linter, + **kwargs, + ) + for f in froms + ] + ) + else: + text += self.default_from() + + if select._where_criteria: + t = self._generate_delimited_and_list( + select._where_criteria, from_linter=from_linter, **kwargs + ) + if t: + text += " \nWHERE " + t + + if warn_linting: + assert from_linter is not None + from_linter.warn() + + if select._group_by_clauses: + text += self.group_by_clause(select, **kwargs) + + if select._having_criteria: + t = self._generate_delimited_and_list( + select._having_criteria, **kwargs + ) + if t: + text += " \nHAVING " + t + + if select._order_by_clauses: + text += self.order_by_clause(select, **kwargs) + + if select._has_row_limiting_clause: + text += self._row_limit_clause(select, **kwargs) + + if select._for_update_arg is not None: + text += self.for_update_clause(select, **kwargs) + + return text + + def _generate_prefixes(self, stmt, prefixes, **kw): + clause = " ".join( + prefix._compiler_dispatch(self, **kw) + for prefix, dialect_name in prefixes + if dialect_name in (None, "*") or dialect_name == self.dialect.name + ) + if clause: + clause += " " + return clause + + def _render_cte_clause( + self, + nesting_level=None, + include_following_stack=False, + ): + """ + include_following_stack + Also render the nesting CTEs on the next stack. Useful for + SQL structures like UNION or INSERT that can wrap SELECT + statements containing nesting CTEs. + """ + if not self.ctes: + return "" + + ctes: MutableMapping[CTE, str] + + if nesting_level and nesting_level > 1: + ctes = util.OrderedDict() + for cte in list(self.ctes.keys()): + cte_level, cte_name, cte_opts = self.level_name_by_cte[ + cte._get_reference_cte() + ] + nesting = cte.nesting or cte_opts.nesting + is_rendered_level = cte_level == nesting_level or ( + include_following_stack and cte_level == nesting_level + 1 + ) + if not (nesting and is_rendered_level): + continue + + ctes[cte] = self.ctes[cte] + + else: + ctes = self.ctes + + if not ctes: + return "" + ctes_recursive = any([cte.recursive for cte in ctes]) + + cte_text = self.get_cte_preamble(ctes_recursive) + " " + cte_text += ", \n".join([txt for txt in ctes.values()]) + cte_text += "\n " + + if nesting_level and nesting_level > 1: + for cte in list(ctes.keys()): + cte_level, cte_name, cte_opts = self.level_name_by_cte[ + cte._get_reference_cte() + ] + del self.ctes[cte] + del self.ctes_by_level_name[(cte_level, cte_name)] + del self.level_name_by_cte[cte._get_reference_cte()] + + return cte_text + + def get_cte_preamble(self, recursive): + if recursive: + return "WITH RECURSIVE" + else: + return "WITH" + + def get_select_precolumns(self, select, **kw): + """Called when building a ``SELECT`` statement, position is just + before column list. + + """ + if select._distinct_on: + util.warn_deprecated( + "DISTINCT ON is currently supported only by the PostgreSQL " + "dialect. Use of DISTINCT ON for other backends is currently " + "silently ignored, however this usage is deprecated, and will " + "raise CompileError in a future release for all backends " + "that do not support this syntax.", + version="1.4", + ) + return "DISTINCT " if select._distinct else "" + + def group_by_clause(self, select, **kw): + """allow dialects to customize how GROUP BY is rendered.""" + + group_by = self._generate_delimited_list( + select._group_by_clauses, OPERATORS[operators.comma_op], **kw + ) + if group_by: + return " GROUP BY " + group_by + else: + return "" + + def order_by_clause(self, select, **kw): + """allow dialects to customize how ORDER BY is rendered.""" + + order_by = self._generate_delimited_list( + select._order_by_clauses, OPERATORS[operators.comma_op], **kw + ) + + if order_by: + return " ORDER BY " + order_by + else: + return "" + + def for_update_clause(self, select, **kw): + return " FOR UPDATE" + + def returning_clause( + self, + stmt: UpdateBase, + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, + ) -> str: + columns = [ + self._label_returning_column( + stmt, + column, + populate_result_map, + fallback_label_name=fallback_label_name, + column_is_repeated=repeated, + name=name, + proxy_name=proxy_name, + **kw, + ) + for ( + name, + proxy_name, + fallback_label_name, + column, + repeated, + ) in stmt._generate_columns_plus_names( + True, cols=base._select_iterables(returning_cols) + ) + ] + + return "RETURNING " + ", ".join(columns) + + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + text += "\n LIMIT " + self.process(select._limit_clause, **kw) + if select._offset_clause is not None: + if select._limit_clause is None: + text += "\n LIMIT -1" + text += " OFFSET " + self.process(select._offset_clause, **kw) + return text + + def fetch_clause( + self, + select, + fetch_clause=None, + require_offset=False, + use_literal_execute_for_simple_int=False, + **kw, + ): + if fetch_clause is None: + fetch_clause = select._fetch_clause + fetch_clause_options = select._fetch_clause_options + else: + fetch_clause_options = {"percent": False, "with_ties": False} + + text = "" + + if select._offset_clause is not None: + offset_clause = select._offset_clause + if ( + use_literal_execute_for_simple_int + and select._simple_int_clause(offset_clause) + ): + offset_clause = offset_clause.render_literal_execute() + offset_str = self.process(offset_clause, **kw) + text += "\n OFFSET %s ROWS" % offset_str + elif require_offset: + text += "\n OFFSET 0 ROWS" + + if fetch_clause is not None: + if ( + use_literal_execute_for_simple_int + and select._simple_int_clause(fetch_clause) + ): + fetch_clause = fetch_clause.render_literal_execute() + text += "\n FETCH FIRST %s%s ROWS %s" % ( + self.process(fetch_clause, **kw), + " PERCENT" if fetch_clause_options["percent"] else "", + "WITH TIES" if fetch_clause_options["with_ties"] else "ONLY", + ) + return text + + def visit_table( + self, + table, + asfrom=False, + iscrud=False, + ashint=False, + fromhints=None, + use_schema=True, + from_linter=None, + ambiguous_table_name_map=None, + **kwargs, + ): + if from_linter: + from_linter.froms[table] = table.fullname + + if asfrom or ashint: + effective_schema = self.preparer.schema_for_object(table) + + if use_schema and effective_schema: + ret = ( + self.preparer.quote_schema(effective_schema) + + "." + + self.preparer.quote(table.name) + ) + else: + ret = self.preparer.quote(table.name) + + if ( + not effective_schema + and ambiguous_table_name_map + and table.name in ambiguous_table_name_map + ): + anon_name = self._truncated_identifier( + "alias", ambiguous_table_name_map[table.name] + ) + + ret = ret + self.get_render_as_alias_suffix( + self.preparer.format_alias(None, anon_name) + ) + + if fromhints and table in fromhints: + ret = self.format_from_hint_text( + ret, table, fromhints[table], iscrud + ) + return ret + else: + return "" + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.update( + itertools.product( + _de_clone(join.left._from_objects), + _de_clone(join.right._from_objects), + ) + ) + + if join.full: + join_type = " FULL OUTER JOIN " + elif join.isouter: + join_type = " LEFT OUTER JOIN " + else: + join_type = " JOIN " + return ( + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + join_type + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + + " ON " + # TODO: likely need asfrom=True here? + + join.onclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) + ) + + def _setup_crud_hints(self, stmt, table_text): + dialect_hints = { + table: hint_text + for (table, dialect), hint_text in stmt._hints.items() + if dialect in ("*", self.dialect.name) + } + if stmt.table in dialect_hints: + table_text = self.format_from_hint_text( + table_text, stmt.table, dialect_hints[stmt.table], True + ) + return dialect_hints, table_text + + # within the realm of "insertmanyvalues sentinel columns", + # these lookups match different kinds of Column() configurations + # to specific backend capabilities. they are broken into two + # lookups, one for autoincrement columns and the other for non + # autoincrement columns + _sentinel_col_non_autoinc_lookup = util.immutabledict( + { + _SentinelDefaultCharacterization.CLIENTSIDE: ( + InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT + ), + _SentinelDefaultCharacterization.SENTINEL_DEFAULT: ( + InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT + ), + _SentinelDefaultCharacterization.NONE: ( + InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT + ), + _SentinelDefaultCharacterization.IDENTITY: ( + InsertmanyvaluesSentinelOpts.IDENTITY + ), + _SentinelDefaultCharacterization.SEQUENCE: ( + InsertmanyvaluesSentinelOpts.SEQUENCE + ), + } + ) + _sentinel_col_autoinc_lookup = _sentinel_col_non_autoinc_lookup.union( + { + _SentinelDefaultCharacterization.NONE: ( + InsertmanyvaluesSentinelOpts.AUTOINCREMENT + ), + } + ) + + def _get_sentinel_column_for_table( + self, table: Table + ) -> Optional[Sequence[Column[Any]]]: + """given a :class:`.Table`, return a usable sentinel column or + columns for this dialect if any. + + Return None if no sentinel columns could be identified, or raise an + error if a column was marked as a sentinel explicitly but isn't + compatible with this dialect. + + """ + + sentinel_opts = self.dialect.insertmanyvalues_implicit_sentinel + sentinel_characteristics = table._sentinel_column_characteristics + + sent_cols = sentinel_characteristics.columns + + if sent_cols is None: + return None + + if sentinel_characteristics.is_autoinc: + bitmask = self._sentinel_col_autoinc_lookup.get( + sentinel_characteristics.default_characterization, 0 + ) + else: + bitmask = self._sentinel_col_non_autoinc_lookup.get( + sentinel_characteristics.default_characterization, 0 + ) + + if sentinel_opts & bitmask: + return sent_cols + + if sentinel_characteristics.is_explicit: + # a column was explicitly marked as insert_sentinel=True, + # however it is not compatible with this dialect. they should + # not indicate this column as a sentinel if they need to include + # this dialect. + + # TODO: do we want non-primary key explicit sentinel cols + # that can gracefully degrade for some backends? + # insert_sentinel="degrade" perhaps. not for the initial release. + # I am hoping people are generally not dealing with this sentinel + # business at all. + + # if is_explicit is True, there will be only one sentinel column. + + raise exc.InvalidRequestError( + f"Column {sent_cols[0]} can't be explicitly " + "marked as a sentinel column when using the " + f"{self.dialect.name} dialect, as the " + "particular type of default generation on this column is " + "not currently compatible with this dialect's specific " + f"INSERT..RETURNING syntax which can receive the " + "server-generated value in " + "a deterministic way. To remove this error, remove " + "insert_sentinel=True from primary key autoincrement " + "columns; these columns are automatically used as " + "sentinels for supported dialects in any case." + ) + + return None + + def _deliver_insertmanyvalues_batches( + self, + statement: str, + parameters: _DBAPIMultiExecuteParams, + compiled_parameters: List[_MutableCoreSingleExecuteParams], + generic_setinputsizes: Optional[_GenericSetInputSizesType], + batch_size: int, + sort_by_parameter_order: bool, + schema_translate_map: Optional[SchemaTranslateMapType], + ) -> Iterator[_InsertManyValuesBatch]: + imv = self._insertmanyvalues + assert imv is not None + + if not imv.sentinel_param_keys: + _sentinel_from_params = None + else: + _sentinel_from_params = operator.itemgetter( + *imv.sentinel_param_keys + ) + + lenparams = len(parameters) + if imv.is_default_expr and not self.dialect.supports_default_metavalue: + # backend doesn't support + # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ... + # at the moment this is basically SQL Server due to + # not being able to use DEFAULT for identity column + # just yield out that many single statements! still + # faster than a whole connection.execute() call ;) + # + # note we still are taking advantage of the fact that we know + # we are using RETURNING. The generalized approach of fetching + # cursor.lastrowid etc. still goes through the more heavyweight + # "ExecutionContext per statement" system as it isn't usable + # as a generic "RETURNING" approach + use_row_at_a_time = True + downgraded = False + elif not self.dialect.supports_multivalues_insert or ( + sort_by_parameter_order + and self._result_columns + and (imv.sentinel_columns is None or imv.includes_upsert_behaviors) + ): + # deterministic order was requested and the compiler could + # not organize sentinel columns for this dialect/statement. + # use row at a time + use_row_at_a_time = True + downgraded = True + else: + use_row_at_a_time = False + downgraded = False + + if use_row_at_a_time: + for batchnum, (param, compiled_param) in enumerate( + cast( + "Sequence[Tuple[_DBAPISingleExecuteParams, _MutableCoreSingleExecuteParams]]", # noqa: E501 + zip(parameters, compiled_parameters), + ), + 1, + ): + yield _InsertManyValuesBatch( + statement, + param, + generic_setinputsizes, + [param], + ( + [_sentinel_from_params(compiled_param)] + if _sentinel_from_params + else [] + ), + 1, + batchnum, + lenparams, + sort_by_parameter_order, + downgraded, + ) + return + + if schema_translate_map: + rst = functools.partial( + self.preparer._render_schema_translates, + schema_translate_map=schema_translate_map, + ) + else: + rst = None + + imv_single_values_expr = imv.single_values_expr + if rst: + imv_single_values_expr = rst(imv_single_values_expr) + + executemany_values = f"({imv_single_values_expr})" + statement = statement.replace(executemany_values, "__EXECMANY_TOKEN__") + + # Use optional insertmanyvalues_max_parameters + # to further shrink the batch size so that there are no more than + # insertmanyvalues_max_parameters params. + # Currently used by SQL Server, which limits statements to 2100 bound + # parameters (actually 2099). + max_params = self.dialect.insertmanyvalues_max_parameters + if max_params: + total_num_of_params = len(self.bind_names) + num_params_per_batch = len(imv.insert_crud_params) + num_params_outside_of_batch = ( + total_num_of_params - num_params_per_batch + ) + batch_size = min( + batch_size, + ( + (max_params - num_params_outside_of_batch) + // num_params_per_batch + ), + ) + + batches = cast("List[Sequence[Any]]", list(parameters)) + compiled_batches = cast( + "List[Sequence[Any]]", list(compiled_parameters) + ) + + processed_setinputsizes: Optional[_GenericSetInputSizesType] = None + batchnum = 1 + total_batches = lenparams // batch_size + ( + 1 if lenparams % batch_size else 0 + ) + + insert_crud_params = imv.insert_crud_params + assert insert_crud_params is not None + + if rst: + insert_crud_params = [ + (col, key, rst(expr), st) + for col, key, expr, st in insert_crud_params + ] + + escaped_bind_names: Mapping[str, str] + expand_pos_lower_index = expand_pos_upper_index = 0 + + if not self.positional: + if self.escaped_bind_names: + escaped_bind_names = self.escaped_bind_names + else: + escaped_bind_names = {} + + all_keys = set(parameters[0]) + + def apply_placeholders(keys, formatted): + for key in keys: + key = escaped_bind_names.get(key, key) + formatted = formatted.replace( + self.bindtemplate % {"name": key}, + self.bindtemplate + % {"name": f"{key}__EXECMANY_INDEX__"}, + ) + return formatted + + if imv.embed_values_counter: + imv_values_counter = ", _IMV_VALUES_COUNTER" + else: + imv_values_counter = "" + formatted_values_clause = f"""({', '.join( + apply_placeholders(bind_keys, formatted) + for _, _, formatted, bind_keys in insert_crud_params + )}{imv_values_counter})""" + + keys_to_replace = all_keys.intersection( + escaped_bind_names.get(key, key) + for _, _, _, bind_keys in insert_crud_params + for key in bind_keys + ) + base_parameters = { + key: parameters[0][key] + for key in all_keys.difference(keys_to_replace) + } + executemany_values_w_comma = "" + else: + formatted_values_clause = "" + keys_to_replace = set() + base_parameters = {} + + if imv.embed_values_counter: + executemany_values_w_comma = ( + f"({imv_single_values_expr}, _IMV_VALUES_COUNTER), " + ) + else: + executemany_values_w_comma = f"({imv_single_values_expr}), " + + all_names_we_will_expand: Set[str] = set() + for elem in imv.insert_crud_params: + all_names_we_will_expand.update(elem[3]) + + # get the start and end position in a particular list + # of parameters where we will be doing the "expanding". + # statements can have params on either side or both sides, + # given RETURNING and CTEs + if all_names_we_will_expand: + positiontup = self.positiontup + assert positiontup is not None + + all_expand_positions = { + idx + for idx, name in enumerate(positiontup) + if name in all_names_we_will_expand + } + expand_pos_lower_index = min(all_expand_positions) + expand_pos_upper_index = max(all_expand_positions) + 1 + assert ( + len(all_expand_positions) + == expand_pos_upper_index - expand_pos_lower_index + ) + + if self._numeric_binds: + escaped = re.escape(self._numeric_binds_identifier_char) + executemany_values_w_comma = re.sub( + rf"{escaped}\d+", "%s", executemany_values_w_comma + ) + + while batches: + batch = batches[0:batch_size] + compiled_batch = compiled_batches[0:batch_size] + + batches[0:batch_size] = [] + compiled_batches[0:batch_size] = [] + + if batches: + current_batch_size = batch_size + else: + current_batch_size = len(batch) + + if generic_setinputsizes: + # if setinputsizes is present, expand this collection to + # suit the batch length as well + # currently this will be mssql+pyodbc for internal dialects + processed_setinputsizes = [ + (new_key, len_, typ) + for new_key, len_, typ in ( + (f"{key}_{index}", len_, typ) + for index in range(current_batch_size) + for key, len_, typ in generic_setinputsizes + ) + ] + + replaced_parameters: Any + if self.positional: + num_ins_params = imv.num_positional_params_counted + + batch_iterator: Iterable[Sequence[Any]] + extra_params_left: Sequence[Any] + extra_params_right: Sequence[Any] + + if num_ins_params == len(batch[0]): + extra_params_left = extra_params_right = () + batch_iterator = batch + else: + extra_params_left = batch[0][:expand_pos_lower_index] + extra_params_right = batch[0][expand_pos_upper_index:] + batch_iterator = ( + b[expand_pos_lower_index:expand_pos_upper_index] + for b in batch + ) + + if imv.embed_values_counter: + expanded_values_string = ( + "".join( + executemany_values_w_comma.replace( + "_IMV_VALUES_COUNTER", str(i) + ) + for i, _ in enumerate(batch) + ) + )[:-2] + else: + expanded_values_string = ( + (executemany_values_w_comma * current_batch_size) + )[:-2] + + if self._numeric_binds and num_ins_params > 0: + # numeric will always number the parameters inside of + # VALUES (and thus order self.positiontup) to be higher + # than non-VALUES parameters, no matter where in the + # statement those non-VALUES parameters appear (this is + # ensured in _process_numeric by numbering first all + # params that are not in _values_bindparam) + # therefore all extra params are always + # on the left side and numbered lower than the VALUES + # parameters + assert not extra_params_right + + start = expand_pos_lower_index + 1 + end = num_ins_params * (current_batch_size) + start + + # need to format here, since statement may contain + # unescaped %, while values_string contains just (%s, %s) + positions = tuple( + f"{self._numeric_binds_identifier_char}{i}" + for i in range(start, end) + ) + expanded_values_string = expanded_values_string % positions + + replaced_statement = statement.replace( + "__EXECMANY_TOKEN__", expanded_values_string + ) + + replaced_parameters = tuple( + itertools.chain.from_iterable(batch_iterator) + ) + + replaced_parameters = ( + extra_params_left + + replaced_parameters + + extra_params_right + ) + + else: + replaced_values_clauses = [] + replaced_parameters = base_parameters.copy() + + for i, param in enumerate(batch): + fmv = formatted_values_clause.replace( + "EXECMANY_INDEX__", str(i) + ) + if imv.embed_values_counter: + fmv = fmv.replace("_IMV_VALUES_COUNTER", str(i)) + + replaced_values_clauses.append(fmv) + replaced_parameters.update( + {f"{key}__{i}": param[key] for key in keys_to_replace} + ) + + replaced_statement = statement.replace( + "__EXECMANY_TOKEN__", + ", ".join(replaced_values_clauses), + ) + + yield _InsertManyValuesBatch( + replaced_statement, + replaced_parameters, + processed_setinputsizes, + batch, + ( + [_sentinel_from_params(cb) for cb in compiled_batch] + if _sentinel_from_params + else [] + ), + current_batch_size, + batchnum, + total_batches, + sort_by_parameter_order, + False, + ) + batchnum += 1 + + def visit_insert( + self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw + ): + compile_state = insert_stmt._compile_state_factory( + insert_stmt, self, **kw + ) + insert_stmt = compile_state.statement + + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack + + if toplevel: + self.isinsert = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + self.stack.append( + { + "correlate_froms": set(), + "asfrom_froms": set(), + "selectable": insert_stmt, + } + ) + + counted_bindparam = 0 + + # reset any incoming "visited_bindparam" collection + visited_bindparam = None + + # for positional, insertmanyvalues needs to know how many + # bound parameters are in the VALUES sequence; there's no simple + # rule because default expressions etc. can have zero or more + # params inside them. After multiple attempts to figure this out, + # this very simplistic "count after" works and is + # likely the least amount of callcounts, though looks clumsy + if self.positional and visiting_cte is None: + # if we are inside a CTE, don't count parameters + # here since they wont be for insertmanyvalues. keep + # visited_bindparam at None so no counting happens. + # see #9173 + visited_bindparam = [] + + crud_params_struct = crud._get_crud_params( + self, + insert_stmt, + compile_state, + toplevel, + visited_bindparam=visited_bindparam, + **kw, + ) + + if self.positional and visited_bindparam is not None: + counted_bindparam = len(visited_bindparam) + if self._numeric_binds: + if self._values_bindparam is not None: + self._values_bindparam += visited_bindparam + else: + self._values_bindparam = visited_bindparam + + crud_params_single = crud_params_struct.single_params + + if ( + not crud_params_single + and not self.dialect.supports_default_values + and not self.dialect.supports_default_metavalue + and not self.dialect.supports_empty_insert + ): + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support empty " + "inserts." % self.dialect.name + ) + + if compile_state._has_multi_parameters: + if not self.dialect.supports_multivalues_insert: + raise exc.CompileError( + "The '%s' dialect with current database " + "version settings does not support " + "in-place multirow inserts." % self.dialect.name + ) + elif ( + self.implicit_returning or insert_stmt._returning + ) and insert_stmt._sort_by_parameter_order: + raise exc.CompileError( + "RETURNING cannot be determinstically sorted when " + "using an INSERT which includes multi-row values()." + ) + crud_params_single = crud_params_struct.single_params + else: + crud_params_single = crud_params_struct.single_params + + preparer = self.preparer + supports_default_values = self.dialect.supports_default_values + + text = "INSERT " + + if insert_stmt._prefixes: + text += self._generate_prefixes( + insert_stmt, insert_stmt._prefixes, **kw + ) + + text += "INTO " + table_text = preparer.format_table(insert_stmt.table) + + if insert_stmt._hints: + _, table_text = self._setup_crud_hints(insert_stmt, table_text) + + if insert_stmt._independent_ctes: + self._dispatch_independent_ctes(insert_stmt, kw) + + text += table_text + + if crud_params_single or not supports_default_values: + text += " (%s)" % ", ".join( + [expr for _, expr, _, _ in crud_params_single] + ) + + # look for insertmanyvalues attributes that would have been configured + # by crud.py as it scanned through the columns to be part of the + # INSERT + use_insertmanyvalues = crud_params_struct.use_insertmanyvalues + named_sentinel_params: Optional[Sequence[str]] = None + add_sentinel_cols = None + implicit_sentinel = False + + returning_cols = self.implicit_returning or insert_stmt._returning + if returning_cols: + add_sentinel_cols = crud_params_struct.use_sentinel_columns + if add_sentinel_cols is not None: + assert use_insertmanyvalues + + # search for the sentinel column explicitly present + # in the INSERT columns list, and additionally check that + # this column has a bound parameter name set up that's in the + # parameter list. If both of these cases are present, it means + # we will have a client side value for the sentinel in each + # parameter set. + + _params_by_col = { + col: param_names + for col, _, _, param_names in crud_params_single + } + named_sentinel_params = [] + for _add_sentinel_col in add_sentinel_cols: + if _add_sentinel_col not in _params_by_col: + named_sentinel_params = None + break + param_name = self._within_exec_param_key_getter( + _add_sentinel_col + ) + if param_name not in _params_by_col[_add_sentinel_col]: + named_sentinel_params = None + break + named_sentinel_params.append(param_name) + + if named_sentinel_params is None: + # if we are not going to have a client side value for + # the sentinel in the parameter set, that means it's + # an autoincrement, an IDENTITY, or a server-side SQL + # expression like nextval('seqname'). So this is + # an "implicit" sentinel; we will look for it in + # RETURNING + # only, and then sort on it. For this case on PG, + # SQL Server we have to use a special INSERT form + # that guarantees the server side function lines up with + # the entries in the VALUES. + if ( + self.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT + ): + implicit_sentinel = True + else: + # here, we are not using a sentinel at all + # and we are likely the SQLite dialect. + # The first add_sentinel_col that we have should not + # be marked as "insert_sentinel=True". if it was, + # an error should have been raised in + # _get_sentinel_column_for_table. + assert not add_sentinel_cols[0]._insert_sentinel, ( + "sentinel selection rules should have prevented " + "us from getting here for this dialect" + ) + + # always put the sentinel columns last. even if they are + # in the returning list already, they will be there twice + # then. + returning_cols = list(returning_cols) + list(add_sentinel_cols) + + returning_clause = self.returning_clause( + insert_stmt, + returning_cols, + populate_result_map=toplevel, + ) + + if self.returning_precedes_values: + text += " " + returning_clause + + else: + returning_clause = None + + if insert_stmt.select is not None: + # placed here by crud.py + select_text = self.process( + self.stack[-1]["insert_from_select"], insert_into=True, **kw + ) + + if self.ctes and self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text += " %s%s" % ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ), + select_text, + ) + else: + text += " %s" % select_text + elif not crud_params_single and supports_default_values: + text += " DEFAULT VALUES" + if use_insertmanyvalues: + self._insertmanyvalues = _InsertManyValues( + True, + self.dialect.default_metavalue_token, + cast( + "List[crud._CrudParamElementStr]", crud_params_single + ), + counted_bindparam, + sort_by_parameter_order=( + insert_stmt._sort_by_parameter_order + ), + includes_upsert_behaviors=( + insert_stmt._post_values_clause is not None + ), + sentinel_columns=add_sentinel_cols, + num_sentinel_columns=( + len(add_sentinel_cols) if add_sentinel_cols else 0 + ), + implicit_sentinel=implicit_sentinel, + ) + elif compile_state._has_multi_parameters: + text += " VALUES %s" % ( + ", ".join( + "(%s)" + % (", ".join(value for _, _, value, _ in crud_param_set)) + for crud_param_set in crud_params_struct.all_multi_params + ), + ) + else: + insert_single_values_expr = ", ".join( + [ + value + for _, _, value, _ in cast( + "List[crud._CrudParamElementStr]", + crud_params_single, + ) + ] + ) + + if use_insertmanyvalues: + if ( + implicit_sentinel + and ( + self.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT + ) + # this is checking if we have + # INSERT INTO table (id) VALUES (DEFAULT). + and not (crud_params_struct.is_default_metavalue_only) + ): + # if we have a sentinel column that is server generated, + # then for selected backends render the VALUES list as a + # subquery. This is the orderable form supported by + # PostgreSQL and SQL Server. + embed_sentinel_value = True + + render_bind_casts = ( + self.dialect.insertmanyvalues_implicit_sentinel + & InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS + ) + + colnames = ", ".join( + f"p{i}" for i, _ in enumerate(crud_params_single) + ) + + if render_bind_casts: + # render casts for the SELECT list. For PG, we are + # already rendering bind casts in the parameter list, + # selectively for the more "tricky" types like ARRAY. + # however, even for the "easy" types, if the parameter + # is NULL for every entry, PG gives up and says + # "it must be TEXT", which fails for other easy types + # like ints. So we cast on this side too. + colnames_w_cast = ", ".join( + self.render_bind_cast( + col.type, + col.type._unwrapped_dialect_impl(self.dialect), + f"p{i}", + ) + for i, (col, *_) in enumerate(crud_params_single) + ) + else: + colnames_w_cast = colnames + + text += ( + f" SELECT {colnames_w_cast} FROM " + f"(VALUES ({insert_single_values_expr})) " + f"AS imp_sen({colnames}, sen_counter) " + "ORDER BY sen_counter" + ) + else: + # otherwise, if no sentinel or backend doesn't support + # orderable subquery form, use a plain VALUES list + embed_sentinel_value = False + text += f" VALUES ({insert_single_values_expr})" + + self._insertmanyvalues = _InsertManyValues( + is_default_expr=False, + single_values_expr=insert_single_values_expr, + insert_crud_params=cast( + "List[crud._CrudParamElementStr]", + crud_params_single, + ), + num_positional_params_counted=counted_bindparam, + sort_by_parameter_order=( + insert_stmt._sort_by_parameter_order + ), + includes_upsert_behaviors=( + insert_stmt._post_values_clause is not None + ), + sentinel_columns=add_sentinel_cols, + num_sentinel_columns=( + len(add_sentinel_cols) if add_sentinel_cols else 0 + ), + sentinel_param_keys=named_sentinel_params, + implicit_sentinel=implicit_sentinel, + embed_values_counter=embed_sentinel_value, + ) + + else: + text += f" VALUES ({insert_single_values_expr})" + + if insert_stmt._post_values_clause is not None: + post_values_clause = self.process( + insert_stmt._post_values_clause, **kw + ) + if post_values_clause: + text += " " + post_values_clause + + if returning_clause and not self.returning_precedes_values: + text += " " + returning_clause + + if self.ctes and not self.dialect.cte_follows_insert: + nesting_level = len(self.stack) if not toplevel else None + text = ( + self._render_cte_clause( + nesting_level=nesting_level, + include_following_stack=True, + ) + + text + ) + + self.stack.pop(-1) + + return text + + def update_limit_clause(self, update_stmt): + """Provide a hook for MySQL to add LIMIT to the UPDATE""" + return None + + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + """Provide a hook to override the initial table clause + in an UPDATE statement. + + MySQL overrides this. + + """ + kw["asfrom"] = True + return from_table._compiler_dispatch(self, iscrud=True, **kw) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Provide a hook to override the generation of an + UPDATE..FROM clause. + + MySQL and MSSQL override this. + + """ + raise NotImplementedError( + "This backend does not support multiple-table " + "criteria within UPDATE" + ) + + def visit_update(self, update_stmt, visiting_cte=None, **kw): + compile_state = update_stmt._compile_state_factory( + update_stmt, self, **kw + ) + update_stmt = compile_state.statement + + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack + + if toplevel: + self.isupdate = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + + extra_froms = compile_state._extra_froms + is_multitable = bool(extra_froms) + + if is_multitable: + # main table might be a JOIN + main_froms = set(_from_objects(update_stmt.table)) + render_extra_froms = [ + f for f in extra_froms if f not in main_froms + ] + correlate_froms = main_froms.union(extra_froms) + else: + render_extra_froms = [] + correlate_froms = {update_stmt.table} + + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": update_stmt, + } + ) + + text = "UPDATE " + + if update_stmt._prefixes: + text += self._generate_prefixes( + update_stmt, update_stmt._prefixes, **kw + ) + + table_text = self.update_tables_clause( + update_stmt, + update_stmt.table, + render_extra_froms, + from_linter=from_linter, + **kw, + ) + crud_params_struct = crud._get_crud_params( + self, update_stmt, compile_state, toplevel, **kw + ) + crud_params = crud_params_struct.single_params + + if update_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + update_stmt, table_text + ) + else: + dialect_hints = None + + if update_stmt._independent_ctes: + self._dispatch_independent_ctes(update_stmt, kw) + + text += table_text + + text += " SET " + text += ", ".join( + expr + "=" + value + for _, expr, value, _ in cast( + "List[Tuple[Any, str, str, Any]]", crud_params + ) + ) + + if self.implicit_returning or update_stmt._returning: + if self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if extra_froms: + extra_from_text = self.update_from_clause( + update_stmt, + update_stmt.table, + render_extra_froms, + dialect_hints, + from_linter=from_linter, + **kw, + ) + if extra_from_text: + text += " " + extra_from_text + + if update_stmt._where_criteria: + t = self._generate_delimited_and_list( + update_stmt._where_criteria, from_linter=from_linter, **kw + ) + if t: + text += " WHERE " + t + + limit_clause = self.update_limit_clause(update_stmt) + if limit_clause: + text += " " + limit_clause + + if ( + self.implicit_returning or update_stmt._returning + ) and not self.returning_precedes_values: + text += " " + self.returning_clause( + update_stmt, + self.implicit_returning or update_stmt._returning, + populate_result_map=toplevel, + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="UPDATE") + + self.stack.pop(-1) + + return text + + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + """Provide a hook to override the generation of an + DELETE..FROM clause. + + This can be used to implement DELETE..USING for example. + + MySQL and MSSQL override this. + + """ + raise NotImplementedError( + "This backend does not support multiple-table " + "criteria within DELETE" + ) + + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): + return from_table._compiler_dispatch( + self, asfrom=True, iscrud=True, **kw + ) + + def visit_delete(self, delete_stmt, visiting_cte=None, **kw): + compile_state = delete_stmt._compile_state_factory( + delete_stmt, self, **kw + ) + delete_stmt = compile_state.statement + + if visiting_cte is not None: + kw["visiting_cte"] = visiting_cte + toplevel = False + else: + toplevel = not self.stack + + if toplevel: + self.isdelete = True + if not self.dml_compile_state: + self.dml_compile_state = compile_state + if not self.compile_state: + self.compile_state = compile_state + + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + warn_linting = self.linting & WARN_LINTING + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + warn_linting = False + + extra_froms = compile_state._extra_froms + + correlate_froms = {delete_stmt.table}.union(extra_froms) + self.stack.append( + { + "correlate_froms": correlate_froms, + "asfrom_froms": correlate_froms, + "selectable": delete_stmt, + } + ) + + text = "DELETE " + + if delete_stmt._prefixes: + text += self._generate_prefixes( + delete_stmt, delete_stmt._prefixes, **kw + ) + + text += "FROM " + + try: + table_text = self.delete_table_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + from_linter=from_linter, + ) + except TypeError: + # anticipate 3rd party dialects that don't include **kw + # TODO: remove in 2.1 + table_text = self.delete_table_clause( + delete_stmt, delete_stmt.table, extra_froms + ) + if from_linter: + _ = self.process(delete_stmt.table, from_linter=from_linter) + + crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw) + + if delete_stmt._hints: + dialect_hints, table_text = self._setup_crud_hints( + delete_stmt, table_text + ) + else: + dialect_hints = None + + if delete_stmt._independent_ctes: + self._dispatch_independent_ctes(delete_stmt, kw) + + text += table_text + + if ( + self.implicit_returning or delete_stmt._returning + ) and self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, + self.implicit_returning or delete_stmt._returning, + populate_result_map=toplevel, + ) + + if extra_froms: + extra_from_text = self.delete_extra_from_clause( + delete_stmt, + delete_stmt.table, + extra_froms, + dialect_hints, + from_linter=from_linter, + **kw, + ) + if extra_from_text: + text += " " + extra_from_text + + if delete_stmt._where_criteria: + t = self._generate_delimited_and_list( + delete_stmt._where_criteria, from_linter=from_linter, **kw + ) + if t: + text += " WHERE " + t + + if ( + self.implicit_returning or delete_stmt._returning + ) and not self.returning_precedes_values: + text += " " + self.returning_clause( + delete_stmt, + self.implicit_returning or delete_stmt._returning, + populate_result_map=toplevel, + ) + + if self.ctes: + nesting_level = len(self.stack) if not toplevel else None + text = self._render_cte_clause(nesting_level=nesting_level) + text + + if warn_linting: + assert from_linter is not None + from_linter.warn(stmt_type="DELETE") + + self.stack.pop(-1) + + return text + + def visit_savepoint(self, savepoint_stmt, **kw): + return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt, **kw): + return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + def visit_release_savepoint(self, savepoint_stmt, **kw): + return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint( + savepoint_stmt + ) + + +class StrSQLCompiler(SQLCompiler): + """A :class:`.SQLCompiler` subclass which allows a small selection + of non-standard SQL features to render into a string value. + + The :class:`.StrSQLCompiler` is invoked whenever a Core expression + element is directly stringified without calling upon the + :meth:`_expression.ClauseElement.compile` method. + It can render a limited set + of non-standard SQL constructs to assist in basic stringification, + however for more substantial custom or dialect-specific SQL constructs, + it will be necessary to make use of + :meth:`_expression.ClauseElement.compile` + directly. + + .. seealso:: + + :ref:`faq_sql_expression_string` + + """ + + def _fallback_column_name(self, column): + return "" + + @util.preload_module("sqlalchemy.engine.url") + def visit_unsupported_compilation(self, element, err, **kw): + if element.stringify_dialect != "default": + url = util.preloaded.engine_url + dialect = url.URL.create(element.stringify_dialect).get_dialect()() + + compiler = dialect.statement_compiler( + dialect, None, _supporting_against=self + ) + if not isinstance(compiler, StrSQLCompiler): + return compiler.process(element, **kw) + + return super().visit_unsupported_compilation(element, err) + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s[%s]" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def visit_json_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + return self.visit_getitem_binary(binary, operator, **kw) + + def visit_sequence(self, seq, **kw): + return "" % self.preparer.format_sequence(seq) + + def returning_clause( + self, + stmt: UpdateBase, + returning_cols: Sequence[ColumnElement[Any]], + *, + populate_result_map: bool, + **kw: Any, + ) -> str: + columns = [ + self._label_select_column(None, c, True, False, {}) + for c in base._select_iterables(returning_cols) + ] + return "RETURNING " + ", ".join(columns) + + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return "FROM " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def delete_extra_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + kw["asfrom"] = True + return ", " + ", ".join( + t._compiler_dispatch(self, fromhints=from_hints, **kw) + for t in extra_froms + ) + + def visit_empty_set_expr(self, type_, **kw): + return "SELECT 1 WHERE 1!=1" + + def get_from_hint_text(self, table, text): + return "[%s]" % text + + def visit_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " ", **kw) + + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): + return self._generate_generic_binary(binary, " ", **kw) + + def visit_regexp_replace_op_binary(self, binary, operator, **kw): + return "(%s, %s)" % ( + binary.left._compiler_dispatch(self, **kw), + binary.right._compiler_dispatch(self, **kw), + ) + + def visit_try_cast(self, cast, **kwargs): + return "TRY_CAST(%s AS %s)" % ( + cast.clause._compiler_dispatch(self, **kwargs), + cast.typeclause._compiler_dispatch(self, **kwargs), + ) + + +class DDLCompiler(Compiled): + is_ddl = True + + if TYPE_CHECKING: + + def __init__( + self, + dialect: Dialect, + statement: ExecutableDDLElement, + schema_translate_map: Optional[SchemaTranslateMapType] = ..., + render_schema_translate: bool = ..., + compile_kwargs: Mapping[str, Any] = ..., + ): ... + + @util.memoized_property + def sql_compiler(self): + return self.dialect.statement_compiler( + self.dialect, None, schema_translate_map=self.schema_translate_map + ) + + @util.memoized_property + def type_compiler(self): + return self.dialect.type_compiler_instance + + def construct_params( + self, + params: Optional[_CoreSingleExecuteParams] = None, + extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None, + escape_names: bool = True, + ) -> Optional[_MutableCoreSingleExecuteParams]: + return None + + def visit_ddl(self, ddl, **kwargs): + # table events can substitute table and schema name + context = ddl.context + if isinstance(ddl.target, schema.Table): + context = context.copy() + + preparer = self.preparer + path = preparer.format_table_seq(ddl.target) + if len(path) == 1: + table, sch = path[0], "" + else: + table, sch = path[-1], path[0] + + context.setdefault("table", table) + context.setdefault("schema", sch) + context.setdefault("fullname", preparer.format_table(ddl.target)) + + return self.sql_compiler.post_process_text(ddl.statement % context) + + def visit_create_schema(self, create, **kw): + text = "CREATE SCHEMA " + if create.if_not_exists: + text += "IF NOT EXISTS " + return text + self.preparer.format_schema(create.element) + + def visit_drop_schema(self, drop, **kw): + text = "DROP SCHEMA " + if drop.if_exists: + text += "IF EXISTS " + text += self.preparer.format_schema(drop.element) + if drop.cascade: + text += " CASCADE" + return text + + def visit_create_table(self, create, **kw): + table = create.element + preparer = self.preparer + + text = "\nCREATE " + if table._prefixes: + text += " ".join(table._prefixes) + " " + + text += "TABLE " + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += preparer.format_table(table) + " " + + create_table_suffix = self.create_table_suffix(table) + if create_table_suffix: + text += create_table_suffix + " " + + text += "(" + + separator = "\n" + + # if only one primary key, specify it along with the column + first_pk = False + for create_column in create.columns: + column = create_column.element + try: + processed = self.process( + create_column, first_pk=column.primary_key and not first_pk + ) + if processed is not None: + text += separator + separator = ", \n" + text += "\t" + processed + if column.primary_key: + first_pk = True + except exc.CompileError as ce: + raise exc.CompileError( + "(in table '%s', column '%s'): %s" + % (table.description, column.name, ce.args[0]) + ) from ce + + const = self.create_table_constraints( + table, + _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa + ) + if const: + text += separator + "\t" + const + + text += "\n)%s\n\n" % self.post_create_table(table) + return text + + def visit_create_column(self, create, first_pk=False, **kw): + column = create.element + + if column.system: + return None + + text = self.get_column_specification(column, first_pk=first_pk) + const = " ".join( + self.process(constraint) for constraint in column.constraints + ) + if const: + text += " " + const + + return text + + def create_table_constraints( + self, table, _include_foreign_key_constraints=None, **kw + ): + # On some DB order is significant: visit PK first, then the + # other constraints (engine.ReflectionTest.testbasic failed on FB2) + constraints = [] + if table.primary_key: + constraints.append(table.primary_key) + + all_fkcs = table.foreign_key_constraints + if _include_foreign_key_constraints is not None: + omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints) + else: + omit_fkcs = set() + + constraints.extend( + [ + c + for c in table._sorted_constraints + if c is not table.primary_key and c not in omit_fkcs + ] + ) + + return ", \n\t".join( + p + for p in ( + self.process(constraint) + for constraint in constraints + if (constraint._should_create_for_compiler(self)) + and ( + not self.dialect.supports_alter + or not getattr(constraint, "use_alter", False) + ) + ) + if p is not None + ) + + def visit_drop_table(self, drop, **kw): + text = "\nDROP TABLE " + if drop.if_exists: + text += "IF EXISTS " + return text + self.preparer.format_table(drop.element) + + def visit_drop_view(self, drop, **kw): + return "\nDROP VIEW " + self.preparer.format_table(drop.element) + + def _verify_index_table(self, index): + if index.table is None: + raise exc.CompileError( + "Index '%s' is not associated with any table." % index.name + ) + + def visit_create_index( + self, create, include_schema=False, include_table_schema=True, **kw + ): + index = create.element + self._verify_index_table(index) + preparer = self.preparer + text = "CREATE " + if index.unique: + text += "UNIQUE " + if index.name is None: + raise exc.CompileError( + "CREATE INDEX requires that the index have a name" + ) + + text += "INDEX " + if create.if_not_exists: + text += "IF NOT EXISTS " + + text += "%s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table( + index.table, use_schema=include_table_schema + ), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) + return text + + def visit_drop_index(self, drop, **kw): + index = drop.element + + if index.name is None: + raise exc.CompileError( + "DROP INDEX requires that the index have a name" + ) + text = "\nDROP INDEX " + if drop.if_exists: + text += "IF EXISTS " + + return text + self._prepared_index_name(index, include_schema=True) + + def _prepared_index_name(self, index, include_schema=False): + if index.table is not None: + effective_schema = self.preparer.schema_for_object(index.table) + else: + effective_schema = None + if include_schema and effective_schema: + schema_name = self.preparer.quote_schema(effective_schema) + else: + schema_name = None + + index_name = self.preparer.format_index(index) + + if schema_name: + index_name = schema_name + "." + index_name + return index_name + + def visit_add_constraint(self, create, **kw): + return "ALTER TABLE %s ADD %s" % ( + self.preparer.format_table(create.element.table), + self.process(create.element), + ) + + def visit_set_table_comment(self, create, **kw): + return "COMMENT ON TABLE %s IS %s" % ( + self.preparer.format_table(create.element), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_table_comment(self, drop, **kw): + return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table( + drop.element + ) + + def visit_set_column_comment(self, create, **kw): + return "COMMENT ON COLUMN %s IS %s" % ( + self.preparer.format_column( + create.element, use_table=True, use_schema=True + ), + self.sql_compiler.render_literal_value( + create.element.comment, sqltypes.String() + ), + ) + + def visit_drop_column_comment(self, drop, **kw): + return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column( + drop.element, use_table=True + ) + + def visit_set_constraint_comment(self, create, **kw): + raise exc.UnsupportedCompilationError(self, type(create)) + + def visit_drop_constraint_comment(self, drop, **kw): + raise exc.UnsupportedCompilationError(self, type(drop)) + + def get_identity_options(self, identity_options): + text = [] + if identity_options.increment is not None: + text.append("INCREMENT BY %d" % identity_options.increment) + if identity_options.start is not None: + text.append("START WITH %d" % identity_options.start) + if identity_options.minvalue is not None: + text.append("MINVALUE %d" % identity_options.minvalue) + if identity_options.maxvalue is not None: + text.append("MAXVALUE %d" % identity_options.maxvalue) + if identity_options.nominvalue is not None: + text.append("NO MINVALUE") + if identity_options.nomaxvalue is not None: + text.append("NO MAXVALUE") + if identity_options.cache is not None: + text.append("CACHE %d" % identity_options.cache) + if identity_options.cycle is not None: + text.append("CYCLE" if identity_options.cycle else "NO CYCLE") + return " ".join(text) + + def visit_create_sequence(self, create, prefix=None, **kw): + text = "CREATE SEQUENCE " + if create.if_not_exists: + text += "IF NOT EXISTS " + text += self.preparer.format_sequence(create.element) + + if prefix: + text += prefix + options = self.get_identity_options(create.element) + if options: + text += " " + options + return text + + def visit_drop_sequence(self, drop, **kw): + text = "DROP SEQUENCE " + if drop.if_exists: + text += "IF EXISTS " + return text + self.preparer.format_sequence(drop.element) + + def visit_drop_constraint(self, drop, **kw): + constraint = drop.element + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + else: + formatted_name = None + + if formatted_name is None: + raise exc.CompileError( + "Can't emit DROP CONSTRAINT for constraint %r; " + "it has no name" % drop.element + ) + return "ALTER TABLE %s DROP CONSTRAINT %s%s%s" % ( + self.preparer.format_table(drop.element.table), + "IF EXISTS " if drop.if_exists else "", + formatted_name, + " CASCADE" if drop.cascade else "", + ) + + def get_column_specification(self, column, **kwargs): + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler_instance.process( + column.type, type_expression=column + ) + ) + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + if column.computed is not None: + colspec += " " + self.process(column.computed) + + if ( + column.identity is not None + and self.dialect.supports_identity_columns + ): + colspec += " " + self.process(column.identity) + + if not column.nullable and ( + not column.identity or not self.dialect.supports_identity_columns + ): + colspec += " NOT NULL" + return colspec + + def create_table_suffix(self, table): + return "" + + def post_create_table(self, table): + return "" + + def get_column_default_string(self, column): + if isinstance(column.server_default, schema.DefaultClause): + return self.render_default_string(column.server_default.arg) + else: + return None + + def render_default_string(self, default): + if isinstance(default, str): + return self.sql_compiler.render_literal_value( + default, sqltypes.STRINGTYPE + ) + else: + return self.sql_compiler.process(default, literal_binds=True) + + def visit_table_or_column_check_constraint(self, constraint, **kw): + if constraint.is_column_level: + return self.visit_column_check_constraint(constraint) + else: + return self.visit_check_constraint(constraint) + + def visit_check_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_column_check_constraint(self, constraint, **kw): + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "CHECK (%s)" % self.sql_compiler.process( + constraint.sqltext, include_table=False, literal_binds=True + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_primary_key_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "PRIMARY KEY " + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) + text += self.define_constraint_deferrability(constraint) + return text + + def visit_foreign_key_constraint(self, constraint, **kw): + preparer = self.preparer + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + remote_table = list(constraint.elements)[0].column.table + text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % ( + ", ".join( + preparer.quote(f.parent.name) for f in constraint.elements + ), + self.define_constraint_remote_table( + constraint, remote_table, preparer + ), + ", ".join( + preparer.quote(f.column.name) for f in constraint.elements + ), + ) + text += self.define_constraint_match(constraint) + text += self.define_constraint_cascades(constraint) + text += self.define_constraint_deferrability(constraint) + return text + + def define_constraint_remote_table(self, constraint, table, preparer): + """Format the remote table clause of a CREATE CONSTRAINT clause.""" + + return preparer.format_table(table) + + def visit_unique_constraint(self, constraint, **kw): + if len(constraint) == 0: + return "" + text = "" + if constraint.name is not None: + formatted_name = self.preparer.format_constraint(constraint) + if formatted_name is not None: + text += "CONSTRAINT %s " % formatted_name + text += "UNIQUE %s(%s)" % ( + self.define_unique_constraint_distinct(constraint, **kw), + ", ".join(self.preparer.quote(c.name) for c in constraint), + ) + text += self.define_constraint_deferrability(constraint) + return text + + def define_unique_constraint_distinct(self, constraint, **kw): + return "" + + def define_constraint_cascades(self, constraint): + text = "" + if constraint.ondelete is not None: + text += " ON DELETE %s" % self.preparer.validate_sql_phrase( + constraint.ondelete, FK_ON_DELETE + ) + if constraint.onupdate is not None: + text += " ON UPDATE %s" % self.preparer.validate_sql_phrase( + constraint.onupdate, FK_ON_UPDATE + ) + return text + + def define_constraint_deferrability(self, constraint): + text = "" + if constraint.deferrable is not None: + if constraint.deferrable: + text += " DEFERRABLE" + else: + text += " NOT DEFERRABLE" + if constraint.initially is not None: + text += " INITIALLY %s" % self.preparer.validate_sql_phrase( + constraint.initially, FK_INITIALLY + ) + return text + + def define_constraint_match(self, constraint): + text = "" + if constraint.match is not None: + text += " MATCH %s" % constraint.match + return text + + def visit_computed_column(self, generated, **kw): + text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process( + generated.sqltext, include_table=False, literal_binds=True + ) + if generated.persisted is True: + text += " STORED" + elif generated.persisted is False: + text += " VIRTUAL" + return text + + def visit_identity_column(self, identity, **kw): + text = "GENERATED %s AS IDENTITY" % ( + "ALWAYS" if identity.always else "BY DEFAULT", + ) + options = self.get_identity_options(identity) + if options: + text += " (%s)" % options + return text + + +class GenericTypeCompiler(TypeCompiler): + def visit_FLOAT(self, type_, **kw): + return "FLOAT" + + def visit_DOUBLE(self, type_, **kw): + return "DOUBLE" + + def visit_DOUBLE_PRECISION(self, type_, **kw): + return "DOUBLE PRECISION" + + def visit_REAL(self, type_, **kw): + return "REAL" + + def visit_NUMERIC(self, type_, **kw): + if type_.precision is None: + return "NUMERIC" + elif type_.scale is None: + return "NUMERIC(%(precision)s)" % {"precision": type_.precision} + else: + return "NUMERIC(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } + + def visit_DECIMAL(self, type_, **kw): + if type_.precision is None: + return "DECIMAL" + elif type_.scale is None: + return "DECIMAL(%(precision)s)" % {"precision": type_.precision} + else: + return "DECIMAL(%(precision)s, %(scale)s)" % { + "precision": type_.precision, + "scale": type_.scale, + } + + def visit_INTEGER(self, type_, **kw): + return "INTEGER" + + def visit_SMALLINT(self, type_, **kw): + return "SMALLINT" + + def visit_BIGINT(self, type_, **kw): + return "BIGINT" + + def visit_TIMESTAMP(self, type_, **kw): + return "TIMESTAMP" + + def visit_DATETIME(self, type_, **kw): + return "DATETIME" + + def visit_DATE(self, type_, **kw): + return "DATE" + + def visit_TIME(self, type_, **kw): + return "TIME" + + def visit_CLOB(self, type_, **kw): + return "CLOB" + + def visit_NCLOB(self, type_, **kw): + return "NCLOB" + + def _render_string_type(self, type_, name, length_override=None): + text = name + if length_override: + text += "(%d)" % length_override + elif type_.length: + text += "(%d)" % type_.length + if type_.collation: + text += ' COLLATE "%s"' % type_.collation + return text + + def visit_CHAR(self, type_, **kw): + return self._render_string_type(type_, "CHAR") + + def visit_NCHAR(self, type_, **kw): + return self._render_string_type(type_, "NCHAR") + + def visit_VARCHAR(self, type_, **kw): + return self._render_string_type(type_, "VARCHAR") + + def visit_NVARCHAR(self, type_, **kw): + return self._render_string_type(type_, "NVARCHAR") + + def visit_TEXT(self, type_, **kw): + return self._render_string_type(type_, "TEXT") + + def visit_UUID(self, type_, **kw): + return "UUID" + + def visit_BLOB(self, type_, **kw): + return "BLOB" + + def visit_BINARY(self, type_, **kw): + return "BINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY" + (type_.length and "(%d)" % type_.length or "") + + def visit_BOOLEAN(self, type_, **kw): + return "BOOLEAN" + + def visit_uuid(self, type_, **kw): + if not type_.native_uuid or not self.dialect.supports_native_uuid: + return self._render_string_type(type_, "CHAR", length_override=32) + else: + return self.visit_UUID(type_, **kw) + + def visit_large_binary(self, type_, **kw): + return self.visit_BLOB(type_, **kw) + + def visit_boolean(self, type_, **kw): + return self.visit_BOOLEAN(type_, **kw) + + def visit_time(self, type_, **kw): + return self.visit_TIME(type_, **kw) + + def visit_datetime(self, type_, **kw): + return self.visit_DATETIME(type_, **kw) + + def visit_date(self, type_, **kw): + return self.visit_DATE(type_, **kw) + + def visit_big_integer(self, type_, **kw): + return self.visit_BIGINT(type_, **kw) + + def visit_small_integer(self, type_, **kw): + return self.visit_SMALLINT(type_, **kw) + + def visit_integer(self, type_, **kw): + return self.visit_INTEGER(type_, **kw) + + def visit_real(self, type_, **kw): + return self.visit_REAL(type_, **kw) + + def visit_float(self, type_, **kw): + return self.visit_FLOAT(type_, **kw) + + def visit_double(self, type_, **kw): + return self.visit_DOUBLE(type_, **kw) + + def visit_numeric(self, type_, **kw): + return self.visit_NUMERIC(type_, **kw) + + def visit_string(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_unicode(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_unicode_text(self, type_, **kw): + return self.visit_TEXT(type_, **kw) + + def visit_enum(self, type_, **kw): + return self.visit_VARCHAR(type_, **kw) + + def visit_null(self, type_, **kw): + raise exc.CompileError( + "Can't generate DDL for %r; " + "did you forget to specify a " + "type on this Column?" % type_ + ) + + def visit_type_decorator(self, type_, **kw): + return self.process(type_.type_engine(self.dialect), **kw) + + def visit_user_defined(self, type_, **kw): + return type_.get_col_spec(**kw) + + +class StrSQLTypeCompiler(GenericTypeCompiler): + def process(self, type_, **kw): + try: + _compiler_dispatch = type_._compiler_dispatch + except AttributeError: + return self._visit_unknown(type_, **kw) + else: + return _compiler_dispatch(self, **kw) + + def __getattr__(self, key): + if key.startswith("visit_"): + return self._visit_unknown + else: + raise AttributeError(key) + + def _visit_unknown(self, type_, **kw): + if type_.__class__.__name__ == type_.__class__.__name__.upper(): + return type_.__class__.__name__ + else: + return repr(type_) + + def visit_null(self, type_, **kw): + return "NULL" + + def visit_user_defined(self, type_, **kw): + try: + get_col_spec = type_.get_col_spec + except AttributeError: + return repr(type_) + else: + return get_col_spec(**kw) + + +class _SchemaForObjectCallable(Protocol): + def __call__(self, obj: Any) -> str: ... + + +class _BindNameForColProtocol(Protocol): + def __call__(self, col: ColumnClause[Any]) -> str: ... + + +class IdentifierPreparer: + """Handle quoting and case-folding of identifiers based on options.""" + + reserved_words = RESERVED_WORDS + + legal_characters = LEGAL_CHARACTERS + + illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS + + initial_quote: str + + final_quote: str + + _strings: MutableMapping[str, str] + + schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema") + """Return the .schema attribute for an object. + + For the default IdentifierPreparer, the schema for an object is always + the value of the ".schema" attribute. if the preparer is replaced + with one that has a non-empty schema_translate_map, the value of the + ".schema" attribute is rendered a symbol that will be converted to a + real schema name from the mapping post-compile. + + """ + + _includes_none_schema_translate: bool = False + + def __init__( + self, + dialect, + initial_quote='"', + final_quote=None, + escape_quote='"', + quote_case_sensitive_collations=True, + omit_schema=False, + ): + """Construct a new ``IdentifierPreparer`` object. + + initial_quote + Character that begins a delimited identifier. + + final_quote + Character that ends a delimited identifier. Defaults to + `initial_quote`. + + omit_schema + Prevent prepending schema name. Useful for databases that do + not support schemae. + """ + + self.dialect = dialect + self.initial_quote = initial_quote + self.final_quote = final_quote or self.initial_quote + self.escape_quote = escape_quote + self.escape_to_quote = self.escape_quote * 2 + self.omit_schema = omit_schema + self.quote_case_sensitive_collations = quote_case_sensitive_collations + self._strings = {} + self._double_percents = self.dialect.paramstyle in ( + "format", + "pyformat", + ) + + def _with_schema_translate(self, schema_translate_map): + prep = self.__class__.__new__(self.__class__) + prep.__dict__.update(self.__dict__) + + includes_none = None in schema_translate_map + + def symbol_getter(obj): + name = obj.schema + if obj._use_schema_map and (name is not None or includes_none): + if name is not None and ("[" in name or "]" in name): + raise exc.CompileError( + "Square bracket characters ([]) not supported " + "in schema translate name '%s'" % name + ) + return quoted_name( + "__[SCHEMA_%s]" % (name or "_none"), quote=False + ) + else: + return obj.schema + + prep.schema_for_object = symbol_getter + prep._includes_none_schema_translate = includes_none + return prep + + def _render_schema_translates(self, statement, schema_translate_map): + d = schema_translate_map + if None in d: + if not self._includes_none_schema_translate: + raise exc.InvalidRequestError( + "schema translate map which previously did not have " + "`None` present as a key now has `None` present; compiled " + "statement may lack adequate placeholders. Please use " + "consistent keys in successive " + "schema_translate_map dictionaries." + ) + + d["_none"] = d[None] + + def replace(m): + name = m.group(2) + if name in d: + effective_schema = d[name] + else: + if name in (None, "_none"): + raise exc.InvalidRequestError( + "schema translate map which previously had `None` " + "present as a key now no longer has it present; don't " + "know how to apply schema for compiled statement. " + "Please use consistent keys in successive " + "schema_translate_map dictionaries." + ) + effective_schema = name + + if not effective_schema: + effective_schema = self.dialect.default_schema_name + if not effective_schema: + # TODO: no coverage here + raise exc.CompileError( + "Dialect has no default schema name; can't " + "use None as dynamic schema target." + ) + return self.quote_schema(effective_schema) + + return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement) + + def _escape_identifier(self, value: str) -> str: + """Escape an identifier. + + Subclasses should override this to provide database-dependent + escaping behavior. + """ + + value = value.replace(self.escape_quote, self.escape_to_quote) + if self._double_percents: + value = value.replace("%", "%%") + return value + + def _unescape_identifier(self, value: str) -> str: + """Canonicalize an escaped identifier. + + Subclasses should override this to provide database-dependent + unescaping behavior that reverses _escape_identifier. + """ + + return value.replace(self.escape_to_quote, self.escape_quote) + + def validate_sql_phrase(self, element, reg): + """keyword sequence filter. + + a filter for elements that are intended to represent keyword sequences, + such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters + should be present. + + .. versionadded:: 1.3 + + """ + + if element is not None and not reg.match(element): + raise exc.CompileError( + "Unexpected SQL phrase: %r (matching against %r)" + % (element, reg.pattern) + ) + return element + + def quote_identifier(self, value: str) -> str: + """Quote an identifier. + + Subclasses should override this to provide database-dependent + quoting behavior. + """ + + return ( + self.initial_quote + + self._escape_identifier(value) + + self.final_quote + ) + + def _requires_quotes(self, value: str) -> bool: + """Return True if the given identifier requires quoting.""" + lc_value = value.lower() + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(str(value)) + or (lc_value != value) + ) + + def _requires_quotes_illegal_chars(self, value): + """Return True if the given identifier requires quoting, but + not taking case convention into account.""" + return not self.legal_characters.match(str(value)) + + def quote_schema(self, schema: str, force: Any = None) -> str: + """Conditionally quote a schema name. + + + The name is quoted if it is a reserved word, contains quote-necessary + characters, or is an instance of :class:`.quoted_name` which includes + ``quote`` set to ``True``. + + Subclasses can override this to provide database-dependent + quoting behavior for schema names. + + :param schema: string schema name + :param force: unused + + .. deprecated:: 0.9 + + The :paramref:`.IdentifierPreparer.quote_schema.force` + parameter is deprecated and will be removed in a future + release. This flag has no effect on the behavior of the + :meth:`.IdentifierPreparer.quote` method; please refer to + :class:`.quoted_name`. + + """ + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote_schema.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + # deprecated 0.9. warning from 1.3 + version="0.9", + ) + + return self.quote(schema) + + def quote(self, ident: str, force: Any = None) -> str: + """Conditionally quote an identifier. + + The identifier is quoted if it is a reserved word, contains + quote-necessary characters, or is an instance of + :class:`.quoted_name` which includes ``quote`` set to ``True``. + + Subclasses can override this to provide database-dependent + quoting behavior for identifier names. + + :param ident: string identifier + :param force: unused + + .. deprecated:: 0.9 + + The :paramref:`.IdentifierPreparer.quote.force` + parameter is deprecated and will be removed in a future + release. This flag has no effect on the behavior of the + :meth:`.IdentifierPreparer.quote` method; please refer to + :class:`.quoted_name`. + + """ + if force is not None: + # not using the util.deprecated_params() decorator in this + # case because of the additional function call overhead on this + # very performance-critical spot. + util.warn_deprecated( + "The IdentifierPreparer.quote.force parameter is " + "deprecated and will be removed in a future release. This " + "flag has no effect on the behavior of the " + "IdentifierPreparer.quote method; please refer to " + "quoted_name().", + # deprecated 0.9. warning from 1.3 + version="0.9", + ) + + force = getattr(ident, "quote", None) + + if force is None: + if ident in self._strings: + return self._strings[ident] + else: + if self._requires_quotes(ident): + self._strings[ident] = self.quote_identifier(ident) + else: + self._strings[ident] = ident + return self._strings[ident] + elif force: + return self.quote_identifier(ident) + else: + return ident + + def format_collation(self, collation_name): + if self.quote_case_sensitive_collations: + return self.quote(collation_name) + else: + return collation_name + + def format_sequence(self, sequence, use_schema=True): + name = self.quote(sequence.name) + + effective_schema = self.schema_for_object(sequence) + + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): + name = self.quote_schema(effective_schema) + "." + name + return name + + def format_label( + self, label: Label[Any], name: Optional[str] = None + ) -> str: + return self.quote(name or label.name) + + def format_alias( + self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None + ) -> str: + if name is None: + assert alias is not None + return self.quote(alias.name) + else: + return self.quote(name) + + def format_savepoint(self, savepoint, name=None): + # Running the savepoint name through quoting is unnecessary + # for all known dialects. This is here to support potential + # third party use cases + ident = name or savepoint.ident + if self._requires_quotes(ident): + ident = self.quote_identifier(ident) + return ident + + @util.preload_module("sqlalchemy.sql.naming") + def format_constraint(self, constraint, _alembic_quote=True): + naming = util.preloaded.sql_naming + + if constraint.name is _NONE_NAME: + name = naming._constraint_name_for_table( + constraint, constraint.table + ) + + if name is None: + return None + else: + name = constraint.name + + if constraint.__visit_name__ == "index": + return self.truncate_and_render_index_name( + name, _alembic_quote=_alembic_quote + ) + else: + return self.truncate_and_render_constraint_name( + name, _alembic_quote=_alembic_quote + ) + + def truncate_and_render_index_name(self, name, _alembic_quote=True): + # calculate these at format time so that ad-hoc changes + # to dialect.max_identifier_length etc. can be reflected + # as IdentifierPreparer is long lived + max_ = ( + self.dialect.max_index_name_length + or self.dialect.max_identifier_length + ) + return self._truncate_and_render_maxlen_name( + name, max_, _alembic_quote + ) + + def truncate_and_render_constraint_name(self, name, _alembic_quote=True): + # calculate these at format time so that ad-hoc changes + # to dialect.max_identifier_length etc. can be reflected + # as IdentifierPreparer is long lived + max_ = ( + self.dialect.max_constraint_name_length + or self.dialect.max_identifier_length + ) + return self._truncate_and_render_maxlen_name( + name, max_, _alembic_quote + ) + + def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote): + if isinstance(name, elements._truncated_label): + if len(name) > max_: + name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:] + else: + self.dialect.validate_identifier(name) + + if not _alembic_quote: + return name + else: + return self.quote(name) + + def format_index(self, index): + return self.format_constraint(index) + + def format_table(self, table, use_schema=True, name=None): + """Prepare a quoted table and schema name.""" + + if name is None: + name = table.name + + result = self.quote(name) + + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema and effective_schema: + result = self.quote_schema(effective_schema) + "." + result + return result + + def format_schema(self, name): + """Prepare a quoted schema name.""" + + return self.quote(name) + + def format_label_name( + self, + name, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + return self.quote(name) + + def format_column( + self, + column, + use_table=False, + name=None, + table_name=None, + use_schema=False, + anon_map=None, + ): + """Prepare a quoted column name.""" + + if name is None: + name = column.name + + if anon_map is not None and isinstance( + name, elements._truncated_label + ): + name = name.apply_map(anon_map) + + if not getattr(column, "is_literal", False): + if use_table: + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + self.quote(name) + ) + else: + return self.quote(name) + else: + # literal textual elements get stuck into ColumnClause a lot, + # which shouldn't get quoted + + if use_table: + return ( + self.format_table( + column.table, use_schema=use_schema, name=table_name + ) + + "." + + name + ) + else: + return name + + def format_table_seq(self, table, use_schema=True): + """Format table name and schema as a tuple.""" + + # Dialects with more levels in their fully qualified references + # ('database', 'owner', etc.) could override this and return + # a longer sequence. + + effective_schema = self.schema_for_object(table) + + if not self.omit_schema and use_schema and effective_schema: + return ( + self.quote_schema(effective_schema), + self.format_table(table, use_schema=False), + ) + else: + return (self.format_table(table, use_schema=False),) + + @util.memoized_property + def _r_identifiers(self): + initial, final, escaped_final = ( + re.escape(s) + for s in ( + self.initial_quote, + self.final_quote, + self._escape_identifier(self.final_quote), + ) + ) + r = re.compile( + r"(?:" + r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s" + r"|([^\.]+))(?=\.|$))+" + % {"initial": initial, "final": final, "escaped": escaped_final} + ) + return r + + def unformat_identifiers(self, identifiers): + """Unpack 'schema.table.column'-like strings into components.""" + + r = self._r_identifiers + return [ + self._unescape_identifier(i) + for i in [a or b for a, b in r.findall(identifiers)] + ] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/crud.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/crud.py new file mode 100644 index 0000000..499a19d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/crud.py @@ -0,0 +1,1669 @@ +# sql/crud.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Functions used by compiler.py to determine the parameters rendered +within INSERT and UPDATE statements. + +""" +from __future__ import annotations + +import functools +import operator +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import MutableMapping +from typing import NamedTuple +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from . import coercions +from . import dml +from . import elements +from . import roles +from .base import _DefaultDescriptionTuple +from .dml import isinsert as _compile_state_isinsert +from .elements import ColumnClause +from .schema import default_is_clause_element +from .schema import default_is_sequence +from .selectable import Select +from .selectable import TableClause +from .. import exc +from .. import util +from ..util.typing import Literal + +if TYPE_CHECKING: + from .compiler import _BindNameForColProtocol + from .compiler import SQLCompiler + from .dml import _DMLColumnElement + from .dml import DMLState + from .dml import ValuesBase + from .elements import ColumnElement + from .elements import KeyedColumnElement + from .schema import _SQLExprDefault + from .schema import Column + +REQUIRED = util.symbol( + "REQUIRED", + """ +Placeholder for the value within a :class:`.BindParameter` +which is required to be present when the statement is passed +to :meth:`_engine.Connection.execute`. + +This symbol is typically used when a :func:`_expression.insert` +or :func:`_expression.update` statement is compiled without parameter +values present. + +""", +) + + +def _as_dml_column(c: ColumnElement[Any]) -> ColumnClause[Any]: + if not isinstance(c, ColumnClause): + raise exc.CompileError( + f"Can't create DML statement against column expression {c!r}" + ) + return c + + +_CrudParamElement = Tuple[ + "ColumnElement[Any]", + str, # column name + Optional[ + Union[str, "_SQLExprDefault"] + ], # bound parameter string or SQL expression to apply + Iterable[str], +] +_CrudParamElementStr = Tuple[ + "KeyedColumnElement[Any]", + str, # column name + str, # bound parameter string + Iterable[str], +] +_CrudParamElementSQLExpr = Tuple[ + "ColumnClause[Any]", + str, + "_SQLExprDefault", # SQL expression to apply + Iterable[str], +] + +_CrudParamSequence = List[_CrudParamElement] + + +class _CrudParams(NamedTuple): + single_params: _CrudParamSequence + all_multi_params: List[Sequence[_CrudParamElementStr]] + is_default_metavalue_only: bool = False + use_insertmanyvalues: bool = False + use_sentinel_columns: Optional[Sequence[Column[Any]]] = None + + +def _get_crud_params( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + toplevel: bool, + **kw: Any, +) -> _CrudParams: + """create a set of tuples representing column/string pairs for use + in an INSERT or UPDATE statement. + + Also generates the Compiled object's postfetch, prefetch, and + returning column collections, used for default handling and ultimately + populating the CursorResult's prefetch_cols() and postfetch_cols() + collections. + + """ + + # note: the _get_crud_params() system was written with the notion in mind + # that INSERT, UPDATE, DELETE are always the top level statement and + # that there is only one of them. With the addition of CTEs that can + # make use of DML, this assumption is no longer accurate; the DML + # statement is not necessarily the top-level "row returning" thing + # and it is also theoretically possible (fortunately nobody has asked yet) + # to have a single statement with multiple DMLs inside of it via CTEs. + + # the current _get_crud_params() design doesn't accommodate these cases + # right now. It "just works" for a CTE that has a single DML inside of + # it, and for a CTE with multiple DML, it's not clear what would happen. + + # overall, the "compiler.XYZ" collections here would need to be in a + # per-DML structure of some kind, and DefaultDialect would need to + # navigate these collections on a per-statement basis, with additional + # emphasis on the "toplevel returning data" statement. However we + # still need to run through _get_crud_params() for all DML as we have + # Python / SQL generated column defaults that need to be rendered. + + # if there is user need for this kind of thing, it's likely a post 2.0 + # kind of change as it would require deep changes to DefaultDialect + # as well as here. + + compiler.postfetch = [] + compiler.insert_prefetch = [] + compiler.update_prefetch = [] + compiler.implicit_returning = [] + + visiting_cte = kw.get("visiting_cte", None) + if visiting_cte is not None: + # for insert -> CTE -> insert, don't populate an incoming + # _crud_accumulate_bind_names collection; the INSERT we process here + # will not be inline within the VALUES of the enclosing INSERT as the + # CTE is placed on the outside. See issue #9173 + kw.pop("accumulate_bind_names", None) + assert ( + "accumulate_bind_names" not in kw + ), "Don't know how to handle insert within insert without a CTE" + + # getters - these are normally just column.key, + # but in the case of mysql multi-table update, the rules for + # .key must conditionally take tablename into account + ( + _column_as_key, + _getattr_col_key, + _col_bind_name, + ) = _key_getters_for_crud_column(compiler, stmt, compile_state) + + compiler._get_bind_name_for_col = _col_bind_name + + if stmt._returning and stmt._return_defaults: + raise exc.CompileError( + "Can't compile statement that includes returning() and " + "return_defaults() simultaneously" + ) + + if compile_state.isdelete: + _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + (), + _getattr_col_key, + _column_as_key, + _col_bind_name, + (), + (), + toplevel, + kw, + ) + return _CrudParams([], []) + + # no parameters in the statement, no parameters in the + # compiled params - return binds for all columns + if compiler.column_keys is None and compile_state._no_parameters: + return _CrudParams( + [ + ( + c, + compiler.preparer.format_column(c), + _create_bind_param(compiler, c, None, required=True), + (c.key,), + ) + for c in stmt.table.columns + if not c._omit_from_statements + ], + [], + ) + + stmt_parameter_tuples: Optional[ + List[Tuple[Union[str, ColumnClause[Any]], Any]] + ] + spd: Optional[MutableMapping[_DMLColumnElement, Any]] + + if ( + _compile_state_isinsert(compile_state) + and compile_state._has_multi_parameters + ): + mp = compile_state._multi_parameters + assert mp is not None + spd = mp[0] + stmt_parameter_tuples = list(spd.items()) + spd_str_key = {_column_as_key(key) for key in spd} + elif compile_state._ordered_values: + spd = compile_state._dict_parameters + stmt_parameter_tuples = compile_state._ordered_values + assert spd is not None + spd_str_key = {_column_as_key(key) for key in spd} + elif compile_state._dict_parameters: + spd = compile_state._dict_parameters + stmt_parameter_tuples = list(spd.items()) + spd_str_key = {_column_as_key(key) for key in spd} + else: + stmt_parameter_tuples = spd = spd_str_key = None + + # if we have statement parameters - set defaults in the + # compiled params + if compiler.column_keys is None: + parameters = {} + elif stmt_parameter_tuples: + assert spd_str_key is not None + parameters = { + _column_as_key(key): REQUIRED + for key in compiler.column_keys + if key not in spd_str_key + } + else: + parameters = { + _column_as_key(key): REQUIRED for key in compiler.column_keys + } + + # create a list of column assignment clauses as tuples + values: List[_CrudParamElement] = [] + + if stmt_parameter_tuples is not None: + _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, + ) + + check_columns: Dict[str, ColumnClause[Any]] = {} + + # special logic that only occurs for multi-table UPDATE + # statements + if dml.isupdate(compile_state) and compile_state.is_multitable: + _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, + ) + + if _compile_state_isinsert(compile_state) and stmt._select_names: + # is an insert from select, is not a multiparams + + assert not compile_state._has_multi_parameters + + _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + use_insertmanyvalues = False + use_sentinel_columns = None + else: + use_insertmanyvalues, use_sentinel_columns = _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, + ) + + if parameters and stmt_parameter_tuples: + check = ( + set(parameters) + .intersection(_column_as_key(k) for k, v in stmt_parameter_tuples) + .difference(check_columns) + ) + if check: + raise exc.CompileError( + "Unconsumed column names: %s" + % (", ".join("%s" % (c,) for c in check)) + ) + + is_default_metavalue_only = False + + if ( + _compile_state_isinsert(compile_state) + and compile_state._has_multi_parameters + ): + # is a multiparams, is not an insert from a select + assert not stmt._select_names + multi_extended_values = _extend_values_for_multiparams( + compiler, + stmt, + compile_state, + cast( + "Sequence[_CrudParamElementStr]", + values, + ), + cast("Callable[..., str]", _column_as_key), + kw, + ) + return _CrudParams(values, multi_extended_values) + elif ( + not values + and compiler.for_executemany + and compiler.dialect.supports_default_metavalue + ): + # convert an "INSERT DEFAULT VALUES" + # into INSERT (firstcol) VALUES (DEFAULT) which can be turned + # into an in-place multi values. This supports + # insert_executemany_returning mode :) + values = [ + ( + _as_dml_column(stmt.table.columns[0]), + compiler.preparer.format_column(stmt.table.columns[0]), + compiler.dialect.default_metavalue_token, + (), + ) + ] + is_default_metavalue_only = True + + return _CrudParams( + values, + [], + is_default_metavalue_only=is_default_metavalue_only, + use_insertmanyvalues=use_insertmanyvalues, + use_sentinel_columns=use_sentinel_columns, + ) + + +@overload +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: Literal[True] = ..., + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> str: ... + + +@overload +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + **kw: Any, +) -> str: ... + + +def _create_bind_param( + compiler: SQLCompiler, + col: ColumnElement[Any], + value: Any, + process: bool = True, + required: bool = False, + name: Optional[str] = None, + **kw: Any, +) -> Union[str, elements.BindParameter[Any]]: + if name is None: + name = col.key + bindparam = elements.BindParameter( + name, value, type_=col.type, required=required + ) + bindparam._is_crud = True + if process: + return bindparam._compiler_dispatch(compiler, **kw) + else: + return bindparam + + +def _handle_values_anonymous_param(compiler, col, value, name, **kw): + # the insert() and update() constructs as of 1.4 will now produce anonymous + # bindparam() objects in the values() collections up front when given plain + # literal values. This is so that cache key behaviors, which need to + # produce bound parameters in deterministic order without invoking any + # compilation here, can be applied to these constructs when they include + # values() (but not yet multi-values, which are not included in caching + # right now). + # + # in order to produce the desired "crud" style name for these parameters, + # which will also be targetable in engine/default.py through the usual + # conventions, apply our desired name to these unique parameters by + # populating the compiler truncated names cache with the desired name, + # rather than having + # compiler.visit_bindparam()->compiler._truncated_identifier make up a + # name. Saves on call counts also. + + # for INSERT/UPDATE that's a CTE, we don't need names to match to + # external parameters and these would also conflict in the case where + # multiple insert/update are combined together using CTEs + is_cte = "visiting_cte" in kw + + if ( + not is_cte + and value.unique + and isinstance(value.key, elements._truncated_label) + ): + compiler.truncated_names[("bindparam", value.key)] = name + + if value.type._isnull: + # either unique parameter, or other bound parameters that were + # passed in directly + # set type to that of the column unconditionally + value = value._with_binary_element_type(col.type) + + return value._compiler_dispatch(compiler, **kw) + + +def _key_getters_for_crud_column( + compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState +) -> Tuple[ + Callable[[Union[str, ColumnClause[Any]]], Union[str, Tuple[str, str]]], + Callable[[ColumnClause[Any]], Union[str, Tuple[str, str]]], + _BindNameForColProtocol, +]: + if dml.isupdate(compile_state) and compile_state._extra_froms: + # when extra tables are present, refer to the columns + # in those extra tables as table-qualified, including in + # dictionaries and when rendering bind param names. + # the "main" table of the statement remains unqualified, + # allowing the most compatibility with a non-multi-table + # statement. + _et = set(compile_state._extra_froms) + + c_key_role = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + + def _column_as_key( + key: Union[ColumnClause[Any], str] + ) -> Union[str, Tuple[str, str]]: + str_key = c_key_role(key) + if hasattr(key, "table") and key.table in _et: + return (key.table.name, str_key) # type: ignore + else: + return str_key + + def _getattr_col_key( + col: ColumnClause[Any], + ) -> Union[str, Tuple[str, str]]: + if col.table in _et: + return (col.table.name, col.key) # type: ignore + else: + return col.key + + def _col_bind_name(col: ColumnClause[Any]) -> str: + if col.table in _et: + if TYPE_CHECKING: + assert isinstance(col.table, TableClause) + return "%s_%s" % (col.table.name, col.key) + else: + return col.key + + else: + _column_as_key = functools.partial( + coercions.expect_as_key, roles.DMLColumnRole + ) + _getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa: E501 + + return _column_as_key, _getattr_col_key, _col_bind_name + + +def _scan_insert_from_select_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names] + + assert compiler.stack[-1]["selectable"] is stmt + + compiler.stack[-1]["insert_from_select"] = stmt.select + + add_select_cols: List[_CrudParamElementSQLExpr] = [] + if stmt.include_insert_from_select_defaults: + col_set = set(cols) + for col in stmt.table.columns: + # omit columns that were not in the SELECT statement. + # this will omit columns marked as omit_from_statements naturally, + # as long as that col was not explicit in the SELECT. + # if an omit_from_statements col has a "default" on it, then + # we need to include it, as these defaults should still fire off. + # but, if it has that default and it's the "sentinel" default, + # we don't do sentinel default operations for insert_from_select + # here so we again omit it. + if ( + col not in col_set + and col.default + and not col.default.is_sentinel + ): + cols.append(col) + + for c in cols: + col_key = _getattr_col_key(c) + if col_key in parameters and col_key not in check_columns: + parameters.pop(col_key) + values.append((c, compiler.preparer.format_column(c), None, ())) + else: + _append_param_insert_select_hasdefault( + compiler, stmt, c, add_select_cols, kw + ) + + if add_select_cols: + values.extend(add_select_cols) + ins_from_select = compiler.stack[-1]["insert_from_select"] + if not isinstance(ins_from_select, Select): + raise exc.CompileError( + f"Can't extend statement for INSERT..FROM SELECT to include " + f"additional default-holding column(s) " + f"""{ + ', '.join(repr(key) for _, key, _, _ in add_select_cols) + }. Convert the selectable to a subquery() first, or pass """ + "include_defaults=False to Insert.from_select() to skip these " + "columns." + ) + ins_from_select = ins_from_select._generate() + # copy raw_columns + ins_from_select._raw_columns = list(ins_from_select._raw_columns) + [ + expr for _, _, expr, _ in add_select_cols + ] + compiler.stack[-1]["insert_from_select"] = ins_from_select + + +def _scan_cols( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + ( + need_pks, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + use_insertmanyvalues, + use_sentinel_columns, + ) = _get_returning_modifiers(compiler, stmt, compile_state, toplevel) + + assert compile_state.isupdate or compile_state.isinsert + + if compile_state._parameter_ordering: + parameter_ordering = [ + _column_as_key(key) for key in compile_state._parameter_ordering + ] + ordered_keys = set(parameter_ordering) + cols = [ + stmt.table.c[key] + for key in parameter_ordering + if isinstance(key, str) and key in stmt.table.c + ] + [c for c in stmt.table.c if c.key not in ordered_keys] + + else: + cols = stmt.table.columns + + isinsert = _compile_state_isinsert(compile_state) + if isinsert and not compile_state._has_multi_parameters: + # new rules for #7998. fetch lastrowid or implicit returning + # for autoincrement column even if parameter is NULL, for DBs that + # override NULL param for primary key (sqlite, mysql/mariadb) + autoincrement_col = stmt.table._autoincrement_column + insert_null_pk_still_autoincrements = ( + compiler.dialect.insert_null_pk_still_autoincrements + ) + else: + autoincrement_col = insert_null_pk_still_autoincrements = None + + if stmt._supplemental_returning: + supplemental_returning = set(stmt._supplemental_returning) + else: + supplemental_returning = set() + + compiler_implicit_returning = compiler.implicit_returning + + # TODO - see TODO(return_defaults_columns) below + # cols_in_params = set() + + for c in cols: + # scan through every column in the target table + + col_key = _getattr_col_key(c) + + if col_key in parameters and col_key not in check_columns: + # parameter is present for the column. use that. + + _append_param_parameter( + compiler, + stmt, + compile_state, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + values, + autoincrement_col, + insert_null_pk_still_autoincrements, + kw, + ) + + # TODO - see TODO(return_defaults_columns) below + # cols_in_params.add(c) + + elif isinsert: + # no parameter is present and it's an insert. + + if c.primary_key and need_pks: + # it's a primary key column, it will need to be generated by a + # default generator of some kind, and the statement expects + # inserted_primary_key to be available. + + if implicit_returning: + # we can use RETURNING, find out how to invoke this + # column and get the value where RETURNING is an option. + # we can inline server-side functions in this case. + + _append_param_insert_pk_returning( + compiler, stmt, c, values, kw + ) + else: + # otherwise, find out how to invoke this column + # and get its value where RETURNING is not an option. + # if we have to invoke a server-side function, we need + # to pre-execute it. or if this is a straight + # autoincrement column and the dialect supports it + # we can use cursor.lastrowid. + + _append_param_insert_pk_no_returning( + compiler, stmt, c, values, kw + ) + + elif c.default is not None: + # column has a default, but it's not a pk column, or it is but + # we don't need to get the pk back. + if not c.default.is_sentinel or ( + use_sentinel_columns is not None + ): + _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw + ) + + elif c.server_default is not None: + # column has a DDL-level default, and is either not a pk + # column or we don't need the pk. + if implicit_return_defaults and c in implicit_return_defaults: + compiler_implicit_returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + + elif implicit_return_defaults and c in implicit_return_defaults: + compiler_implicit_returning.append(c) + + elif ( + c.primary_key + and c is not stmt.table._autoincrement_column + and not c.nullable + ): + _warn_pk_with_no_anticipated_value(c) + + elif compile_state.isupdate: + # no parameter is present and it's an insert. + + _append_param_update( + compiler, + compile_state, + stmt, + c, + implicit_return_defaults, + values, + kw, + ) + + # adding supplemental cols to implicit_returning in table + # order so that order is maintained between multiple INSERT + # statements which may have different parameters included, but all + # have the same RETURNING clause + if ( + c in supplemental_returning + and c not in compiler_implicit_returning + ): + compiler_implicit_returning.append(c) + + if supplemental_returning: + # we should have gotten every col into implicit_returning, + # however supplemental returning can also have SQL functions etc. + # in it + remaining_supplemental = supplemental_returning.difference( + compiler_implicit_returning + ) + compiler_implicit_returning.extend( + c + for c in stmt._supplemental_returning + if c in remaining_supplemental + ) + + # TODO(return_defaults_columns): there can still be more columns in + # _return_defaults_columns in the case that they are from something like an + # aliased of the table. we can add them here, however this breaks other ORM + # things. so this is for another day. see + # test/orm/dml/test_update_delete_where.py -> test_update_from_alias + + # if stmt._return_defaults_columns: + # compiler_implicit_returning.extend( + # set(stmt._return_defaults_columns) + # .difference(compiler_implicit_returning) + # .difference(cols_in_params) + # ) + + return (use_insertmanyvalues, use_sentinel_columns) + + +def _setup_delete_return_defaults( + compiler, + stmt, + compile_state, + parameters, + _getattr_col_key, + _column_as_key, + _col_bind_name, + check_columns, + values, + toplevel, + kw, +): + (_, _, implicit_return_defaults, *_) = _get_returning_modifiers( + compiler, stmt, compile_state, toplevel + ) + + if not implicit_return_defaults: + return + + if stmt._return_defaults_columns: + compiler.implicit_returning.extend(implicit_return_defaults) + + if stmt._supplemental_returning: + ir_set = set(compiler.implicit_returning) + compiler.implicit_returning.extend( + c for c in stmt._supplemental_returning if c not in ir_set + ) + + +def _append_param_parameter( + compiler, + stmt, + compile_state, + c, + col_key, + parameters, + _col_bind_name, + implicit_returning, + implicit_return_defaults, + postfetch_lastrowid, + values, + autoincrement_col, + insert_null_pk_still_autoincrements, + kw, +): + value = parameters.pop(col_key) + + col_value = compiler.preparer.format_column( + c, use_table=compile_state.include_table_with_column_exprs + ) + + accumulated_bind_names: Set[str] = set() + + if coercions._is_literal(value): + if ( + insert_null_pk_still_autoincrements + and c.primary_key + and c is autoincrement_col + ): + # support use case for #7998, fetch autoincrement cols + # even if value was given. + + if postfetch_lastrowid: + compiler.postfetch_lastrowid = True + elif implicit_returning: + compiler.implicit_returning.append(c) + + value = _create_bind_param( + compiler, + c, + value, + required=value is REQUIRED, + name=( + _col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c) + ), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) + elif value._is_bind_parameter: + if ( + insert_null_pk_still_autoincrements + and value.value is None + and c.primary_key + and c is autoincrement_col + ): + # support use case for #7998, fetch autoincrement cols + # even if value was given + if implicit_returning: + compiler.implicit_returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + + value = _handle_values_anonymous_param( + compiler, + c, + value, + name=( + _col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c) + ), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) + else: + # value is a SQL expression + value = compiler.process( + value.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ) + + if compile_state.isupdate: + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + + else: + compiler.postfetch.append(c) + else: + if c.primary_key: + if implicit_returning: + compiler.implicit_returning.append(c) + elif compiler.dialect.postfetch_lastrowid: + compiler.postfetch_lastrowid = True + + elif implicit_return_defaults and (c in implicit_return_defaults): + compiler.implicit_returning.append(c) + + else: + # postfetch specifically means, "we can SELECT the row we just + # inserted by primary key to get back the server generated + # defaults". so by definition this can't be used to get the + # primary key value back, because we need to have it ahead of + # time. + + compiler.postfetch.append(c) + + values.append((c, col_value, value, accumulated_bind_names)) + + +def _append_param_insert_pk_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and RETURNING + is available. + + """ + if c.default is not None: + if c.default.is_sequence: + if compiler.dialect.supports_sequences and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ): + accumulated_bind_names: Set[str] = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + compiler.implicit_returning.append(c) + elif c.default.is_clause_element: + accumulated_bind_names = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + compiler.implicit_returning.append(c) + else: + # client side default. OK we can't use RETURNING, need to + # do a "prefetch", which in fact fetches the default value + # on the Python side + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + elif c is stmt.table._autoincrement_column or c.server_default is not None: + compiler.implicit_returning.append(c) + elif not c.nullable: + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _warn_pk_with_no_anticipated_value(c) + + +def _append_param_insert_pk_no_returning(compiler, stmt, c, values, kw): + """Create a primary key expression in the INSERT statement where + we want to populate result.inserted_primary_key and we cannot use + RETURNING. + + Depending on the kind of default here we may create a bound parameter + in the INSERT statement and pre-execute a default generation function, + or we may use cursor.lastrowid if supported by the dialect. + + + """ + + if ( + # column has a Python-side default + c.default is not None + and ( + # and it either is not a sequence, or it is and we support + # sequences and want to invoke it + not c.default.is_sequence + or ( + compiler.dialect.supports_sequences + and ( + not c.default.optional + or not compiler.dialect.sequences_optional + ) + ) + ) + ) or ( + # column is the "autoincrement column" + c is stmt.table._autoincrement_column + and ( + # dialect can't use cursor.lastrowid + not compiler.dialect.postfetch_lastrowid + and ( + # column has a Sequence and we support those + ( + c.default is not None + and c.default.is_sequence + and compiler.dialect.supports_sequences + ) + or + # column has no default on it, but dialect can run the + # "autoincrement" mechanism explicitly, e.g. PostgreSQL + # SERIAL we know the sequence name + ( + c.default is None + and compiler.dialect.preexecute_autoincrement_sequences + ) + ) + ) + ): + # do a pre-execute of the default + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + elif ( + c.default is None + and c.server_default is None + and not c.nullable + and c is not stmt.table._autoincrement_column + ): + # no .default, no .server_default, not autoincrement, we have + # no indication this primary key column will have any value + _warn_pk_with_no_anticipated_value(c) + elif compiler.dialect.postfetch_lastrowid: + # finally, where it seems like there will be a generated primary key + # value and we haven't set up any other way to fetch it, and the + # dialect supports cursor.lastrowid, switch on the lastrowid flag so + # that the DefaultExecutionContext calls upon cursor.lastrowid + compiler.postfetch_lastrowid = True + + +def _append_param_insert_hasdefault( + compiler, stmt, c, implicit_return_defaults, values, kw +): + if c.default.is_sequence: + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): + accumulated_bind_names: Set[str] = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default, + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + elif not c.primary_key: + compiler.postfetch.append(c) + elif c.default.is_clause_element: + accumulated_bind_names = set() + values.append( + ( + c, + compiler.preparer.format_column(c), + compiler.process( + c.default.arg.self_group(), + accumulate_bind_names=accumulated_bind_names, + **kw, + ), + accumulated_bind_names, + ) + ) + + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + elif not c.primary_key: + # don't add primary key column to postfetch + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + + +def _append_param_insert_select_hasdefault( + compiler: SQLCompiler, + stmt: ValuesBase, + c: ColumnClause[Any], + values: List[_CrudParamElementSQLExpr], + kw: Dict[str, Any], +) -> None: + if default_is_sequence(c.default): + if compiler.dialect.supports_sequences and ( + not c.default.optional or not compiler.dialect.sequences_optional + ): + values.append( + ( + c, + compiler.preparer.format_column(c), + c.default.next_value(), + (), + ) + ) + elif default_is_clause_element(c.default): + values.append( + ( + c, + compiler.preparer.format_column(c), + c.default.arg.self_group(), + (), + ) + ) + else: + values.append( + ( + c, + compiler.preparer.format_column(c), + _create_insert_prefetch_bind_param( + compiler, c, process=False, **kw + ), + (c.key,), + ) + ) + + +def _append_param_update( + compiler, compile_state, stmt, c, implicit_return_defaults, values, kw +): + include_table = compile_state.include_table_with_column_exprs + if c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + ( + c, + compiler.preparer.format_column( + c, + use_table=include_table, + ), + compiler.process(c.onupdate.arg.self_group(), **kw), + (), + ) + ) + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + else: + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.preparer.format_column( + c, + use_table=include_table, + ), + _create_update_prefetch_bind_param(compiler, c, **kw), + (c.key,), + ) + ) + elif c.server_onupdate is not None: + if implicit_return_defaults and c in implicit_return_defaults: + compiler.implicit_returning.append(c) + else: + compiler.postfetch.append(c) + elif ( + implicit_return_defaults + and (stmt._return_defaults_columns or not stmt._return_defaults) + and c in implicit_return_defaults + ): + compiler.implicit_returning.append(c) + + +@overload +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: ... + + +@overload +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: ... + + +def _create_insert_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.insert_prefetch.append(c) # type: ignore + return param + + +@overload +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[True] = ..., + **kw: Any, +) -> str: ... + + +@overload +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: Literal[False], + **kw: Any, +) -> elements.BindParameter[Any]: ... + + +def _create_update_prefetch_bind_param( + compiler: SQLCompiler, + c: ColumnElement[Any], + process: bool = True, + name: Optional[str] = None, + **kw: Any, +) -> Union[elements.BindParameter[Any], str]: + param = _create_bind_param( + compiler, c, None, process=process, name=name, **kw + ) + compiler.update_prefetch.append(c) # type: ignore + return param + + +class _multiparam_column(elements.ColumnElement[Any]): + _is_multiparam_column = True + + def __init__(self, original, index): + self.index = index + self.key = "%s_m%d" % (original.key, index + 1) + self.original = original + self.default = original.default + self.type = original.type + + def compare(self, other, **kw): + raise NotImplementedError() + + def _copy_internals(self, other, **kw): + raise NotImplementedError() + + def __eq__(self, other): + return ( + isinstance(other, _multiparam_column) + and other.key == self.key + and other.original == self.original + ) + + @util.memoized_property + def _default_description_tuple(self) -> _DefaultDescriptionTuple: + """used by default.py -> _process_execute_defaults()""" + + return _DefaultDescriptionTuple._from_column_default(self.default) + + @util.memoized_property + def _onupdate_description_tuple(self) -> _DefaultDescriptionTuple: + """used by default.py -> _process_execute_defaults()""" + + return _DefaultDescriptionTuple._from_column_default(self.onupdate) + + +def _process_multiparam_default_bind( + compiler: SQLCompiler, + stmt: ValuesBase, + c: KeyedColumnElement[Any], + index: int, + kw: Dict[str, Any], +) -> str: + if not c.default: + raise exc.CompileError( + "INSERT value for column %s is explicitly rendered as a bound" + "parameter in the VALUES clause; " + "a Python-side value or SQL expression is required" % c + ) + elif default_is_clause_element(c.default): + return compiler.process(c.default.arg.self_group(), **kw) + elif c.default.is_sequence: + # these conditions would have been established + # by append_param_insert_(?:hasdefault|pk_returning|pk_no_returning) + # in order for us to be here, so these don't need to be + # checked + # assert compiler.dialect.supports_sequences and ( + # not c.default.optional + # or not compiler.dialect.sequences_optional + # ) + return compiler.process(c.default, **kw) + else: + col = _multiparam_column(c, index) + assert isinstance(stmt, dml.Insert) + return _create_insert_prefetch_bind_param( + compiler, col, process=True, **kw + ) + + +def _get_update_multitable_params( + compiler, + stmt, + compile_state, + stmt_parameter_tuples, + check_columns, + _col_bind_name, + _getattr_col_key, + values, + kw, +): + normalized_params = { + coercions.expect(roles.DMLColumnRole, c): param + for c, param in stmt_parameter_tuples or () + } + + include_table = compile_state.include_table_with_column_exprs + + affected_tables = set() + for t in compile_state._extra_froms: + for c in t.c: + if c in normalized_params: + affected_tables.add(t) + check_columns[_getattr_col_key(c)] = c + value = normalized_params[c] + + col_value = compiler.process(c, include_table=include_table) + if coercions._is_literal(value): + value = _create_bind_param( + compiler, + c, + value, + required=value is REQUIRED, + name=_col_bind_name(c), + **kw, # TODO: no test coverage for literal binds here + ) + accumulated_bind_names: Iterable[str] = (c.key,) + elif value._is_bind_parameter: + cbn = _col_bind_name(c) + value = _handle_values_anonymous_param( + compiler, c, value, name=cbn, **kw + ) + accumulated_bind_names = (cbn,) + else: + compiler.postfetch.append(c) + value = compiler.process(value.self_group(), **kw) + accumulated_bind_names = () + values.append((c, col_value, value, accumulated_bind_names)) + # determine tables which are actually to be updated - process onupdate + # and server_onupdate for these + for t in affected_tables: + for c in t.c: + if c in normalized_params: + continue + elif c.onupdate is not None and not c.onupdate.is_sequence: + if c.onupdate.is_clause_element: + values.append( + ( + c, + compiler.process(c, include_table=include_table), + compiler.process( + c.onupdate.arg.self_group(), **kw + ), + (), + ) + ) + compiler.postfetch.append(c) + else: + values.append( + ( + c, + compiler.process(c, include_table=include_table), + _create_update_prefetch_bind_param( + compiler, c, name=_col_bind_name(c), **kw + ), + (c.key,), + ) + ) + elif c.server_onupdate is not None: + compiler.postfetch.append(c) + + +def _extend_values_for_multiparams( + compiler: SQLCompiler, + stmt: ValuesBase, + compile_state: DMLState, + initial_values: Sequence[_CrudParamElementStr], + _column_as_key: Callable[..., str], + kw: Dict[str, Any], +) -> List[Sequence[_CrudParamElementStr]]: + values_0 = initial_values + values = [initial_values] + + mp = compile_state._multi_parameters + assert mp is not None + for i, row in enumerate(mp[1:]): + extension: List[_CrudParamElementStr] = [] + + row = {_column_as_key(key): v for key, v in row.items()} + + for col, col_expr, param, accumulated_names in values_0: + if col.key in row: + key = col.key + + if coercions._is_literal(row[key]): + new_param = _create_bind_param( + compiler, + col, + row[key], + name="%s_m%d" % (col.key, i + 1), + **kw, + ) + else: + new_param = compiler.process(row[key].self_group(), **kw) + else: + new_param = _process_multiparam_default_bind( + compiler, stmt, col, i, kw + ) + + extension.append((col, col_expr, new_param, accumulated_names)) + + values.append(extension) + + return values + + +def _get_stmt_parameter_tuples_params( + compiler, + compile_state, + parameters, + stmt_parameter_tuples, + _column_as_key, + values, + kw, +): + for k, v in stmt_parameter_tuples: + colkey = _column_as_key(k) + if colkey is not None: + parameters.setdefault(colkey, v) + else: + # a non-Column expression on the left side; + # add it to values() in an "as-is" state, + # coercing right side to bound param + + # note one of the main use cases for this is array slice + # updates on PostgreSQL, as the left side is also an expression. + + col_expr = compiler.process( + k, include_table=compile_state.include_table_with_column_exprs + ) + + if coercions._is_literal(v): + v = compiler.process( + elements.BindParameter(None, v, type_=k.type), **kw + ) + else: + if v._is_bind_parameter and v.type._isnull: + # either unique parameter, or other bound parameters that + # were passed in directly + # set type to that of the column unconditionally + v = v._with_binary_element_type(k.type) + + v = compiler.process(v.self_group(), **kw) + + # TODO: not sure if accumulated_bind_names applies here + values.append((k, col_expr, v, ())) + + +def _get_returning_modifiers(compiler, stmt, compile_state, toplevel): + """determines RETURNING strategy, if any, for the statement. + + This is where it's determined what we need to fetch from the + INSERT or UPDATE statement after it's invoked. + + """ + + dialect = compiler.dialect + + need_pks = ( + toplevel + and _compile_state_isinsert(compile_state) + and not stmt._inline + and ( + not compiler.for_executemany + or (dialect.insert_executemany_returning and stmt._return_defaults) + ) + and not stmt._returning + # and (not stmt._returning or stmt._return_defaults) + and not compile_state._has_multi_parameters + ) + + # check if we have access to simple cursor.lastrowid. we can use that + # after the INSERT if that's all we need. + postfetch_lastrowid = ( + need_pks + and dialect.postfetch_lastrowid + and stmt.table._autoincrement_column is not None + ) + + # see if we want to add RETURNING to an INSERT in order to get + # primary key columns back. This would be instead of postfetch_lastrowid + # if that's set. + implicit_returning = ( + # statement itself can veto it + need_pks + # the dialect can veto it if it just doesnt support RETURNING + # with INSERT + and dialect.insert_returning + # user-defined implicit_returning on Table can veto it + and compile_state._primary_table.implicit_returning + # the compile_state can veto it (SQlite uses this to disable + # RETURNING for an ON CONFLICT insert, as SQLite does not return + # for rows that were updated, which is wrong) + and compile_state._supports_implicit_returning + and ( + # since we support MariaDB and SQLite which also support lastrowid, + # decide if we should use lastrowid or RETURNING. for insert + # that didnt call return_defaults() and has just one set of + # parameters, we can use lastrowid. this is more "traditional" + # and a lot of weird use cases are supported by it. + # SQLite lastrowid times 3x faster than returning, + # Mariadb lastrowid 2x faster than returning + (not postfetch_lastrowid or dialect.favor_returning_over_lastrowid) + or compile_state._has_multi_parameters + or stmt._return_defaults + ) + ) + if implicit_returning: + postfetch_lastrowid = False + + if _compile_state_isinsert(compile_state): + should_implicit_return_defaults = ( + implicit_returning and stmt._return_defaults + ) + explicit_returning = ( + should_implicit_return_defaults + or stmt._returning + or stmt._supplemental_returning + ) + use_insertmanyvalues = ( + toplevel + and compiler.for_executemany + and dialect.use_insertmanyvalues + and ( + explicit_returning or dialect.use_insertmanyvalues_wo_returning + ) + ) + + use_sentinel_columns = None + if ( + use_insertmanyvalues + and explicit_returning + and stmt._sort_by_parameter_order + ): + use_sentinel_columns = compiler._get_sentinel_column_for_table( + stmt.table + ) + + elif compile_state.isupdate: + should_implicit_return_defaults = ( + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and dialect.update_returning + ) + use_insertmanyvalues = False + use_sentinel_columns = None + elif compile_state.isdelete: + should_implicit_return_defaults = ( + stmt._return_defaults + and compile_state._primary_table.implicit_returning + and compile_state._supports_implicit_returning + and dialect.delete_returning + ) + use_insertmanyvalues = False + use_sentinel_columns = None + else: + should_implicit_return_defaults = False # pragma: no cover + use_insertmanyvalues = False + use_sentinel_columns = None + + if should_implicit_return_defaults: + if not stmt._return_defaults_columns: + # TODO: this is weird. See #9685 where we have to + # take an extra step to prevent this from happening. why + # would this ever be *all* columns? but if we set to blank, then + # that seems to break things also in the ORM. So we should + # try to clean this up and figure out what return_defaults + # needs to do w/ the ORM etc. here + implicit_return_defaults = set(stmt.table.c) + else: + implicit_return_defaults = set(stmt._return_defaults_columns) + else: + implicit_return_defaults = None + + return ( + need_pks, + implicit_returning or should_implicit_return_defaults, + implicit_return_defaults, + postfetch_lastrowid, + use_insertmanyvalues, + use_sentinel_columns, + ) + + +def _warn_pk_with_no_anticipated_value(c): + msg = ( + "Column '%s.%s' is marked as a member of the " + "primary key for table '%s', " + "but has no Python-side or server-side default generator indicated, " + "nor does it indicate 'autoincrement=True' or 'nullable=True', " + "and no explicit value is passed. " + "Primary key columns typically may not store NULL." + % (c.table.fullname, c.name, c.table.fullname) + ) + if len(c.table.primary_key) > 1: + msg += ( + " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be " + "indicated explicitly for composite (e.g. multicolumn) primary " + "keys if AUTO_INCREMENT/SERIAL/IDENTITY " + "behavior is expected for one of the columns in the primary key. " + "CREATE TABLE statements are impacted by this change as well on " + "most backends." + ) + util.warn(msg) diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/ddl.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/ddl.py new file mode 100644 index 0000000..d9e3f67 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/ddl.py @@ -0,0 +1,1378 @@ +# sql/ddl.py +# Copyright (C) 2009-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +""" +Provides the hierarchy of DDL-defining schema items as well as routines +to invoke them for a create/drop call. + +""" +from __future__ import annotations + +import contextlib +import typing +from typing import Any +from typing import Callable +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence as typing_Sequence +from typing import Tuple + +from . import roles +from .base import _generative +from .base import Executable +from .base import SchemaVisitor +from .elements import ClauseElement +from .. import exc +from .. import util +from ..util import topological +from ..util.typing import Protocol +from ..util.typing import Self + +if typing.TYPE_CHECKING: + from .compiler import Compiled + from .compiler import DDLCompiler + from .elements import BindParameter + from .schema import Constraint + from .schema import ForeignKeyConstraint + from .schema import SchemaItem + from .schema import Sequence + from .schema import Table + from .selectable import TableClause + from ..engine.base import Connection + from ..engine.interfaces import CacheStats + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import Dialect + from ..engine.interfaces import SchemaTranslateMapType + + +class BaseDDLElement(ClauseElement): + """The root of DDL constructs, including those that are sub-elements + within the "create table" and other processes. + + .. versionadded:: 2.0 + + """ + + _hierarchy_supports_caching = False + """disable cache warnings for all _DDLCompiles subclasses. """ + + def _compiler(self, dialect, **kw): + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + return dialect.ddl_compiler(dialect, self, **kw) + + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[SchemaTranslateMapType] = None, + **kw: Any, + ) -> Tuple[ + Compiled, Optional[typing_Sequence[BindParameter[Any]]], CacheStats + ]: + raise NotImplementedError() + + +class DDLIfCallable(Protocol): + def __call__( + self, + ddl: BaseDDLElement, + target: SchemaItem, + bind: Optional[Connection], + tables: Optional[List[Table]] = None, + state: Optional[Any] = None, + *, + dialect: Dialect, + compiler: Optional[DDLCompiler] = ..., + checkfirst: bool, + ) -> bool: ... + + +class DDLIf(typing.NamedTuple): + dialect: Optional[str] + callable_: Optional[DDLIfCallable] + state: Optional[Any] + + def _should_execute( + self, + ddl: BaseDDLElement, + target: SchemaItem, + bind: Optional[Connection], + compiler: Optional[DDLCompiler] = None, + **kw: Any, + ) -> bool: + if bind is not None: + dialect = bind.dialect + elif compiler is not None: + dialect = compiler.dialect + else: + assert False, "compiler or dialect is required" + + if isinstance(self.dialect, str): + if self.dialect != dialect.name: + return False + elif isinstance(self.dialect, (tuple, list, set)): + if dialect.name not in self.dialect: + return False + if self.callable_ is not None and not self.callable_( + ddl, + target, + bind, + state=self.state, + dialect=dialect, + compiler=compiler, + **kw, + ): + return False + + return True + + +class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): + """Base class for standalone executable DDL expression constructs. + + This class is the base for the general purpose :class:`.DDL` class, + as well as the various create/drop clause constructs such as + :class:`.CreateTable`, :class:`.DropTable`, :class:`.AddConstraint`, + etc. + + .. versionchanged:: 2.0 :class:`.ExecutableDDLElement` is renamed from + :class:`.DDLElement`, which still exists for backwards compatibility. + + :class:`.ExecutableDDLElement` integrates closely with SQLAlchemy events, + introduced in :ref:`event_toplevel`. An instance of one is + itself an event receiving callable:: + + event.listen( + users, + 'after_create', + AddConstraint(constraint).execute_if(dialect='postgresql') + ) + + .. seealso:: + + :class:`.DDL` + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + :ref:`schema_ddl_sequences` + + """ + + _ddl_if: Optional[DDLIf] = None + target: Optional[SchemaItem] = None + + def _execute_on_connection( + self, connection, distilled_params, execution_options + ): + return connection._execute_ddl( + self, distilled_params, execution_options + ) + + @_generative + def against(self, target: SchemaItem) -> Self: + """Return a copy of this :class:`_schema.ExecutableDDLElement` which + will include the given target. + + This essentially applies the given item to the ``.target`` attribute of + the returned :class:`_schema.ExecutableDDLElement` object. This target + is then usable by event handlers and compilation routines in order to + provide services such as tokenization of a DDL string in terms of a + particular :class:`_schema.Table`. + + When a :class:`_schema.ExecutableDDLElement` object is established as + an event handler for the :meth:`_events.DDLEvents.before_create` or + :meth:`_events.DDLEvents.after_create` events, and the event then + occurs for a given target such as a :class:`_schema.Constraint` or + :class:`_schema.Table`, that target is established with a copy of the + :class:`_schema.ExecutableDDLElement` object using this method, which + then proceeds to the :meth:`_schema.ExecutableDDLElement.execute` + method in order to invoke the actual DDL instruction. + + :param target: a :class:`_schema.SchemaItem` that will be the subject + of a DDL operation. + + :return: a copy of this :class:`_schema.ExecutableDDLElement` with the + ``.target`` attribute assigned to the given + :class:`_schema.SchemaItem`. + + .. seealso:: + + :class:`_schema.DDL` - uses tokenization against the "target" when + processing the DDL string. + + """ + self.target = target + return self + + @_generative + def execute_if( + self, + dialect: Optional[str] = None, + callable_: Optional[DDLIfCallable] = None, + state: Optional[Any] = None, + ) -> Self: + r"""Return a callable that will execute this + :class:`_ddl.ExecutableDDLElement` conditionally within an event + handler. + + Used to provide a wrapper for event listening:: + + event.listen( + metadata, + 'before_create', + DDL("my_ddl").execute_if(dialect='postgresql') + ) + + :param dialect: May be a string or tuple of strings. + If a string, it will be compared to the name of the + executing database dialect:: + + DDL('something').execute_if(dialect='postgresql') + + If a tuple, specifies multiple dialect names:: + + DDL('something').execute_if(dialect=('postgresql', 'mysql')) + + :param callable\_: A callable, which will be invoked with + three positional arguments as well as optional keyword + arguments: + + :ddl: + This DDL element. + + :target: + The :class:`_schema.Table` or :class:`_schema.MetaData` + object which is the + target of this event. May be None if the DDL is executed + explicitly. + + :bind: + The :class:`_engine.Connection` being used for DDL execution. + May be None if this construct is being created inline within + a table, in which case ``compiler`` will be present. + + :tables: + Optional keyword argument - a list of Table objects which are to + be created/ dropped within a MetaData.create_all() or drop_all() + method call. + + :dialect: keyword argument, but always present - the + :class:`.Dialect` involved in the operation. + + :compiler: keyword argument. Will be ``None`` for an engine + level DDL invocation, but will refer to a :class:`.DDLCompiler` + if this DDL element is being created inline within a table. + + :state: + Optional keyword argument - will be the ``state`` argument + passed to this function. + + :checkfirst: + Keyword argument, will be True if the 'checkfirst' flag was + set during the call to ``create()``, ``create_all()``, + ``drop()``, ``drop_all()``. + + If the callable returns a True value, the DDL statement will be + executed. + + :param state: any value which will be passed to the callable\_ + as the ``state`` keyword argument. + + .. seealso:: + + :meth:`.SchemaItem.ddl_if` + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + """ + self._ddl_if = DDLIf(dialect, callable_, state) + return self + + def _should_execute(self, target, bind, **kw): + if self._ddl_if is None: + return True + else: + return self._ddl_if._should_execute(self, target, bind, **kw) + + def _invoke_with(self, bind): + if self._should_execute(self.target, bind): + return bind.execute(self) + + def __call__(self, target, bind, **kw): + """Execute the DDL as a ddl_listener.""" + + self.against(target)._invoke_with(bind) + + def _generate(self): + s = self.__class__.__new__(self.__class__) + s.__dict__ = self.__dict__.copy() + return s + + +DDLElement = ExecutableDDLElement +""":class:`.DDLElement` is renamed to :class:`.ExecutableDDLElement`.""" + + +class DDL(ExecutableDDLElement): + """A literal DDL statement. + + Specifies literal SQL DDL to be executed by the database. DDL objects + function as DDL event listeners, and can be subscribed to those events + listed in :class:`.DDLEvents`, using either :class:`_schema.Table` or + :class:`_schema.MetaData` objects as targets. + Basic templating support allows + a single DDL instance to handle repetitive tasks for multiple tables. + + Examples:: + + from sqlalchemy import event, DDL + + tbl = Table('users', metadata, Column('uid', Integer)) + event.listen(tbl, 'before_create', DDL('DROP TRIGGER users_trigger')) + + spow = DDL('ALTER TABLE %(table)s SET secretpowers TRUE') + event.listen(tbl, 'after_create', spow.execute_if(dialect='somedb')) + + drop_spow = DDL('ALTER TABLE users SET secretpowers FALSE') + connection.execute(drop_spow) + + When operating on Table events, the following ``statement`` + string substitutions are available:: + + %(table)s - the Table name, with any required quoting applied + %(schema)s - the schema name, with any required quoting applied + %(fullname)s - the Table name including schema, quoted if needed + + The DDL's "context", if any, will be combined with the standard + substitutions noted above. Keys present in the context will override + the standard substitutions. + + """ + + __visit_name__ = "ddl" + + def __init__(self, statement, context=None): + """Create a DDL statement. + + :param statement: + A string or unicode string to be executed. Statements will be + processed with Python's string formatting operator using + a fixed set of string substitutions, as well as additional + substitutions provided by the optional :paramref:`.DDL.context` + parameter. + + A literal '%' in a statement must be escaped as '%%'. + + SQL bind parameters are not available in DDL statements. + + :param context: + Optional dictionary, defaults to None. These values will be + available for use in string substitutions on the DDL statement. + + .. seealso:: + + :class:`.DDLEvents` + + :ref:`event_toplevel` + + """ + + if not isinstance(statement, str): + raise exc.ArgumentError( + "Expected a string or unicode SQL statement, got '%r'" + % statement + ) + + self.statement = statement + self.context = context or {} + + def __repr__(self): + parts = [repr(self.statement)] + if self.context: + parts.append(f"context={self.context}") + + return "<%s@%s; %s>" % ( + type(self).__name__, + id(self), + ", ".join(parts), + ) + + +class _CreateDropBase(ExecutableDDLElement): + """Base class for DDL constructs that represent CREATE and DROP or + equivalents. + + The common theme of _CreateDropBase is a single + ``element`` attribute which refers to the element + to be created or dropped. + + """ + + def __init__( + self, + element, + ): + self.element = self.target = element + self._ddl_if = getattr(element, "_ddl_if", None) + + @property + def stringify_dialect(self): + return self.element.create_drop_stringify_dialect + + def _create_rule_disable(self, compiler): + """Allow disable of _create_rule using a callable. + + Pass to _create_rule using + util.portable_instancemethod(self._create_rule_disable) + to retain serializability. + + """ + return False + + +class _CreateBase(_CreateDropBase): + def __init__(self, element, if_not_exists=False): + super().__init__(element) + self.if_not_exists = if_not_exists + + +class _DropBase(_CreateDropBase): + def __init__(self, element, if_exists=False): + super().__init__(element) + self.if_exists = if_exists + + +class CreateSchema(_CreateBase): + """Represent a CREATE SCHEMA statement. + + The argument here is the string name of the schema. + + """ + + __visit_name__ = "create_schema" + + stringify_dialect = "default" + + def __init__( + self, + name, + if_not_exists=False, + ): + """Create a new :class:`.CreateSchema` construct.""" + + super().__init__(element=name, if_not_exists=if_not_exists) + + +class DropSchema(_DropBase): + """Represent a DROP SCHEMA statement. + + The argument here is the string name of the schema. + + """ + + __visit_name__ = "drop_schema" + + stringify_dialect = "default" + + def __init__( + self, + name, + cascade=False, + if_exists=False, + ): + """Create a new :class:`.DropSchema` construct.""" + + super().__init__(element=name, if_exists=if_exists) + self.cascade = cascade + + +class CreateTable(_CreateBase): + """Represent a CREATE TABLE statement.""" + + __visit_name__ = "create_table" + + def __init__( + self, + element: Table, + include_foreign_key_constraints: Optional[ + typing_Sequence[ForeignKeyConstraint] + ] = None, + if_not_exists: bool = False, + ): + """Create a :class:`.CreateTable` construct. + + :param element: a :class:`_schema.Table` that's the subject + of the CREATE + :param on: See the description for 'on' in :class:`.DDL`. + :param include_foreign_key_constraints: optional sequence of + :class:`_schema.ForeignKeyConstraint` objects that will be included + inline within the CREATE construct; if omitted, all foreign key + constraints that do not specify use_alter=True are included. + + :param if_not_exists: if True, an IF NOT EXISTS operator will be + applied to the construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_not_exists=if_not_exists) + self.columns = [CreateColumn(column) for column in element.columns] + self.include_foreign_key_constraints = include_foreign_key_constraints + + +class _DropView(_DropBase): + """Semi-public 'DROP VIEW' construct. + + Used by the test suite for dialect-agnostic drops of views. + This object will eventually be part of a public "view" API. + + """ + + __visit_name__ = "drop_view" + + +class CreateConstraint(BaseDDLElement): + def __init__(self, element: Constraint): + self.element = element + + +class CreateColumn(BaseDDLElement): + """Represent a :class:`_schema.Column` + as rendered in a CREATE TABLE statement, + via the :class:`.CreateTable` construct. + + This is provided to support custom column DDL within the generation + of CREATE TABLE statements, by using the + compiler extension documented in :ref:`sqlalchemy.ext.compiler_toplevel` + to extend :class:`.CreateColumn`. + + Typical integration is to examine the incoming :class:`_schema.Column` + object, and to redirect compilation if a particular flag or condition + is found:: + + from sqlalchemy import schema + from sqlalchemy.ext.compiler import compiles + + @compiles(schema.CreateColumn) + def compile(element, compiler, **kw): + column = element.element + + if "special" not in column.info: + return compiler.visit_create_column(element, **kw) + + text = "%s SPECIAL DIRECTIVE %s" % ( + column.name, + compiler.type_compiler.process(column.type) + ) + default = compiler.get_column_default_string(column) + if default is not None: + text += " DEFAULT " + default + + if not column.nullable: + text += " NOT NULL" + + if column.constraints: + text += " ".join( + compiler.process(const) + for const in column.constraints) + return text + + The above construct can be applied to a :class:`_schema.Table` + as follows:: + + from sqlalchemy import Table, Metadata, Column, Integer, String + from sqlalchemy import schema + + metadata = MetaData() + + table = Table('mytable', MetaData(), + Column('x', Integer, info={"special":True}, primary_key=True), + Column('y', String(50)), + Column('z', String(20), info={"special":True}) + ) + + metadata.create_all(conn) + + Above, the directives we've added to the :attr:`_schema.Column.info` + collection + will be detected by our custom compilation scheme:: + + CREATE TABLE mytable ( + x SPECIAL DIRECTIVE INTEGER NOT NULL, + y VARCHAR(50), + z SPECIAL DIRECTIVE VARCHAR(20), + PRIMARY KEY (x) + ) + + The :class:`.CreateColumn` construct can also be used to skip certain + columns when producing a ``CREATE TABLE``. This is accomplished by + creating a compilation rule that conditionally returns ``None``. + This is essentially how to produce the same effect as using the + ``system=True`` argument on :class:`_schema.Column`, which marks a column + as an implicitly-present "system" column. + + For example, suppose we wish to produce a :class:`_schema.Table` + which skips + rendering of the PostgreSQL ``xmin`` column against the PostgreSQL + backend, but on other backends does render it, in anticipation of a + triggered rule. A conditional compilation rule could skip this name only + on PostgreSQL:: + + from sqlalchemy.schema import CreateColumn + + @compiles(CreateColumn, "postgresql") + def skip_xmin(element, compiler, **kw): + if element.element.name == 'xmin': + return None + else: + return compiler.visit_create_column(element, **kw) + + + my_table = Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('xmin', Integer) + ) + + Above, a :class:`.CreateTable` construct will generate a ``CREATE TABLE`` + which only includes the ``id`` column in the string; the ``xmin`` column + will be omitted, but only against the PostgreSQL backend. + + """ + + __visit_name__ = "create_column" + + def __init__(self, element): + self.element = element + + +class DropTable(_DropBase): + """Represent a DROP TABLE statement.""" + + __visit_name__ = "drop_table" + + def __init__(self, element: Table, if_exists: bool = False): + """Create a :class:`.DropTable` construct. + + :param element: a :class:`_schema.Table` that's the subject + of the DROP. + :param on: See the description for 'on' in :class:`.DDL`. + :param if_exists: if True, an IF EXISTS operator will be applied to the + construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_exists=if_exists) + + +class CreateSequence(_CreateBase): + """Represent a CREATE SEQUENCE statement.""" + + __visit_name__ = "create_sequence" + + def __init__(self, element: Sequence, if_not_exists: bool = False): + super().__init__(element, if_not_exists=if_not_exists) + + +class DropSequence(_DropBase): + """Represent a DROP SEQUENCE statement.""" + + __visit_name__ = "drop_sequence" + + def __init__(self, element: Sequence, if_exists: bool = False): + super().__init__(element, if_exists=if_exists) + + +class CreateIndex(_CreateBase): + """Represent a CREATE INDEX statement.""" + + __visit_name__ = "create_index" + + def __init__(self, element, if_not_exists=False): + """Create a :class:`.Createindex` construct. + + :param element: a :class:`_schema.Index` that's the subject + of the CREATE. + :param if_not_exists: if True, an IF NOT EXISTS operator will be + applied to the construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_not_exists=if_not_exists) + + +class DropIndex(_DropBase): + """Represent a DROP INDEX statement.""" + + __visit_name__ = "drop_index" + + def __init__(self, element, if_exists=False): + """Create a :class:`.DropIndex` construct. + + :param element: a :class:`_schema.Index` that's the subject + of the DROP. + :param if_exists: if True, an IF EXISTS operator will be applied to the + construct. + + .. versionadded:: 1.4.0b2 + + """ + super().__init__(element, if_exists=if_exists) + + +class AddConstraint(_CreateBase): + """Represent an ALTER TABLE ADD CONSTRAINT statement.""" + + __visit_name__ = "add_constraint" + + def __init__(self, element): + super().__init__(element) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) + + +class DropConstraint(_DropBase): + """Represent an ALTER TABLE DROP CONSTRAINT statement.""" + + __visit_name__ = "drop_constraint" + + def __init__(self, element, cascade=False, if_exists=False, **kw): + self.cascade = cascade + super().__init__(element, if_exists=if_exists, **kw) + element._create_rule = util.portable_instancemethod( + self._create_rule_disable + ) + + +class SetTableComment(_CreateDropBase): + """Represent a COMMENT ON TABLE IS statement.""" + + __visit_name__ = "set_table_comment" + + +class DropTableComment(_CreateDropBase): + """Represent a COMMENT ON TABLE '' statement. + + Note this varies a lot across database backends. + + """ + + __visit_name__ = "drop_table_comment" + + +class SetColumnComment(_CreateDropBase): + """Represent a COMMENT ON COLUMN IS statement.""" + + __visit_name__ = "set_column_comment" + + +class DropColumnComment(_CreateDropBase): + """Represent a COMMENT ON COLUMN IS NULL statement.""" + + __visit_name__ = "drop_column_comment" + + +class SetConstraintComment(_CreateDropBase): + """Represent a COMMENT ON CONSTRAINT IS statement.""" + + __visit_name__ = "set_constraint_comment" + + +class DropConstraintComment(_CreateDropBase): + """Represent a COMMENT ON CONSTRAINT IS NULL statement.""" + + __visit_name__ = "drop_constraint_comment" + + +class InvokeDDLBase(SchemaVisitor): + def __init__(self, connection): + self.connection = connection + + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + raise NotImplementedError() + + +class InvokeCreateDDLBase(InvokeDDLBase): + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + target.dispatch.before_create( + target, self.connection, _ddl_runner=self, **kw + ) + yield + target.dispatch.after_create( + target, self.connection, _ddl_runner=self, **kw + ) + + +class InvokeDropDDLBase(InvokeDDLBase): + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + target.dispatch.before_drop( + target, self.connection, _ddl_runner=self, **kw + ) + yield + target.dispatch.after_drop( + target, self.connection, _ddl_runner=self, **kw + ) + + +class SchemaGenerator(InvokeCreateDDLBase): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def _can_create_table(self, table): + self.dialect.validate_identifier(table.name) + effective_schema = self.connection.schema_for_object(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or not self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) + + def _can_create_index(self, index): + effective_schema = self.connection.schema_for_object(index.table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or not self.dialect.has_index( + self.connection, + index.table.name, + index.name, + schema=effective_schema, + ) + + def _can_create_sequence(self, sequence): + effective_schema = self.connection.schema_for_object(sequence) + + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or not self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + + collection = sort_tables_and_constraints( + [t for t in tables if self._can_create_table(t)] + ) + + seq_coll = [ + s + for s in metadata._sequences.values() + if s.column is None and self._can_create_sequence(s) + ] + + event_collection = [t for (t, fks) in collection if t is not None] + + with self.with_ddl_events( + metadata, + tables=event_collection, + checkfirst=self.checkfirst, + ): + for seq in seq_coll: + self.traverse_single(seq, create_ok=True) + + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + create_ok=True, + include_foreign_key_constraints=fkcs, + _is_metadata_operation=True, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) + + def visit_table( + self, + table, + create_ok=False, + include_foreign_key_constraints=None, + _is_metadata_operation=False, + ): + if not create_ok and not self._can_create_table(table): + return + + with self.with_ddl_events( + table, + checkfirst=self.checkfirst, + _is_metadata_operation=_is_metadata_operation, + ): + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) + + if not self.dialect.supports_alter: + # e.g., don't omit any foreign key constraints + include_foreign_key_constraints = None + + CreateTable( + table, + include_foreign_key_constraints=( + include_foreign_key_constraints + ), + )._invoke_with(self.connection) + + if hasattr(table, "indexes"): + for index in table.indexes: + self.traverse_single(index, create_ok=True) + + if ( + self.dialect.supports_comments + and not self.dialect.inline_comments + ): + if table.comment is not None: + SetTableComment(table)._invoke_with(self.connection) + + for column in table.columns: + if column.comment is not None: + SetColumnComment(column)._invoke_with(self.connection) + + if self.dialect.supports_constraint_comments: + for constraint in table.constraints: + if constraint.comment is not None: + self.connection.execute( + SetConstraintComment(constraint) + ) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + + with self.with_ddl_events(constraint): + AddConstraint(constraint)._invoke_with(self.connection) + + def visit_sequence(self, sequence, create_ok=False): + if not create_ok and not self._can_create_sequence(sequence): + return + with self.with_ddl_events(sequence): + CreateSequence(sequence)._invoke_with(self.connection) + + def visit_index(self, index, create_ok=False): + if not create_ok and not self._can_create_index(index): + return + with self.with_ddl_events(index): + CreateIndex(index)._invoke_with(self.connection) + + +class SchemaDropper(InvokeDropDDLBase): + def __init__( + self, dialect, connection, checkfirst=False, tables=None, **kwargs + ): + super().__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.tables = tables + self.preparer = dialect.identifier_preparer + self.dialect = dialect + self.memo = {} + + def visit_metadata(self, metadata): + if self.tables is not None: + tables = self.tables + else: + tables = list(metadata.tables.values()) + + try: + unsorted_tables = [t for t in tables if self._can_drop_table(t)] + collection = list( + reversed( + sort_tables_and_constraints( + unsorted_tables, + filter_fn=lambda constraint: ( + False + if not self.dialect.supports_alter + or constraint.name is None + else None + ), + ) + ) + ) + except exc.CircularDependencyError as err2: + if not self.dialect.supports_alter: + util.warn( + "Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s; and backend does " + "not support ALTER. To restore at least a partial sort, " + "apply use_alter=True to ForeignKey and " + "ForeignKeyConstraint " + "objects involved in the cycle to mark these as known " + "cycles that will be ignored." + % (", ".join(sorted([t.fullname for t in err2.cycles]))) + ) + collection = [(t, ()) for t in unsorted_tables] + else: + raise exc.CircularDependencyError( + err2.args[0], + err2.cycles, + err2.edges, + msg="Can't sort tables for DROP; an " + "unresolvable foreign key " + "dependency exists between tables: %s. Please ensure " + "that the ForeignKey and ForeignKeyConstraint objects " + "involved in the cycle have " + "names so that they can be dropped using " + "DROP CONSTRAINT." + % (", ".join(sorted([t.fullname for t in err2.cycles]))), + ) from err2 + + seq_coll = [ + s + for s in metadata._sequences.values() + if self._can_drop_sequence(s) + ] + + event_collection = [t for (t, fks) in collection if t is not None] + + with self.with_ddl_events( + metadata, + tables=event_collection, + checkfirst=self.checkfirst, + ): + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + drop_ok=True, + _is_metadata_operation=True, + _ignore_sequences=seq_coll, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) + + for seq in seq_coll: + self.traverse_single(seq, drop_ok=seq.column is None) + + def _can_drop_table(self, table): + self.dialect.validate_identifier(table.name) + effective_schema = self.connection.schema_for_object(table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or self.dialect.has_table( + self.connection, table.name, schema=effective_schema + ) + + def _can_drop_index(self, index): + effective_schema = self.connection.schema_for_object(index.table) + if effective_schema: + self.dialect.validate_identifier(effective_schema) + return not self.checkfirst or self.dialect.has_index( + self.connection, + index.table.name, + index.name, + schema=effective_schema, + ) + + def _can_drop_sequence(self, sequence): + effective_schema = self.connection.schema_for_object(sequence) + return self.dialect.supports_sequences and ( + (not self.dialect.sequences_optional or not sequence.optional) + and ( + not self.checkfirst + or self.dialect.has_sequence( + self.connection, sequence.name, schema=effective_schema + ) + ) + ) + + def visit_index(self, index, drop_ok=False): + if not drop_ok and not self._can_drop_index(index): + return + + with self.with_ddl_events(index): + DropIndex(index)(index, self.connection) + + def visit_table( + self, + table, + drop_ok=False, + _is_metadata_operation=False, + _ignore_sequences=(), + ): + if not drop_ok and not self._can_drop_table(table): + return + + with self.with_ddl_events( + table, + checkfirst=self.checkfirst, + _is_metadata_operation=_is_metadata_operation, + ): + DropTable(table)._invoke_with(self.connection) + + # traverse client side defaults which may refer to server-side + # sequences. noting that some of these client side defaults may + # also be set up as server side defaults + # (see https://docs.sqlalchemy.org/en/ + # latest/core/defaults.html + # #associating-a-sequence-as-the-server-side- + # default), so have to be dropped after the table is dropped. + for column in table.columns: + if ( + column.default is not None + and column.default not in _ignore_sequences + ): + self.traverse_single(column.default) + + def visit_foreign_key_constraint(self, constraint): + if not self.dialect.supports_alter: + return + with self.with_ddl_events(constraint): + DropConstraint(constraint)._invoke_with(self.connection) + + def visit_sequence(self, sequence, drop_ok=False): + if not drop_ok and not self._can_drop_sequence(sequence): + return + with self.with_ddl_events(sequence): + DropSequence(sequence)._invoke_with(self.connection) + + +def sort_tables( + tables: Iterable[TableClause], + skip_fn: Optional[Callable[[ForeignKeyConstraint], bool]] = None, + extra_dependencies: Optional[ + typing_Sequence[Tuple[TableClause, TableClause]] + ] = None, +) -> List[Table]: + """Sort a collection of :class:`_schema.Table` objects based on + dependency. + + This is a dependency-ordered sort which will emit :class:`_schema.Table` + objects such that they will follow their dependent :class:`_schema.Table` + objects. + Tables are dependent on another based on the presence of + :class:`_schema.ForeignKeyConstraint` + objects as well as explicit dependencies + added by :meth:`_schema.Table.add_is_dependent_on`. + + .. warning:: + + The :func:`._schema.sort_tables` function cannot by itself + accommodate automatic resolution of dependency cycles between + tables, which are usually caused by mutually dependent foreign key + constraints. When these cycles are detected, the foreign keys + of these tables are omitted from consideration in the sort. + A warning is emitted when this condition occurs, which will be an + exception raise in a future release. Tables which are not part + of the cycle will still be returned in dependency order. + + To resolve these cycles, the + :paramref:`_schema.ForeignKeyConstraint.use_alter` parameter may be + applied to those constraints which create a cycle. Alternatively, + the :func:`_schema.sort_tables_and_constraints` function will + automatically return foreign key constraints in a separate + collection when cycles are detected so that they may be applied + to a schema separately. + + .. versionchanged:: 1.3.17 - a warning is emitted when + :func:`_schema.sort_tables` cannot perform a proper sort due to + cyclical dependencies. This will be an exception in a future + release. Additionally, the sort will continue to return + other tables not involved in the cycle in dependency order + which was not the case previously. + + :param tables: a sequence of :class:`_schema.Table` objects. + + :param skip_fn: optional callable which will be passed a + :class:`_schema.ForeignKeyConstraint` object; if it returns True, this + constraint will not be considered as a dependency. Note this is + **different** from the same parameter in + :func:`.sort_tables_and_constraints`, which is + instead passed the owning :class:`_schema.ForeignKeyConstraint` object. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. seealso:: + + :func:`.sort_tables_and_constraints` + + :attr:`_schema.MetaData.sorted_tables` - uses this function to sort + + + """ + + if skip_fn is not None: + fixed_skip_fn = skip_fn + + def _skip_fn(fkc): + for fk in fkc.elements: + if fixed_skip_fn(fk): + return True + else: + return None + + else: + _skip_fn = None # type: ignore + + return [ + t + for (t, fkcs) in sort_tables_and_constraints( + tables, + filter_fn=_skip_fn, + extra_dependencies=extra_dependencies, + _warn_for_cycles=True, + ) + if t is not None + ] + + +def sort_tables_and_constraints( + tables, filter_fn=None, extra_dependencies=None, _warn_for_cycles=False +): + """Sort a collection of :class:`_schema.Table` / + :class:`_schema.ForeignKeyConstraint` + objects. + + This is a dependency-ordered sort which will emit tuples of + ``(Table, [ForeignKeyConstraint, ...])`` such that each + :class:`_schema.Table` follows its dependent :class:`_schema.Table` + objects. + Remaining :class:`_schema.ForeignKeyConstraint` + objects that are separate due to + dependency rules not satisfied by the sort are emitted afterwards + as ``(None, [ForeignKeyConstraint ...])``. + + Tables are dependent on another based on the presence of + :class:`_schema.ForeignKeyConstraint` objects, explicit dependencies + added by :meth:`_schema.Table.add_is_dependent_on`, + as well as dependencies + stated here using the :paramref:`~.sort_tables_and_constraints.skip_fn` + and/or :paramref:`~.sort_tables_and_constraints.extra_dependencies` + parameters. + + :param tables: a sequence of :class:`_schema.Table` objects. + + :param filter_fn: optional callable which will be passed a + :class:`_schema.ForeignKeyConstraint` object, + and returns a value based on + whether this constraint should definitely be included or excluded as + an inline constraint, or neither. If it returns False, the constraint + will definitely be included as a dependency that cannot be subject + to ALTER; if True, it will **only** be included as an ALTER result at + the end. Returning None means the constraint is included in the + table-based result unless it is detected as part of a dependency cycle. + + :param extra_dependencies: a sequence of 2-tuples of tables which will + also be considered as dependent on each other. + + .. seealso:: + + :func:`.sort_tables` + + + """ + + fixed_dependencies = set() + mutable_dependencies = set() + + if extra_dependencies is not None: + fixed_dependencies.update(extra_dependencies) + + remaining_fkcs = set() + for table in tables: + for fkc in table.foreign_key_constraints: + if fkc.use_alter is True: + remaining_fkcs.add(fkc) + continue + + if filter_fn: + filtered = filter_fn(fkc) + + if filtered is True: + remaining_fkcs.add(fkc) + continue + + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.add((dependent_on, table)) + + fixed_dependencies.update( + (parent, table) for parent in table._extra_dependencies + ) + + try: + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), + tables, + ) + ) + except exc.CircularDependencyError as err: + if _warn_for_cycles: + util.warn( + "Cannot correctly sort tables; there are unresolvable cycles " + 'between tables "%s", which is usually caused by mutually ' + "dependent foreign key constraints. Foreign key constraints " + "involving these tables will not be considered; this warning " + "may raise an error in a future release." + % (", ".join(sorted(t.fullname for t in err.cycles)),) + ) + for edge in err.edges: + if edge in mutable_dependencies: + table = edge[1] + if table not in err.cycles: + continue + can_remove = [ + fkc + for fkc in table.foreign_key_constraints + if filter_fn is None or filter_fn(fkc) is not False + ] + remaining_fkcs.update(can_remove) + for fkc in can_remove: + dependent_on = fkc.referred_table + if dependent_on is not table: + mutable_dependencies.discard((dependent_on, table)) + candidate_sort = list( + topological.sort( + fixed_dependencies.union(mutable_dependencies), + tables, + ) + ) + + return [ + (table, table.foreign_key_constraints.difference(remaining_fkcs)) + for table in candidate_sort + ] + [(None, list(remaining_fkcs))] diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/default_comparator.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/default_comparator.py new file mode 100644 index 0000000..76131bc --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/default_comparator.py @@ -0,0 +1,552 @@ +# sql/default_comparator.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""Default implementation of SQL comparison operations. +""" + +from __future__ import annotations + +import typing +from typing import Any +from typing import Callable +from typing import Dict +from typing import NoReturn +from typing import Optional +from typing import Tuple +from typing import Type +from typing import Union + +from . import coercions +from . import operators +from . import roles +from . import type_api +from .elements import and_ +from .elements import BinaryExpression +from .elements import ClauseElement +from .elements import CollationClause +from .elements import CollectionAggregate +from .elements import ExpressionClauseList +from .elements import False_ +from .elements import Null +from .elements import OperatorExpression +from .elements import or_ +from .elements import True_ +from .elements import UnaryExpression +from .operators import OperatorType +from .. import exc +from .. import util + +_T = typing.TypeVar("_T", bound=Any) + +if typing.TYPE_CHECKING: + from .elements import ColumnElement + from .operators import custom_op + from .type_api import TypeEngine + + +def _boolean_compare( + expr: ColumnElement[Any], + op: OperatorType, + obj: Any, + *, + negate_op: Optional[OperatorType] = None, + reverse: bool = False, + _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), + result_type: Optional[TypeEngine[bool]] = None, + **kwargs: Any, +) -> OperatorExpression[bool]: + if result_type is None: + result_type = type_api.BOOLEANTYPE + + if isinstance(obj, _python_is_types + (Null, True_, False_)): + # allow x ==/!= True/False to be treated as a literal. + # this comes out to "== / != true/false" or "1/0" if those + # constants aren't supported and works on all platforms + if op in (operators.eq, operators.ne) and isinstance( + obj, (bool, True_, False_) + ): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + elif op in ( + operators.is_distinct_from, + operators.is_not_distinct_from, + ): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + elif expr._is_collection_aggregate: + obj = coercions.expect( + roles.ConstExprRole, element=obj, operator=op, expr=expr + ) + else: + # all other None uses IS, IS NOT + if op in (operators.eq, operators.is_): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + operators.is_, + negate=operators.is_not, + type_=result_type, + ) + elif op in (operators.ne, operators.is_not): + return OperatorExpression._construct_for_op( + expr, + coercions.expect(roles.ConstExprRole, obj), + operators.is_not, + negate=operators.is_, + type_=result_type, + ) + else: + raise exc.ArgumentError( + "Only '=', '!=', 'is_()', 'is_not()', " + "'is_distinct_from()', 'is_not_distinct_from()' " + "operators can be used with None/True/False" + ) + else: + obj = coercions.expect( + roles.BinaryElementRole, element=obj, operator=op, expr=expr + ) + + if reverse: + return OperatorExpression._construct_for_op( + obj, + expr, + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + else: + return OperatorExpression._construct_for_op( + expr, + obj, + op, + type_=result_type, + negate=negate_op, + modifiers=kwargs, + ) + + +def _custom_op_operate( + expr: ColumnElement[Any], + op: custom_op[Any], + obj: Any, + reverse: bool = False, + result_type: Optional[TypeEngine[Any]] = None, + **kw: Any, +) -> ColumnElement[Any]: + if result_type is None: + if op.return_type: + result_type = op.return_type + elif op.is_comparison: + result_type = type_api.BOOLEANTYPE + + return _binary_operate( + expr, op, obj, reverse=reverse, result_type=result_type, **kw + ) + + +def _binary_operate( + expr: ColumnElement[Any], + op: OperatorType, + obj: roles.BinaryElementRole[Any], + *, + reverse: bool = False, + result_type: Optional[TypeEngine[_T]] = None, + **kw: Any, +) -> OperatorExpression[_T]: + coerced_obj = coercions.expect( + roles.BinaryElementRole, obj, expr=expr, operator=op + ) + + if reverse: + left, right = coerced_obj, expr + else: + left, right = expr, coerced_obj + + if result_type is None: + op, result_type = left.comparator._adapt_expression( + op, right.comparator + ) + + return OperatorExpression._construct_for_op( + left, right, op, type_=result_type, modifiers=kw + ) + + +def _conjunction_operate( + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any +) -> ColumnElement[Any]: + if op is operators.and_: + return and_(expr, other) + elif op is operators.or_: + return or_(expr, other) + else: + raise NotImplementedError() + + +def _scalar( + expr: ColumnElement[Any], + op: OperatorType, + fn: Callable[[ColumnElement[Any]], ColumnElement[Any]], + **kw: Any, +) -> ColumnElement[Any]: + return fn(expr) + + +def _in_impl( + expr: ColumnElement[Any], + op: OperatorType, + seq_or_selectable: ClauseElement, + negate_op: OperatorType, + **kw: Any, +) -> ColumnElement[Any]: + seq_or_selectable = coercions.expect( + roles.InElementRole, seq_or_selectable, expr=expr, operator=op + ) + if "in_ops" in seq_or_selectable._annotations: + op, negate_op = seq_or_selectable._annotations["in_ops"] + + return _boolean_compare( + expr, op, seq_or_selectable, negate_op=negate_op, **kw + ) + + +def _getitem_impl( + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any +) -> ColumnElement[Any]: + if ( + isinstance(expr.type, type_api.INDEXABLE) + or isinstance(expr.type, type_api.TypeDecorator) + and isinstance(expr.type.impl_instance, type_api.INDEXABLE) + ): + other = coercions.expect( + roles.BinaryElementRole, other, expr=expr, operator=op + ) + return _binary_operate(expr, op, other, **kw) + else: + _unsupported_impl(expr, op, other, **kw) + + +def _unsupported_impl( + expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any +) -> NoReturn: + raise NotImplementedError( + "Operator '%s' is not supported on this expression" % op.__name__ + ) + + +def _inv_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.__inv__`.""" + + # undocumented element currently used by the ORM for + # relationship.contains() + if hasattr(expr, "negation_clause"): + return expr.negation_clause + else: + return expr._negate() + + +def _neg_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.__neg__`.""" + return UnaryExpression(expr, operator=operators.neg, type_=expr.type) + + +def _bitwise_not_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.bitwise_not`.""" + + return UnaryExpression( + expr, operator=operators.bitwise_not_op, type_=expr.type + ) + + +def _match_impl( + expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.match`.""" + + return _boolean_compare( + expr, + operators.match_op, + coercions.expect( + roles.BinaryElementRole, + other, + expr=expr, + operator=operators.match_op, + ), + result_type=type_api.MATCHTYPE, + negate_op=( + operators.not_match_op + if op is operators.match_op + else operators.match_op + ), + **kw, + ) + + +def _distinct_impl( + expr: ColumnElement[Any], op: OperatorType, **kw: Any +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.distinct`.""" + return UnaryExpression( + expr, operator=operators.distinct_op, type_=expr.type + ) + + +def _between_impl( + expr: ColumnElement[Any], + op: OperatorType, + cleft: Any, + cright: Any, + **kw: Any, +) -> ColumnElement[Any]: + """See :meth:`.ColumnOperators.between`.""" + return BinaryExpression( + expr, + ExpressionClauseList._construct_for_list( + operators.and_, + type_api.NULLTYPE, + coercions.expect( + roles.BinaryElementRole, + cleft, + expr=expr, + operator=operators.and_, + ), + coercions.expect( + roles.BinaryElementRole, + cright, + expr=expr, + operator=operators.and_, + ), + group=False, + ), + op, + negate=( + operators.not_between_op + if op is operators.between_op + else operators.between_op + ), + modifiers=kw, + ) + + +def _collate_impl( + expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any +) -> ColumnElement[str]: + return CollationClause._create_collation_expression(expr, collation) + + +def _regexp_match_impl( + expr: ColumnElement[str], + op: OperatorType, + pattern: Any, + flags: Optional[str], + **kw: Any, +) -> ColumnElement[Any]: + return BinaryExpression( + expr, + coercions.expect( + roles.BinaryElementRole, + pattern, + expr=expr, + operator=operators.comma_op, + ), + op, + negate=operators.not_regexp_match_op, + modifiers={"flags": flags}, + ) + + +def _regexp_replace_impl( + expr: ColumnElement[Any], + op: OperatorType, + pattern: Any, + replacement: Any, + flags: Optional[str], + **kw: Any, +) -> ColumnElement[Any]: + return BinaryExpression( + expr, + ExpressionClauseList._construct_for_list( + operators.comma_op, + type_api.NULLTYPE, + coercions.expect( + roles.BinaryElementRole, + pattern, + expr=expr, + operator=operators.comma_op, + ), + coercions.expect( + roles.BinaryElementRole, + replacement, + expr=expr, + operator=operators.comma_op, + ), + group=False, + ), + op, + modifiers={"flags": flags}, + ) + + +# a mapping of operators with the method they use, along with +# additional keyword arguments to be passed +operator_lookup: Dict[ + str, + Tuple[ + Callable[..., ColumnElement[Any]], + util.immutabledict[ + str, Union[OperatorType, Callable[..., ColumnElement[Any]]] + ], + ], +] = { + "and_": (_conjunction_operate, util.EMPTY_DICT), + "or_": (_conjunction_operate, util.EMPTY_DICT), + "inv": (_inv_impl, util.EMPTY_DICT), + "add": (_binary_operate, util.EMPTY_DICT), + "mul": (_binary_operate, util.EMPTY_DICT), + "sub": (_binary_operate, util.EMPTY_DICT), + "div": (_binary_operate, util.EMPTY_DICT), + "mod": (_binary_operate, util.EMPTY_DICT), + "bitwise_xor_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_or_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_and_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_not_op": (_bitwise_not_impl, util.EMPTY_DICT), + "bitwise_lshift_op": (_binary_operate, util.EMPTY_DICT), + "bitwise_rshift_op": (_binary_operate, util.EMPTY_DICT), + "truediv": (_binary_operate, util.EMPTY_DICT), + "floordiv": (_binary_operate, util.EMPTY_DICT), + "custom_op": (_custom_op_operate, util.EMPTY_DICT), + "json_path_getitem_op": (_binary_operate, util.EMPTY_DICT), + "json_getitem_op": (_binary_operate, util.EMPTY_DICT), + "concat_op": (_binary_operate, util.EMPTY_DICT), + "any_op": ( + _scalar, + util.immutabledict({"fn": CollectionAggregate._create_any}), + ), + "all_op": ( + _scalar, + util.immutabledict({"fn": CollectionAggregate._create_all}), + ), + "lt": (_boolean_compare, util.immutabledict({"negate_op": operators.ge})), + "le": (_boolean_compare, util.immutabledict({"negate_op": operators.gt})), + "ne": (_boolean_compare, util.immutabledict({"negate_op": operators.eq})), + "gt": (_boolean_compare, util.immutabledict({"negate_op": operators.le})), + "ge": (_boolean_compare, util.immutabledict({"negate_op": operators.lt})), + "eq": (_boolean_compare, util.immutabledict({"negate_op": operators.ne})), + "is_distinct_from": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_not_distinct_from}), + ), + "is_not_distinct_from": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_distinct_from}), + ), + "like_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_like_op}), + ), + "ilike_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_ilike_op}), + ), + "not_like_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.like_op}), + ), + "not_ilike_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.ilike_op}), + ), + "contains_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_contains_op}), + ), + "icontains_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_icontains_op}), + ), + "startswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_startswith_op}), + ), + "istartswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_istartswith_op}), + ), + "endswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_endswith_op}), + ), + "iendswith_op": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.not_iendswith_op}), + ), + "desc_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_desc}), + ), + "asc_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_asc}), + ), + "nulls_first_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_nulls_first}), + ), + "nulls_last_op": ( + _scalar, + util.immutabledict({"fn": UnaryExpression._create_nulls_last}), + ), + "in_op": ( + _in_impl, + util.immutabledict({"negate_op": operators.not_in_op}), + ), + "not_in_op": ( + _in_impl, + util.immutabledict({"negate_op": operators.in_op}), + ), + "is_": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_}), + ), + "is_not": ( + _boolean_compare, + util.immutabledict({"negate_op": operators.is_not}), + ), + "collate": (_collate_impl, util.EMPTY_DICT), + "match_op": (_match_impl, util.EMPTY_DICT), + "not_match_op": (_match_impl, util.EMPTY_DICT), + "distinct_op": (_distinct_impl, util.EMPTY_DICT), + "between_op": (_between_impl, util.EMPTY_DICT), + "not_between_op": (_between_impl, util.EMPTY_DICT), + "neg": (_neg_impl, util.EMPTY_DICT), + "getitem": (_getitem_impl, util.EMPTY_DICT), + "lshift": (_unsupported_impl, util.EMPTY_DICT), + "rshift": (_unsupported_impl, util.EMPTY_DICT), + "contains": (_unsupported_impl, util.EMPTY_DICT), + "regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), + "not_regexp_match_op": (_regexp_match_impl, util.EMPTY_DICT), + "regexp_replace_op": (_regexp_replace_impl, util.EMPTY_DICT), +} diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/dml.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/dml.py new file mode 100644 index 0000000..779be1d --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/dml.py @@ -0,0 +1,1817 @@ +# sql/dml.py +# Copyright (C) 2009-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +""" +Provide :class:`_expression.Insert`, :class:`_expression.Update` and +:class:`_expression.Delete`. + +""" +from __future__ import annotations + +import collections.abc as collections_abc +import operator +from typing import Any +from typing import cast +from typing import Dict +from typing import Iterable +from typing import List +from typing import MutableMapping +from typing import NoReturn +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import roles +from . import util as sql_util +from ._typing import _TP +from ._typing import _unexpected_kw +from ._typing import is_column_element +from ._typing import is_named_from_clause +from .base import _entity_namespace_key +from .base import _exclusive_against +from .base import _from_objects +from .base import _generative +from .base import _select_iterables +from .base import ColumnCollection +from .base import CompileState +from .base import DialectKWArgs +from .base import Executable +from .base import Generative +from .base import HasCompileState +from .elements import BooleanClauseList +from .elements import ClauseElement +from .elements import ColumnClause +from .elements import ColumnElement +from .elements import Null +from .selectable import Alias +from .selectable import ExecutableReturnsRows +from .selectable import FromClause +from .selectable import HasCTE +from .selectable import HasPrefixes +from .selectable import Join +from .selectable import SelectLabelStyle +from .selectable import TableClause +from .selectable import TypedReturnsRows +from .sqltypes import NullType +from .visitors import InternalTraversal +from .. import exc +from .. import util +from ..util.typing import Self +from ..util.typing import TypeGuard + +if TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnsClauseArgument + from ._typing import _DMLColumnArgument + from ._typing import _DMLColumnKeyMapping + from ._typing import _DMLTableArgument + from ._typing import _T0 # noqa + from ._typing import _T1 # noqa + from ._typing import _T2 # noqa + from ._typing import _T3 # noqa + from ._typing import _T4 # noqa + from ._typing import _T5 # noqa + from ._typing import _T6 # noqa + from ._typing import _T7 # noqa + from ._typing import _TypedColumnClauseArgument as _TCCA # noqa + from .base import ReadOnlyColumnCollection + from .compiler import SQLCompiler + from .elements import KeyedColumnElement + from .selectable import _ColumnsClauseElement + from .selectable import _SelectIterable + from .selectable import Select + from .selectable import Selectable + + def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: ... + + def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: ... + + def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: ... + +else: + isupdate = operator.attrgetter("isupdate") + isdelete = operator.attrgetter("isdelete") + isinsert = operator.attrgetter("isinsert") + + +_T = TypeVar("_T", bound=Any) + +_DMLColumnElement = Union[str, ColumnClause[Any]] +_DMLTableElement = Union[TableClause, Alias, Join] + + +class DMLState(CompileState): + _no_parameters = True + _dict_parameters: Optional[MutableMapping[_DMLColumnElement, Any]] = None + _multi_parameters: Optional[ + List[MutableMapping[_DMLColumnElement, Any]] + ] = None + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + _parameter_ordering: Optional[List[_DMLColumnElement]] = None + _primary_table: FromClause + _supports_implicit_returning = True + + isupdate = False + isdelete = False + isinsert = False + + statement: UpdateBase + + def __init__( + self, statement: UpdateBase, compiler: SQLCompiler, **kw: Any + ): + raise NotImplementedError() + + @classmethod + def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]: + return { + "name": ( + statement.table.name + if is_named_from_clause(statement.table) + else None + ), + "table": statement.table, + } + + @classmethod + def get_returning_column_descriptions( + cls, statement: UpdateBase + ) -> List[Dict[str, Any]]: + return [ + { + "name": c.key, + "type": c.type, + "expr": c, + } + for c in statement._all_selected_columns + ] + + @property + def dml_table(self) -> _DMLTableElement: + return self.statement.table + + if TYPE_CHECKING: + + @classmethod + def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: ... + + @classmethod + def _get_multi_crud_kv_pairs( + cls, + statement: UpdateBase, + multi_kv_iterator: Iterable[Dict[_DMLColumnArgument, Any]], + ) -> List[Dict[_DMLColumnElement, Any]]: + return [ + { + coercions.expect(roles.DMLColumnRole, k): v + for k, v in mapping.items() + } + for mapping in multi_kv_iterator + ] + + @classmethod + def _get_crud_kv_pairs( + cls, + statement: UpdateBase, + kv_iterator: Iterable[Tuple[_DMLColumnArgument, Any]], + needs_to_be_cacheable: bool, + ) -> List[Tuple[_DMLColumnElement, Any]]: + return [ + ( + coercions.expect(roles.DMLColumnRole, k), + ( + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ) + ), + ) + for k, v in kv_iterator + ] + + def _make_extra_froms( + self, statement: DMLWhereBase + ) -> Tuple[FromClause, List[FromClause]]: + froms: List[FromClause] = [] + + all_tables = list(sql_util.tables_from_leftmost(statement.table)) + primary_table = all_tables[0] + seen = {primary_table} + + consider = statement._where_criteria + if self._dict_parameters: + consider += tuple(self._dict_parameters.values()) + + for crit in consider: + for item in _from_objects(crit): + if not seen.intersection(item._cloned_set): + froms.append(item) + seen.update(item._cloned_set) + + froms.extend(all_tables[1:]) + return primary_table, froms + + def _process_values(self, statement: ValuesBase) -> None: + if self._no_parameters: + self._dict_parameters = statement._values + self._no_parameters = False + + def _process_select_values(self, statement: ValuesBase) -> None: + assert statement._select_names is not None + parameters: MutableMapping[_DMLColumnElement, Any] = { + name: Null() for name in statement._select_names + } + + if self._no_parameters: + self._no_parameters = False + self._dict_parameters = parameters + else: + # this condition normally not reachable as the Insert + # does not allow this construction to occur + assert False, "This statement already has parameters" + + def _no_multi_values_supported(self, statement: ValuesBase) -> NoReturn: + raise exc.InvalidRequestError( + "%s construct does not support " + "multiple parameter sets." % statement.__visit_name__.upper() + ) + + def _cant_mix_formats_error(self) -> NoReturn: + raise exc.InvalidRequestError( + "Can't mix single and multiple VALUES " + "formats in one INSERT statement; one style appends to a " + "list while the other replaces values, so the intent is " + "ambiguous." + ) + + +@CompileState.plugin_for("default", "insert") +class InsertDMLState(DMLState): + isinsert = True + + include_table_with_column_exprs = False + + _has_multi_parameters = False + + def __init__( + self, + statement: Insert, + compiler: SQLCompiler, + disable_implicit_returning: bool = False, + **kw: Any, + ): + self.statement = statement + self._primary_table = statement.table + + if disable_implicit_returning: + self._supports_implicit_returning = False + + self.isinsert = True + if statement._select_names: + self._process_select_values(statement) + if statement._values is not None: + self._process_values(statement) + if statement._multi_values: + self._process_multi_values(statement) + + @util.memoized_property + def _insert_col_keys(self) -> List[str]: + # this is also done in crud.py -> _key_getters_for_crud_column + return [ + coercions.expect(roles.DMLColumnRole, col, as_key=True) + for col in self._dict_parameters or () + ] + + def _process_values(self, statement: ValuesBase) -> None: + if self._no_parameters: + self._has_multi_parameters = False + self._dict_parameters = statement._values + self._no_parameters = False + elif self._has_multi_parameters: + self._cant_mix_formats_error() + + def _process_multi_values(self, statement: ValuesBase) -> None: + for parameters in statement._multi_values: + multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ + ( + { + c.key: value + for c, value in zip(statement.table.c, parameter_set) + } + if isinstance(parameter_set, collections_abc.Sequence) + else parameter_set + ) + for parameter_set in parameters + ] + + if self._no_parameters: + self._no_parameters = False + self._has_multi_parameters = True + self._multi_parameters = multi_parameters + self._dict_parameters = self._multi_parameters[0] + elif not self._has_multi_parameters: + self._cant_mix_formats_error() + else: + assert self._multi_parameters + self._multi_parameters.extend(multi_parameters) + + +@CompileState.plugin_for("default", "update") +class UpdateDMLState(DMLState): + isupdate = True + + include_table_with_column_exprs = False + + def __init__(self, statement: Update, compiler: SQLCompiler, **kw: Any): + self.statement = statement + + self.isupdate = True + if statement._ordered_values is not None: + self._process_ordered_values(statement) + elif statement._values is not None: + self._process_values(statement) + elif statement._multi_values: + self._no_multi_values_supported(statement) + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef + + self.is_multitable = mt = ef + self.include_table_with_column_exprs = bool( + mt and compiler.render_table_with_column_in_update_from + ) + + def _process_ordered_values(self, statement: ValuesBase) -> None: + parameters = statement._ordered_values + + if self._no_parameters: + self._no_parameters = False + assert parameters is not None + self._dict_parameters = dict(parameters) + self._ordered_values = parameters + self._parameter_ordering = [key for key, value in parameters] + else: + raise exc.InvalidRequestError( + "Can only invoke ordered_values() once, and not mixed " + "with any other values() call" + ) + + +@CompileState.plugin_for("default", "delete") +class DeleteDMLState(DMLState): + isdelete = True + + def __init__(self, statement: Delete, compiler: SQLCompiler, **kw: Any): + self.statement = statement + + self.isdelete = True + t, ef = self._make_extra_froms(statement) + self._primary_table = t + self._extra_froms = ef + self.is_multitable = ef + + +class UpdateBase( + roles.DMLRole, + HasCTE, + HasCompileState, + DialectKWArgs, + HasPrefixes, + Generative, + ExecutableReturnsRows, + ClauseElement, +): + """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" + + __visit_name__ = "update_base" + + _hints: util.immutabledict[Tuple[_DMLTableElement, str], str] = ( + util.EMPTY_DICT + ) + named_with_column = False + + _label_style: SelectLabelStyle = ( + SelectLabelStyle.LABEL_STYLE_DISAMBIGUATE_ONLY + ) + table: _DMLTableElement + + _return_defaults = False + _return_defaults_columns: Optional[Tuple[_ColumnsClauseElement, ...]] = ( + None + ) + _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None + _returning: Tuple[_ColumnsClauseElement, ...] = () + + is_dml = True + + def _generate_fromclause_column_proxies( + self, fromclause: FromClause + ) -> None: + fromclause._columns._populate_separate_keys( + col._make_proxy(fromclause) + for col in self._all_selected_columns + if is_column_element(col) + ) + + def params(self, *arg: Any, **kw: Any) -> NoReturn: + """Set the parameters for the statement. + + This method raises ``NotImplementedError`` on the base class, + and is overridden by :class:`.ValuesBase` to provide the + SET/VALUES clause of UPDATE and INSERT. + + """ + raise NotImplementedError( + "params() is not supported for INSERT/UPDATE/DELETE statements." + " To set the values for an INSERT or UPDATE statement, use" + " stmt.values(**parameters)." + ) + + @_generative + def with_dialect_options(self, **opt: Any) -> Self: + """Add dialect options to this INSERT/UPDATE/DELETE object. + + e.g.:: + + upd = table.update().dialect_options(mysql_limit=10) + + .. versionadded: 1.4 - this method supersedes the dialect options + associated with the constructor. + + + """ + self._validate_dialect_kwargs(opt) + return self + + @_generative + def return_defaults( + self, + *cols: _DMLColumnArgument, + supplemental_cols: Optional[Iterable[_DMLColumnArgument]] = None, + sort_by_parameter_order: bool = False, + ) -> Self: + """Make use of a :term:`RETURNING` clause for the purpose + of fetching server-side expressions and defaults, for supporting + backends only. + + .. deepalchemy:: + + The :meth:`.UpdateBase.return_defaults` method is used by the ORM + for its internal work in fetching newly generated primary key + and server default values, in particular to provide the underyling + implementation of the :paramref:`_orm.Mapper.eager_defaults` + ORM feature as well as to allow RETURNING support with bulk + ORM inserts. Its behavior is fairly idiosyncratic + and is not really intended for general use. End users should + stick with using :meth:`.UpdateBase.returning` in order to + add RETURNING clauses to their INSERT, UPDATE and DELETE + statements. + + Normally, a single row INSERT statement will automatically populate the + :attr:`.CursorResult.inserted_primary_key` attribute when executed, + which stores the primary key of the row that was just inserted in the + form of a :class:`.Row` object with column names as named tuple keys + (and the :attr:`.Row._mapping` view fully populated as well). The + dialect in use chooses the strategy to use in order to populate this + data; if it was generated using server-side defaults and / or SQL + expressions, dialect-specific approaches such as ``cursor.lastrowid`` + or ``RETURNING`` are typically used to acquire the new primary key + value. + + However, when the statement is modified by calling + :meth:`.UpdateBase.return_defaults` before executing the statement, + additional behaviors take place **only** for backends that support + RETURNING and for :class:`.Table` objects that maintain the + :paramref:`.Table.implicit_returning` parameter at its default value of + ``True``. In these cases, when the :class:`.CursorResult` is returned + from the statement's execution, not only will + :attr:`.CursorResult.inserted_primary_key` be populated as always, the + :attr:`.CursorResult.returned_defaults` attribute will also be + populated with a :class:`.Row` named-tuple representing the full range + of server generated + values from that single row, including values for any columns that + specify :paramref:`_schema.Column.server_default` or which make use of + :paramref:`_schema.Column.default` using a SQL expression. + + When invoking INSERT statements with multiple rows using + :ref:`insertmanyvalues `, the + :meth:`.UpdateBase.return_defaults` modifier will have the effect of + the :attr:`_engine.CursorResult.inserted_primary_key_rows` and + :attr:`_engine.CursorResult.returned_defaults_rows` attributes being + fully populated with lists of :class:`.Row` objects representing newly + inserted primary key values as well as newly inserted server generated + values for each row inserted. The + :attr:`.CursorResult.inserted_primary_key` and + :attr:`.CursorResult.returned_defaults` attributes will also continue + to be populated with the first row of these two collections. + + If the backend does not support RETURNING or the :class:`.Table` in use + has disabled :paramref:`.Table.implicit_returning`, then no RETURNING + clause is added and no additional data is fetched, however the + INSERT, UPDATE or DELETE statement proceeds normally. + + E.g.:: + + stmt = table.insert().values(data='newdata').return_defaults() + + result = connection.execute(stmt) + + server_created_at = result.returned_defaults['created_at'] + + When used against an UPDATE statement + :meth:`.UpdateBase.return_defaults` instead looks for columns that + include :paramref:`_schema.Column.onupdate` or + :paramref:`_schema.Column.server_onupdate` parameters assigned, when + constructing the columns that will be included in the RETURNING clause + by default if explicit columns were not specified. When used against a + DELETE statement, no columns are included in RETURNING by default, they + instead must be specified explicitly as there are no columns that + normally change values when a DELETE statement proceeds. + + .. versionadded:: 2.0 :meth:`.UpdateBase.return_defaults` is supported + for DELETE statements also and has been moved from + :class:`.ValuesBase` to :class:`.UpdateBase`. + + The :meth:`.UpdateBase.return_defaults` method is mutually exclusive + against the :meth:`.UpdateBase.returning` method and errors will be + raised during the SQL compilation process if both are used at the same + time on one statement. The RETURNING clause of the INSERT, UPDATE or + DELETE statement is therefore controlled by only one of these methods + at a time. + + The :meth:`.UpdateBase.return_defaults` method differs from + :meth:`.UpdateBase.returning` in these ways: + + 1. :meth:`.UpdateBase.return_defaults` method causes the + :attr:`.CursorResult.returned_defaults` collection to be populated + with the first row from the RETURNING result. This attribute is not + populated when using :meth:`.UpdateBase.returning`. + + 2. :meth:`.UpdateBase.return_defaults` is compatible with existing + logic used to fetch auto-generated primary key values that are then + populated into the :attr:`.CursorResult.inserted_primary_key` + attribute. By contrast, using :meth:`.UpdateBase.returning` will + have the effect of the :attr:`.CursorResult.inserted_primary_key` + attribute being left unpopulated. + + 3. :meth:`.UpdateBase.return_defaults` can be called against any + backend. Backends that don't support RETURNING will skip the usage + of the feature, rather than raising an exception, *unless* + ``supplemental_cols`` is passed. The return value + of :attr:`_engine.CursorResult.returned_defaults` will be ``None`` + for backends that don't support RETURNING or for which the target + :class:`.Table` sets :paramref:`.Table.implicit_returning` to + ``False``. + + 4. An INSERT statement invoked with executemany() is supported if the + backend database driver supports the + :ref:`insertmanyvalues ` + feature which is now supported by most SQLAlchemy-included backends. + When executemany is used, the + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors + will return the inserted defaults and primary keys. + + .. versionadded:: 1.4 Added + :attr:`_engine.CursorResult.returned_defaults_rows` and + :attr:`_engine.CursorResult.inserted_primary_key_rows` accessors. + In version 2.0, the underlying implementation which fetches and + populates the data for these attributes was generalized to be + supported by most backends, whereas in 1.4 they were only + supported by the ``psycopg2`` driver. + + + :param cols: optional list of column key names or + :class:`_schema.Column` that acts as a filter for those columns that + will be fetched. + :param supplemental_cols: optional list of RETURNING expressions, + in the same form as one would pass to the + :meth:`.UpdateBase.returning` method. When present, the additional + columns will be included in the RETURNING clause, and the + :class:`.CursorResult` object will be "rewound" when returned, so + that methods like :meth:`.CursorResult.all` will return new rows + mostly as though the statement used :meth:`.UpdateBase.returning` + directly. However, unlike when using :meth:`.UpdateBase.returning` + directly, the **order of the columns is undefined**, so can only be + targeted using names or :attr:`.Row._mapping` keys; they cannot + reliably be targeted positionally. + + .. versionadded:: 2.0 + + :param sort_by_parameter_order: for a batch INSERT that is being + executed against multiple parameter sets, organize the results of + RETURNING so that the returned rows correspond to the order of + parameter sets passed in. This applies only to an :term:`executemany` + execution for supporting dialects and typically makes use of the + :term:`insertmanyvalues` feature. + + .. versionadded:: 2.0.10 + + .. seealso:: + + :ref:`engine_insertmanyvalues_returning_order` - background on + sorting of RETURNING rows for bulk INSERT + + .. seealso:: + + :meth:`.UpdateBase.returning` + + :attr:`_engine.CursorResult.returned_defaults` + + :attr:`_engine.CursorResult.returned_defaults_rows` + + :attr:`_engine.CursorResult.inserted_primary_key` + + :attr:`_engine.CursorResult.inserted_primary_key_rows` + + """ + + if self._return_defaults: + # note _return_defaults_columns = () means return all columns, + # so if we have been here before, only update collection if there + # are columns in the collection + if self._return_defaults_columns and cols: + self._return_defaults_columns = tuple( + util.OrderedSet(self._return_defaults_columns).union( + coercions.expect(roles.ColumnsClauseRole, c) + for c in cols + ) + ) + else: + # set for all columns + self._return_defaults_columns = () + else: + self._return_defaults_columns = tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) + self._return_defaults = True + if sort_by_parameter_order: + if not self.is_insert: + raise exc.ArgumentError( + "The 'sort_by_parameter_order' argument to " + "return_defaults() only applies to INSERT statements" + ) + self._sort_by_parameter_order = True + if supplemental_cols: + # uniquifying while also maintaining order (the maintain of order + # is for test suites but also for vertical splicing + supplemental_col_tup = ( + coercions.expect(roles.ColumnsClauseRole, c) + for c in supplemental_cols + ) + + if self._supplemental_returning is None: + self._supplemental_returning = tuple( + util.unique_list(supplemental_col_tup) + ) + else: + self._supplemental_returning = tuple( + util.unique_list( + self._supplemental_returning + + tuple(supplemental_col_tup) + ) + ) + + return self + + @_generative + def returning( + self, + *cols: _ColumnsClauseArgument[Any], + sort_by_parameter_order: bool = False, + **__kw: Any, + ) -> UpdateBase: + r"""Add a :term:`RETURNING` or equivalent clause to this statement. + + e.g.: + + .. sourcecode:: pycon+sql + + >>> stmt = ( + ... table.update() + ... .where(table.c.data == "value") + ... .values(status="X") + ... .returning(table.c.server_flag, table.c.updated_timestamp) + ... ) + >>> print(stmt) + {printsql}UPDATE some_table SET status=:status + WHERE some_table.data = :data_1 + RETURNING some_table.server_flag, some_table.updated_timestamp + + The method may be invoked multiple times to add new entries to the + list of expressions to be returned. + + .. versionadded:: 1.4.0b2 The method may be invoked multiple times to + add new entries to the list of expressions to be returned. + + The given collection of column expressions should be derived from the + table that is the target of the INSERT, UPDATE, or DELETE. While + :class:`_schema.Column` objects are typical, the elements can also be + expressions: + + .. sourcecode:: pycon+sql + + >>> stmt = table.insert().returning( + ... (table.c.first_name + " " + table.c.last_name).label("fullname") + ... ) + >>> print(stmt) + {printsql}INSERT INTO some_table (first_name, last_name) + VALUES (:first_name, :last_name) + RETURNING some_table.first_name || :first_name_1 || some_table.last_name AS fullname + + Upon compilation, a RETURNING clause, or database equivalent, + will be rendered within the statement. For INSERT and UPDATE, + the values are the newly inserted/updated values. For DELETE, + the values are those of the rows which were deleted. + + Upon execution, the values of the columns to be returned are made + available via the result set and can be iterated using + :meth:`_engine.CursorResult.fetchone` and similar. + For DBAPIs which do not + natively support returning values (i.e. cx_oracle), SQLAlchemy will + approximate this behavior at the result level so that a reasonable + amount of behavioral neutrality is provided. + + Note that not all databases/DBAPIs + support RETURNING. For those backends with no support, + an exception is raised upon compilation and/or execution. + For those who do support it, the functionality across backends + varies greatly, including restrictions on executemany() + and other statements which return multiple rows. Please + read the documentation notes for the database in use in + order to determine the availability of RETURNING. + + :param \*cols: series of columns, SQL expressions, or whole tables + entities to be returned. + :param sort_by_parameter_order: for a batch INSERT that is being + executed against multiple parameter sets, organize the results of + RETURNING so that the returned rows correspond to the order of + parameter sets passed in. This applies only to an :term:`executemany` + execution for supporting dialects and typically makes use of the + :term:`insertmanyvalues` feature. + + .. versionadded:: 2.0.10 + + .. seealso:: + + :ref:`engine_insertmanyvalues_returning_order` - background on + sorting of RETURNING rows for bulk INSERT (Core level discussion) + + :ref:`orm_queryguide_bulk_insert_returning_ordered` - example of + use with :ref:`orm_queryguide_bulk_insert` (ORM level discussion) + + .. seealso:: + + :meth:`.UpdateBase.return_defaults` - an alternative method tailored + towards efficient fetching of server-side defaults and triggers + for single-row INSERTs or UPDATEs. + + :ref:`tutorial_insert_returning` - in the :ref:`unified_tutorial` + + """ # noqa: E501 + if __kw: + raise _unexpected_kw("UpdateBase.returning()", __kw) + if self._return_defaults: + raise exc.InvalidRequestError( + "return_defaults() is already configured on this statement" + ) + self._returning += tuple( + coercions.expect(roles.ColumnsClauseRole, c) for c in cols + ) + if sort_by_parameter_order: + if not self.is_insert: + raise exc.ArgumentError( + "The 'sort_by_parameter_order' argument to returning() " + "only applies to INSERT statements" + ) + self._sort_by_parameter_order = True + return self + + def corresponding_column( + self, column: KeyedColumnElement[Any], require_embedded: bool = False + ) -> Optional[ColumnElement[Any]]: + return self.exported_columns.corresponding_column( + column, require_embedded=require_embedded + ) + + @util.ro_memoized_property + def _all_selected_columns(self) -> _SelectIterable: + return [c for c in _select_iterables(self._returning)] + + @util.ro_memoized_property + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[Optional[str], ColumnElement[Any]]: + """Return the RETURNING columns as a column collection for this + statement. + + .. versionadded:: 1.4 + + """ + return ColumnCollection( + (c.key, c) + for c in self._all_selected_columns + if is_column_element(c) + ).as_readonly() + + @_generative + def with_hint( + self, + text: str, + selectable: Optional[_DMLTableArgument] = None, + dialect_name: str = "*", + ) -> Self: + """Add a table hint for a single table to this + INSERT/UPDATE/DELETE statement. + + .. note:: + + :meth:`.UpdateBase.with_hint` currently applies only to + Microsoft SQL Server. For MySQL INSERT/UPDATE/DELETE hints, use + :meth:`.UpdateBase.prefix_with`. + + The text of the hint is rendered in the appropriate + location for the database backend in use, relative + to the :class:`_schema.Table` that is the subject of this + statement, or optionally to that of the given + :class:`_schema.Table` passed as the ``selectable`` argument. + + The ``dialect_name`` option will limit the rendering of a particular + hint to a particular backend. Such as, to add a hint + that only takes effect for SQL Server:: + + mytable.insert().with_hint("WITH (PAGLOCK)", dialect_name="mssql") + + :param text: Text of the hint. + :param selectable: optional :class:`_schema.Table` that specifies + an element of the FROM clause within an UPDATE or DELETE + to be the subject of the hint - applies only to certain backends. + :param dialect_name: defaults to ``*``, if specified as the name + of a particular dialect, will apply these hints only when + that dialect is in use. + """ + if selectable is None: + selectable = self.table + else: + selectable = coercions.expect(roles.DMLTableRole, selectable) + self._hints = self._hints.union({(selectable, dialect_name): text}) + return self + + @property + def entity_description(self) -> Dict[str, Any]: + """Return a :term:`plugin-enabled` description of the table and/or + entity which this DML construct is operating against. + + This attribute is generally useful when using the ORM, as an + extended structure which includes information about mapped + entities is returned. The section :ref:`queryguide_inspection` + contains more background. + + For a Core statement, the structure returned by this accessor + is derived from the :attr:`.UpdateBase.table` attribute, and + refers to the :class:`.Table` being inserted, updated, or deleted:: + + >>> stmt = insert(user_table) + >>> stmt.entity_description + { + "name": "user_table", + "table": Table("user_table", ...) + } + + .. versionadded:: 1.4.33 + + .. seealso:: + + :attr:`.UpdateBase.returning_column_descriptions` + + :attr:`.Select.column_descriptions` - entity information for + a :func:`.select` construct + + :ref:`queryguide_inspection` - ORM background + + """ + meth = DMLState.get_plugin_class(self).get_entity_description + return meth(self) + + @property + def returning_column_descriptions(self) -> List[Dict[str, Any]]: + """Return a :term:`plugin-enabled` description of the columns + which this DML construct is RETURNING against, in other words + the expressions established as part of :meth:`.UpdateBase.returning`. + + This attribute is generally useful when using the ORM, as an + extended structure which includes information about mapped + entities is returned. The section :ref:`queryguide_inspection` + contains more background. + + For a Core statement, the structure returned by this accessor is + derived from the same objects that are returned by the + :attr:`.UpdateBase.exported_columns` accessor:: + + >>> stmt = insert(user_table).returning(user_table.c.id, user_table.c.name) + >>> stmt.entity_description + [ + { + "name": "id", + "type": Integer, + "expr": Column("id", Integer(), table=, ...) + }, + { + "name": "name", + "type": String(), + "expr": Column("name", String(), table=, ...) + }, + ] + + .. versionadded:: 1.4.33 + + .. seealso:: + + :attr:`.UpdateBase.entity_description` + + :attr:`.Select.column_descriptions` - entity information for + a :func:`.select` construct + + :ref:`queryguide_inspection` - ORM background + + """ # noqa: E501 + meth = DMLState.get_plugin_class( + self + ).get_returning_column_descriptions + return meth(self) + + +class ValuesBase(UpdateBase): + """Supplies support for :meth:`.ValuesBase.values` to + INSERT and UPDATE constructs.""" + + __visit_name__ = "values_base" + + _supports_multi_parameters = False + + select: Optional[Select[Any]] = None + """SELECT statement for INSERT .. FROM SELECT""" + + _post_values_clause: Optional[ClauseElement] = None + """used by extensions to Insert etc. to add additional syntacitcal + constructs, e.g. ON CONFLICT etc.""" + + _values: Optional[util.immutabledict[_DMLColumnElement, Any]] = None + _multi_values: Tuple[ + Union[ + Sequence[Dict[_DMLColumnElement, Any]], + Sequence[Sequence[Any]], + ], + ..., + ] = () + + _ordered_values: Optional[List[Tuple[_DMLColumnElement, Any]]] = None + + _select_names: Optional[List[str]] = None + _inline: bool = False + + def __init__(self, table: _DMLTableArgument): + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) + + @_generative + @_exclusive_against( + "_select_names", + "_ordered_values", + msgs={ + "_select_names": "This construct already inserts from a SELECT", + "_ordered_values": "This statement already has ordered " + "values present", + }, + ) + def values( + self, + *args: Union[ + _DMLColumnKeyMapping[Any], + Sequence[Any], + ], + **kwargs: Any, + ) -> Self: + r"""Specify a fixed VALUES clause for an INSERT statement, or the SET + clause for an UPDATE. + + Note that the :class:`_expression.Insert` and + :class:`_expression.Update` + constructs support + per-execution time formatting of the VALUES and/or SET clauses, + based on the arguments passed to :meth:`_engine.Connection.execute`. + However, the :meth:`.ValuesBase.values` method can be used to "fix" a + particular set of parameters into the statement. + + Multiple calls to :meth:`.ValuesBase.values` will produce a new + construct, each one with the parameter list modified to include + the new parameters sent. In the typical case of a single + dictionary of parameters, the newly passed keys will replace + the same keys in the previous construct. In the case of a list-based + "multiple values" construct, each new list of values is extended + onto the existing list of values. + + :param \**kwargs: key value pairs representing the string key + of a :class:`_schema.Column` + mapped to the value to be rendered into the + VALUES or SET clause:: + + users.insert().values(name="some name") + + users.update().where(users.c.id==5).values(name="some name") + + :param \*args: As an alternative to passing key/value parameters, + a dictionary, tuple, or list of dictionaries or tuples can be passed + as a single positional argument in order to form the VALUES or + SET clause of the statement. The forms that are accepted vary + based on whether this is an :class:`_expression.Insert` or an + :class:`_expression.Update` construct. + + For either an :class:`_expression.Insert` or + :class:`_expression.Update` + construct, a single dictionary can be passed, which works the same as + that of the kwargs form:: + + users.insert().values({"name": "some name"}) + + users.update().values({"name": "some new name"}) + + Also for either form but more typically for the + :class:`_expression.Insert` construct, a tuple that contains an + entry for every column in the table is also accepted:: + + users.insert().values((5, "some name")) + + The :class:`_expression.Insert` construct also supports being + passed a list of dictionaries or full-table-tuples, which on the + server will render the less common SQL syntax of "multiple values" - + this syntax is supported on backends such as SQLite, PostgreSQL, + MySQL, but not necessarily others:: + + users.insert().values([ + {"name": "some name"}, + {"name": "some other name"}, + {"name": "yet another name"}, + ]) + + The above form would render a multiple VALUES statement similar to:: + + INSERT INTO users (name) VALUES + (:name_1), + (:name_2), + (:name_3) + + It is essential to note that **passing multiple values is + NOT the same as using traditional executemany() form**. The above + syntax is a **special** syntax not typically used. To emit an + INSERT statement against multiple rows, the normal method is + to pass a multiple values list to the + :meth:`_engine.Connection.execute` + method, which is supported by all database backends and is generally + more efficient for a very large number of parameters. + + .. seealso:: + + :ref:`tutorial_multiple_parameters` - an introduction to + the traditional Core method of multiple parameter set + invocation for INSERTs and other statements. + + The UPDATE construct also supports rendering the SET parameters + in a specific order. For this feature refer to the + :meth:`_expression.Update.ordered_values` method. + + .. seealso:: + + :meth:`_expression.Update.ordered_values` + + + """ + if args: + # positional case. this is currently expensive. we don't + # yet have positional-only args so we have to check the length. + # then we need to check multiparams vs. single dictionary. + # since the parameter format is needed in order to determine + # a cache key, we need to determine this up front. + arg = args[0] + + if kwargs: + raise exc.ArgumentError( + "Can't pass positional and kwargs to values() " + "simultaneously" + ) + elif len(args) > 1: + raise exc.ArgumentError( + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally." + ) + + elif isinstance(arg, collections_abc.Sequence): + if arg and isinstance(arg[0], dict): + multi_kv_generator = DMLState.get_plugin_class( + self + )._get_multi_crud_kv_pairs + self._multi_values += (multi_kv_generator(self, arg),) + return self + + if arg and isinstance(arg[0], (list, tuple)): + self._multi_values += (arg,) + return self + + if TYPE_CHECKING: + # crud.py raises during compilation if this is not the + # case + assert isinstance(self, Insert) + + # tuple values + arg = {c.key: value for c, value in zip(self.table.c, arg)} + + else: + # kwarg path. this is the most common path for non-multi-params + # so this is fairly quick. + arg = cast("Dict[_DMLColumnArgument, Any]", kwargs) + if args: + raise exc.ArgumentError( + "Only a single dictionary/tuple or list of " + "dictionaries/tuples is accepted positionally." + ) + + # for top level values(), convert literals to anonymous bound + # parameters at statement construction time, so that these values can + # participate in the cache key process like any other ClauseElement. + # crud.py now intercepts bound parameters with unique=True from here + # and ensures they get the "crud"-style name when rendered. + + kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs + coerced_arg = dict(kv_generator(self, arg.items(), True)) + if self._values: + self._values = self._values.union(coerced_arg) + else: + self._values = util.immutabledict(coerced_arg) + return self + + +class Insert(ValuesBase): + """Represent an INSERT construct. + + The :class:`_expression.Insert` object is created using the + :func:`_expression.insert()` function. + + """ + + __visit_name__ = "insert" + + _supports_multi_parameters = True + + select = None + include_insert_from_select_defaults = False + + _sort_by_parameter_order: bool = False + + is_insert = True + + table: TableClause + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_inline", InternalTraversal.dp_boolean), + ("_select_names", InternalTraversal.dp_string_list), + ("_values", InternalTraversal.dp_dml_values), + ("_multi_values", InternalTraversal.dp_dml_multi_values), + ("select", InternalTraversal.dp_clauseelement), + ("_post_values_clause", InternalTraversal.dp_clauseelement), + ("_returning", InternalTraversal.dp_clauseelement_tuple), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_tuple, + ), + ("_sort_by_parameter_order", InternalTraversal.dp_boolean), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + def __init__(self, table: _DMLTableArgument): + super().__init__(table) + + @_generative + def inline(self) -> Self: + """Make this :class:`_expression.Insert` construct "inline" . + + When set, no attempt will be made to retrieve the + SQL-generated default values to be provided within the statement; + in particular, + this allows SQL expressions to be rendered 'inline' within the + statement without the need to pre-execute them beforehand; for + backends that support "returning", this turns off the "implicit + returning" feature for the statement. + + + .. versionchanged:: 1.4 the :paramref:`_expression.Insert.inline` + parameter + is now superseded by the :meth:`_expression.Insert.inline` method. + + """ + self._inline = True + return self + + @_generative + def from_select( + self, + names: Sequence[_DMLColumnArgument], + select: Selectable, + include_defaults: bool = True, + ) -> Self: + """Return a new :class:`_expression.Insert` construct which represents + an ``INSERT...FROM SELECT`` statement. + + e.g.:: + + sel = select(table1.c.a, table1.c.b).where(table1.c.c > 5) + ins = table2.insert().from_select(['a', 'b'], sel) + + :param names: a sequence of string column names or + :class:`_schema.Column` + objects representing the target columns. + :param select: a :func:`_expression.select` construct, + :class:`_expression.FromClause` + or other construct which resolves into a + :class:`_expression.FromClause`, + such as an ORM :class:`_query.Query` object, etc. The order of + columns returned from this FROM clause should correspond to the + order of columns sent as the ``names`` parameter; while this + is not checked before passing along to the database, the database + would normally raise an exception if these column lists don't + correspond. + :param include_defaults: if True, non-server default values and + SQL expressions as specified on :class:`_schema.Column` objects + (as documented in :ref:`metadata_defaults_toplevel`) not + otherwise specified in the list of names will be rendered + into the INSERT and SELECT statements, so that these values are also + included in the data to be inserted. + + .. note:: A Python-side default that uses a Python callable function + will only be invoked **once** for the whole statement, and **not + per row**. + + """ + + if self._values: + raise exc.InvalidRequestError( + "This construct already inserts value expressions" + ) + + self._select_names = [ + coercions.expect(roles.DMLColumnRole, name, as_key=True) + for name in names + ] + self._inline = True + self.include_insert_from_select_defaults = include_defaults + self.select = coercions.expect(roles.DMLSelectRole, select) + return self + + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8 ", *, sort_by_parameter_order: bool = False" # noqa: E501 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning( + self, __ent0: _TCCA[_T0], *, sort_by_parameter_order: bool = False + ) -> ReturningInsert[Tuple[_T0]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + *, + sort_by_parameter_order: bool = False, + ) -> ReturningInsert[Tuple[_T0, _T1]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + *, + sort_by_parameter_order: bool = False, + ) -> ReturningInsert[Tuple[_T0, _T1, _T2]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + *, + sort_by_parameter_order: bool = False, + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + *, + sort_by_parameter_order: bool = False, + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + *, + sort_by_parameter_order: bool = False, + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + *, + sort_by_parameter_order: bool = False, + ) -> ReturningInsert[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + *, + sort_by_parameter_order: bool = False, + ) -> ReturningInsert[ + Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7] + ]: ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, + *cols: _ColumnsClauseArgument[Any], + sort_by_parameter_order: bool = False, + **__kw: Any, + ) -> ReturningInsert[Any]: ... + + def returning( + self, + *cols: _ColumnsClauseArgument[Any], + sort_by_parameter_order: bool = False, + **__kw: Any, + ) -> ReturningInsert[Any]: ... + + +class ReturningInsert(Insert, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Insert` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Insert.returning` method. + + .. versionadded:: 2.0 + + """ + + +class DMLWhereBase: + table: _DMLTableElement + _where_criteria: Tuple[ColumnElement[Any], ...] = () + + @_generative + def where(self, *whereclause: _ColumnExpressionArgument[bool]) -> Self: + """Return a new construct with the given expression(s) added to + its WHERE clause, joined to the existing clause via AND, if any. + + Both :meth:`_dml.Update.where` and :meth:`_dml.Delete.where` + support multiple-table forms, including database-specific + ``UPDATE...FROM`` as well as ``DELETE..USING``. For backends that + don't have multiple-table support, a backend agnostic approach + to using multiple tables is to make use of correlated subqueries. + See the linked tutorial sections below for examples. + + .. seealso:: + + :ref:`tutorial_correlated_updates` + + :ref:`tutorial_update_from` + + :ref:`tutorial_multi_table_deletes` + + """ + + for criterion in whereclause: + where_criteria: ColumnElement[Any] = coercions.expect( + roles.WhereHavingRole, criterion, apply_propagate_attrs=self + ) + self._where_criteria += (where_criteria,) + return self + + def filter(self, *criteria: roles.ExpressionElementRole[Any]) -> Self: + """A synonym for the :meth:`_dml.DMLWhereBase.where` method. + + .. versionadded:: 1.4 + + """ + + return self.where(*criteria) + + def _filter_by_zero(self) -> _DMLTableElement: + return self.table + + def filter_by(self, **kwargs: Any) -> Self: + r"""apply the given filtering criterion as a WHERE clause + to this select. + + """ + from_entity = self._filter_by_zero() + + clauses = [ + _entity_namespace_key(from_entity, key) == value + for key, value in kwargs.items() + ] + return self.filter(*clauses) + + @property + def whereclause(self) -> Optional[ColumnElement[Any]]: + """Return the completed WHERE clause for this :class:`.DMLWhereBase` + statement. + + This assembles the current collection of WHERE criteria + into a single :class:`_expression.BooleanClauseList` construct. + + + .. versionadded:: 1.4 + + """ + + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) + + +class Update(DMLWhereBase, ValuesBase): + """Represent an Update construct. + + The :class:`_expression.Update` object is created using the + :func:`_expression.update()` function. + + """ + + __visit_name__ = "update" + + is_update = True + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_inline", InternalTraversal.dp_boolean), + ("_ordered_values", InternalTraversal.dp_dml_ordered_values), + ("_values", InternalTraversal.dp_dml_values), + ("_returning", InternalTraversal.dp_clauseelement_tuple), + ("_hints", InternalTraversal.dp_table_hint_list), + ("_return_defaults", InternalTraversal.dp_boolean), + ( + "_return_defaults_columns", + InternalTraversal.dp_clauseelement_tuple, + ), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + def __init__(self, table: _DMLTableArgument): + super().__init__(table) + + @_generative + def ordered_values(self, *args: Tuple[_DMLColumnArgument, Any]) -> Self: + """Specify the VALUES clause of this UPDATE statement with an explicit + parameter ordering that will be maintained in the SET clause of the + resulting UPDATE statement. + + E.g.:: + + stmt = table.update().ordered_values( + ("name", "ed"), ("ident", "foo") + ) + + .. seealso:: + + :ref:`tutorial_parameter_ordered_updates` - full example of the + :meth:`_expression.Update.ordered_values` method. + + .. versionchanged:: 1.4 The :meth:`_expression.Update.ordered_values` + method + supersedes the + :paramref:`_expression.update.preserve_parameter_order` + parameter, which will be removed in SQLAlchemy 2.0. + + """ + if self._values: + raise exc.ArgumentError( + "This statement already has values present" + ) + elif self._ordered_values: + raise exc.ArgumentError( + "This statement already has ordered values present" + ) + + kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs + self._ordered_values = kv_generator(self, args, True) + return self + + @_generative + def inline(self) -> Self: + """Make this :class:`_expression.Update` construct "inline" . + + When set, SQL defaults present on :class:`_schema.Column` + objects via the + ``default`` keyword will be compiled 'inline' into the statement and + not pre-executed. This means that their values will not be available + in the dictionary returned from + :meth:`_engine.CursorResult.last_updated_params`. + + .. versionchanged:: 1.4 the :paramref:`_expression.update.inline` + parameter + is now superseded by the :meth:`_expression.Update.inline` method. + + """ + self._inline = True + return self + + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningUpdate 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning( + self, __ent0: _TCCA[_T0] + ) -> ReturningUpdate[Tuple[_T0]]: ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningUpdate[Tuple[_T0, _T1]]: ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningUpdate[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningUpdate[ + Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7] + ]: ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningUpdate[Any]: ... + + +class ReturningUpdate(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Update` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Update.returning` method. + + .. versionadded:: 2.0 + + """ + + +class Delete(DMLWhereBase, UpdateBase): + """Represent a DELETE construct. + + The :class:`_expression.Delete` object is created using the + :func:`_expression.delete()` function. + + """ + + __visit_name__ = "delete" + + is_delete = True + + _traverse_internals = ( + [ + ("table", InternalTraversal.dp_clauseelement), + ("_where_criteria", InternalTraversal.dp_clauseelement_tuple), + ("_returning", InternalTraversal.dp_clauseelement_tuple), + ("_hints", InternalTraversal.dp_table_hint_list), + ] + + HasPrefixes._has_prefixes_traverse_internals + + DialectKWArgs._dialect_kwargs_traverse_internals + + Executable._executable_traverse_internals + + HasCTE._has_ctes_traverse_internals + ) + + def __init__(self, table: _DMLTableArgument): + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) + + if TYPE_CHECKING: + # START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8 + + # code within this block is **programmatically, + # statically generated** by tools/generate_tuple_map_overloads.py + + @overload + def returning( + self, __ent0: _TCCA[_T0] + ) -> ReturningDelete[Tuple[_T0]]: ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] + ) -> ReturningDelete[Tuple[_T0, _T1]]: ... + + @overload + def returning( + self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] + ) -> ReturningDelete[Tuple[_T0, _T1, _T2]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + ) -> ReturningDelete[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + + @overload + def returning( + self, + __ent0: _TCCA[_T0], + __ent1: _TCCA[_T1], + __ent2: _TCCA[_T2], + __ent3: _TCCA[_T3], + __ent4: _TCCA[_T4], + __ent5: _TCCA[_T5], + __ent6: _TCCA[_T6], + __ent7: _TCCA[_T7], + ) -> ReturningDelete[ + Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7] + ]: ... + + # END OVERLOADED FUNCTIONS self.returning + + @overload + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: ... + + def returning( + self, *cols: _ColumnsClauseArgument[Any], **__kw: Any + ) -> ReturningDelete[Any]: ... + + +class ReturningDelete(Update, TypedReturnsRows[_TP]): + """Typing-only class that establishes a generic type form of + :class:`.Delete` which tracks returned column types. + + This datatype is delivered when calling the + :meth:`.Delete.returning` method. + + .. versionadded:: 2.0 + + """ diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/sql/elements.py b/venv/lib/python3.11/site-packages/sqlalchemy/sql/elements.py new file mode 100644 index 0000000..bafb5c7 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/sql/elements.py @@ -0,0 +1,5405 @@ +# sql/elements.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: allow-untyped-defs, allow-untyped-calls + +"""Core SQL expression elements, including :class:`_expression.ClauseElement`, +:class:`_expression.ColumnElement`, and derived classes. + +""" + +from __future__ import annotations + +from decimal import Decimal +from enum import IntEnum +import itertools +import operator +import re +import typing +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import cast +from typing import Dict +from typing import FrozenSet +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import overload +from typing import Sequence +from typing import Set +from typing import Tuple as typing_Tuple +from typing import Type +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from . import coercions +from . import operators +from . import roles +from . import traversals +from . import type_api +from ._typing import has_schema_attr +from ._typing import is_named_from_clause +from ._typing import is_quoted_name +from ._typing import is_tuple_type +from .annotation import Annotated +from .annotation import SupportsWrappingAnnotations +from .base import _clone +from .base import _expand_cloned +from .base import _generative +from .base import _NoArg +from .base import Executable +from .base import Generative +from .base import HasMemoized +from .base import Immutable +from .base import NO_ARG +from .base import SingletonConstant +from .cache_key import MemoizedHasCacheKey +from .cache_key import NO_CACHE +from .coercions import _document_text_coercion # noqa +from .operators import ColumnOperators +from .traversals import HasCopyInternals +from .visitors import cloned_traverse +from .visitors import ExternallyTraversible +from .visitors import InternalTraversal +from .visitors import traverse +from .visitors import Visitable +from .. import exc +from .. import inspection +from .. import util +from ..util import HasMemoized_ro_memoized_attribute +from ..util import TypingOnly +from ..util.typing import Literal +from ..util.typing import Self + +if typing.TYPE_CHECKING: + from ._typing import _ByArgument + from ._typing import _ColumnExpressionArgument + from ._typing import _ColumnExpressionOrStrLabelArgument + from ._typing import _HasDialect + from ._typing import _InfoType + from ._typing import _PropagateAttrsType + from ._typing import _TypeEngineArgument + from .cache_key import _CacheKeyTraversalType + from .cache_key import CacheKey + from .compiler import Compiled + from .compiler import SQLCompiler + from .functions import FunctionElement + from .operators import OperatorType + from .schema import Column + from .schema import DefaultGenerator + from .schema import FetchedValue + from .schema import ForeignKey + from .selectable import _SelectIterable + from .selectable import FromClause + from .selectable import NamedFromClause + from .selectable import TextualSelect + from .sqltypes import TupleType + from .type_api import TypeEngine + from .visitors import _CloneCallableType + from .visitors import _TraverseInternalsType + from .visitors import anon_map + from ..engine import Connection + from ..engine import Dialect + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import CacheStats + from ..engine.interfaces import CompiledCacheType + from ..engine.interfaces import CoreExecuteOptionsParameter + from ..engine.interfaces import SchemaTranslateMapType + from ..engine.result import Result + +_NUMERIC = Union[float, Decimal] +_NUMBER = Union[float, int, Decimal] + +_T = TypeVar("_T", bound="Any") +_T_co = TypeVar("_T_co", bound=Any, covariant=True) +_OPT = TypeVar("_OPT", bound="Any") +_NT = TypeVar("_NT", bound="_NUMERIC") + +_NMT = TypeVar("_NMT", bound="_NUMBER") + + +@overload +def literal( + value: Any, + type_: _TypeEngineArgument[_T], + literal_execute: bool = False, +) -> BindParameter[_T]: ... + + +@overload +def literal( + value: _T, + type_: None = None, + literal_execute: bool = False, +) -> BindParameter[_T]: ... + + +@overload +def literal( + value: Any, + type_: Optional[_TypeEngineArgument[Any]] = None, + literal_execute: bool = False, +) -> BindParameter[Any]: ... + + +def literal( + value: Any, + type_: Optional[_TypeEngineArgument[Any]] = None, + literal_execute: bool = False, +) -> BindParameter[Any]: + r"""Return a literal clause, bound to a bind parameter. + + Literal clauses are created automatically when non- + :class:`_expression.ClauseElement` objects (such as strings, ints, dates, + etc.) are + used in a comparison operation with a :class:`_expression.ColumnElement` + subclass, + such as a :class:`~sqlalchemy.schema.Column` object. Use this function + to force the generation of a literal clause, which will be created as a + :class:`BindParameter` with a bound value. + + :param value: the value to be bound. Can be any Python object supported by + the underlying DB-API, or is translatable via the given type argument. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` which will + provide bind-parameter translation for this literal. + + :param literal_execute: optional bool, when True, the SQL engine will + attempt to render the bound value directly in the SQL statement at + execution time rather than providing as a parameter value. + + .. versionadded:: 2.0 + + """ + return coercions.expect( + roles.LiteralValueRole, + value, + type_=type_, + literal_execute=literal_execute, + ) + + +def literal_column( + text: str, type_: Optional[_TypeEngineArgument[_T]] = None +) -> ColumnClause[_T]: + r"""Produce a :class:`.ColumnClause` object that has the + :paramref:`_expression.column.is_literal` flag set to True. + + :func:`_expression.literal_column` is similar to + :func:`_expression.column`, except that + it is more often used as a "standalone" column expression that renders + exactly as stated; while :func:`_expression.column` + stores a string name that + will be assumed to be part of a table and may be quoted as such, + :func:`_expression.literal_column` can be that, + or any other arbitrary column-oriented + expression. + + :param text: the text of the expression; can be any SQL expression. + Quoting rules will not be applied. To specify a column-name expression + which should be subject to quoting rules, use the :func:`column` + function. + + :param type\_: an optional :class:`~sqlalchemy.types.TypeEngine` + object which will + provide result-set translation and additional expression semantics for + this column. If left as ``None`` the type will be :class:`.NullType`. + + .. seealso:: + + :func:`_expression.column` + + :func:`_expression.text` + + :ref:`tutorial_select_arbitrary_text` + + """ + return ColumnClause(text, type_=type_, is_literal=True) + + +class CompilerElement(Visitable): + """base class for SQL elements that can be compiled to produce a + SQL string. + + .. versionadded:: 2.0 + + """ + + __slots__ = () + __visit_name__ = "compiler_element" + + supports_execution = False + + stringify_dialect = "default" + + @util.preload_module("sqlalchemy.engine.default") + @util.preload_module("sqlalchemy.engine.url") + def compile( + self, + bind: Optional[_HasDialect] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> Compiled: + """Compile this SQL expression. + + The return value is a :class:`~.Compiled` object. + Calling ``str()`` or ``unicode()`` on the returned value will yield a + string representation of the result. The + :class:`~.Compiled` object also can return a + dictionary of bind parameter names and values + using the ``params`` accessor. + + :param bind: An :class:`.Connection` or :class:`.Engine` which + can provide a :class:`.Dialect` in order to generate a + :class:`.Compiled` object. If the ``bind`` and + ``dialect`` parameters are both omitted, a default SQL compiler + is used. + + :param column_keys: Used for INSERT and UPDATE statements, a list of + column names which should be present in the VALUES clause of the + compiled statement. If ``None``, all columns from the target table + object are rendered. + + :param dialect: A :class:`.Dialect` instance which can generate + a :class:`.Compiled` object. This argument takes precedence over + the ``bind`` argument. + + :param compile_kwargs: optional dictionary of additional parameters + that will be passed through to the compiler within all "visit" + methods. This allows any custom flag to be passed through to + a custom compilation construct, for example. It is also used + for the case of passing the ``literal_binds`` flag through:: + + from sqlalchemy.sql import table, column, select + + t = table('t', column('x')) + + s = select(t).where(t.c.x == 5) + + print(s.compile(compile_kwargs={"literal_binds": True})) + + .. seealso:: + + :ref:`faq_sql_expression_string` + + """ + + if dialect is None: + if bind: + dialect = bind.dialect + elif self.stringify_dialect == "default": + default = util.preloaded.engine_default + dialect = default.StrCompileDialect() + else: + url = util.preloaded.engine_url + dialect = url.URL.create( + self.stringify_dialect + ).get_dialect()() + + return self._compiler(dialect, **kw) + + def _compiler(self, dialect: Dialect, **kw: Any) -> Compiled: + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + + if TYPE_CHECKING: + assert isinstance(self, ClauseElement) + return dialect.statement_compiler(dialect, self, **kw) + + def __str__(self) -> str: + return str(self.compile()) + + +@inspection._self_inspects +class ClauseElement( + SupportsWrappingAnnotations, + MemoizedHasCacheKey, + HasCopyInternals, + ExternallyTraversible, + CompilerElement, +): + """Base class for elements of a programmatically constructed SQL + expression. + + """ + + __visit_name__ = "clause" + + if TYPE_CHECKING: + + @util.memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: + """like annotations, however these propagate outwards liberally + as SQL constructs are built, and are set up at construction time. + + """ + ... + + else: + _propagate_attrs = util.EMPTY_DICT + + @util.ro_memoized_property + def description(self) -> Optional[str]: + return None + + _is_clone_of: Optional[Self] = None + + is_clause_element = True + is_selectable = False + is_dml = False + _is_column_element = False + _is_keyed_column_element = False + _is_table = False + _gen_static_annotations_cache_key = False + _is_textual = False + _is_from_clause = False + _is_returns_rows = False + _is_text_clause = False + _is_from_container = False + _is_select_container = False + _is_select_base = False + _is_select_statement = False + _is_bind_parameter = False + _is_clause_list = False + _is_lambda_element = False + _is_singleton_constant = False + _is_immutable = False + _is_star = False + + @property + def _order_by_label_element(self) -> Optional[Label[Any]]: + return None + + _cache_key_traversal: _CacheKeyTraversalType = None + + negation_clause: ColumnElement[bool] + + if typing.TYPE_CHECKING: + + def get_children( + self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any + ) -> Iterable[ClauseElement]: ... + + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return [] + + def _set_propagate_attrs(self, values: Mapping[str, Any]) -> Self: + # usually, self._propagate_attrs is empty here. one case where it's + # not is a subquery against ORM select, that is then pulled as a + # property of an aliased class. should all be good + + # assert not self._propagate_attrs + + self._propagate_attrs = util.immutabledict(values) + return self + + def _clone(self, **kw: Any) -> Self: + """Create a shallow copy of this ClauseElement. + + This method may be used by a generative API. Its also used as + part of the "deep" copy afforded by a traversal that combines + the _copy_internals() method. + + """ + + skip = self._memoized_keys + c = self.__class__.__new__(self.__class__) + + if skip: + # ensure this iteration remains atomic + c.__dict__ = { + k: v for k, v in self.__dict__.copy().items() if k not in skip + } + else: + c.__dict__ = self.__dict__.copy() + + # this is a marker that helps to "equate" clauses to each other + # when a Select returns its list of FROM clauses. the cloning + # process leaves around a lot of remnants of the previous clause + # typically in the form of column expressions still attached to the + # old table. + cc = self._is_clone_of + c._is_clone_of = cc if cc is not None else self + return c + + def _negate_in_binary(self, negated_op, original_op): + """a hook to allow the right side of a binary expression to respond + to a negation of the binary expression. + + Used for the special case of expanding bind parameter with IN. + + """ + return self + + def _with_binary_element_type(self, type_): + """in the context of binary expression, convert the type of this + object to the one given. + + applies only to :class:`_expression.ColumnElement` classes. + + """ + return self + + @property + def _constructor(self): + """return the 'constructor' for this ClauseElement. + + This is for the purposes for creating a new object of + this type. Usually, its just the element's __class__. + However, the "Annotated" version of the object overrides + to return the class of its proxied element. + + """ + return self.__class__ + + @HasMemoized.memoized_attribute + def _cloned_set(self): + """Return the set consisting all cloned ancestors of this + ClauseElement. + + Includes this ClauseElement. This accessor tends to be used for + FromClause objects to identify 'equivalent' FROM clauses, regardless + of transformative operations. + + """ + s = util.column_set() + f: Optional[ClauseElement] = self + + # note this creates a cycle, asserted in test_memusage. however, + # turning this into a plain @property adds tends of thousands of method + # calls to Core / ORM performance tests, so the small overhead + # introduced by the relatively small amount of short term cycles + # produced here is preferable + while f is not None: + s.add(f) + f = f._is_clone_of + return s + + def _de_clone(self): + while self._is_clone_of is not None: + self = self._is_clone_of + return self + + @property + def entity_namespace(self): + raise AttributeError( + "This SQL expression has no entity namespace " + "with which to filter from." + ) + + def __getstate__(self): + d = self.__dict__.copy() + d.pop("_is_clone_of", None) + d.pop("_generate_cache_key", None) + return d + + def _execute_on_connection( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Result[Any]: + if self.supports_execution: + if TYPE_CHECKING: + assert isinstance(self, Executable) + return connection._execute_clauseelement( + self, distilled_params, execution_options + ) + else: + raise exc.ObjectNotExecutableError(self) + + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: CoreExecuteOptionsParameter, + ) -> Any: + """an additional hook for subclasses to provide a different + implementation for connection.scalar() vs. connection.execute(). + + .. versionadded:: 2.0 + + """ + return self._execute_on_connection( + connection, distilled_params, execution_options + ).scalar() + + def _get_embedded_bindparams(self) -> Sequence[BindParameter[Any]]: + """Return the list of :class:`.BindParameter` objects embedded in the + object. + + This accomplishes the same purpose as ``visitors.traverse()`` or + similar would provide, however by making use of the cache key + it takes advantage of memoization of the key to result in fewer + net method calls, assuming the statement is also going to be + executed. + + """ + + key = self._generate_cache_key() + if key is None: + bindparams: List[BindParameter[Any]] = [] + + traverse(self, {}, {"bindparam": bindparams.append}) + return bindparams + + else: + return key.bindparams + + def unique_params( + self, + __optionaldict: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Self: + """Return a copy with :func:`_expression.bindparam` elements + replaced. + + Same functionality as :meth:`_expression.ClauseElement.params`, + except adds `unique=True` + to affected bind parameters so that multiple statements can be + used. + + """ + return self._replace_params(True, __optionaldict, kwargs) + + def params( + self, + __optionaldict: Optional[Mapping[str, Any]] = None, + **kwargs: Any, + ) -> Self: + """Return a copy with :func:`_expression.bindparam` elements + replaced. + + Returns a copy of this ClauseElement with + :func:`_expression.bindparam` + elements replaced with values taken from the given dictionary:: + + >>> clause = column('x') + bindparam('foo') + >>> print(clause.compile().params) + {'foo':None} + >>> print(clause.params({'foo':7}).compile().params) + {'foo':7} + + """ + return self._replace_params(False, __optionaldict, kwargs) + + def _replace_params( + self, + unique: bool, + optionaldict: Optional[Mapping[str, Any]], + kwargs: Dict[str, Any], + ) -> Self: + if optionaldict: + kwargs.update(optionaldict) + + def visit_bindparam(bind: BindParameter[Any]) -> None: + if bind.key in kwargs: + bind.value = kwargs[bind.key] + bind.required = False + if unique: + bind._convert_to_unique() + + return cloned_traverse( + self, + {"maintain_key": True, "detect_subquery_cols": True}, + {"bindparam": visit_bindparam}, + ) + + def compare(self, other: ClauseElement, **kw: Any) -> bool: + r"""Compare this :class:`_expression.ClauseElement` to + the given :class:`_expression.ClauseElement`. + + Subclasses should override the default behavior, which is a + straight identity comparison. + + \**kw are arguments consumed by subclass ``compare()`` methods and + may be used to modify the criteria for comparison + (see :class:`_expression.ColumnElement`). + + """ + return traversals.compare(self, other, **kw) + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ClauseElement: + """Apply a 'grouping' to this :class:`_expression.ClauseElement`. + + This method is overridden by subclasses to return a "grouping" + construct, i.e. parenthesis. In particular it's used by "binary" + expressions to provide a grouping around themselves when placed into a + larger expression, as well as by :func:`_expression.select` + constructs when placed into the FROM clause of another + :func:`_expression.select`. (Note that subqueries should be + normally created using the :meth:`_expression.Select.alias` method, + as many + platforms require nested SELECT statements to be named). + + As expressions are composed together, the application of + :meth:`self_group` is automatic - end-user code should never + need to use this method directly. Note that SQLAlchemy's + clause constructs take operator precedence into account - + so parenthesis might not be needed, for example, in + an expression like ``x OR (y AND z)`` - AND takes precedence + over OR. + + The base :meth:`self_group` method of + :class:`_expression.ClauseElement` + just returns self. + """ + return self + + def _ungroup(self) -> ClauseElement: + """Return this :class:`_expression.ClauseElement` + without any groupings. + """ + + return self + + def _compile_w_cache( + self, + dialect: Dialect, + *, + compiled_cache: Optional[CompiledCacheType], + column_keys: List[str], + for_executemany: bool = False, + schema_translate_map: Optional[SchemaTranslateMapType] = None, + **kw: Any, + ) -> typing_Tuple[ + Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats + ]: + elem_cache_key: Optional[CacheKey] + + if compiled_cache is not None and dialect._supports_statement_cache: + elem_cache_key = self._generate_cache_key() + else: + elem_cache_key = None + + if elem_cache_key is not None: + if TYPE_CHECKING: + assert compiled_cache is not None + + cache_key, extracted_params = elem_cache_key + key = ( + dialect, + cache_key, + tuple(column_keys), + bool(schema_translate_map), + for_executemany, + ) + compiled_sql = compiled_cache.get(key) + + if compiled_sql is None: + cache_hit = dialect.CACHE_MISS + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + **kw, + ) + compiled_cache[key] = compiled_sql + else: + cache_hit = dialect.CACHE_HIT + else: + extracted_params = None + compiled_sql = self._compiler( + dialect, + cache_key=elem_cache_key, + column_keys=column_keys, + for_executemany=for_executemany, + schema_translate_map=schema_translate_map, + **kw, + ) + + if not dialect._supports_statement_cache: + cache_hit = dialect.NO_DIALECT_SUPPORT + elif compiled_cache is None: + cache_hit = dialect.CACHING_DISABLED + else: + cache_hit = dialect.NO_CACHE_KEY + + return compiled_sql, extracted_params, cache_hit + + def __invert__(self): + # undocumented element currently used by the ORM for + # relationship.contains() + if hasattr(self, "negation_clause"): + return self.negation_clause + else: + return self._negate() + + def _negate(self) -> ClauseElement: + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression(grouped, operator=operators.inv) + + def __bool__(self): + raise TypeError("Boolean value of this clause is not defined") + + def __repr__(self): + friendly = self.description + if friendly is None: + return object.__repr__(self) + else: + return "<%s.%s at 0x%x; %s>" % ( + self.__module__, + self.__class__.__name__, + id(self), + friendly, + ) + + +class DQLDMLClauseElement(ClauseElement): + """represents a :class:`.ClauseElement` that compiles to a DQL or DML + expression, not DDL. + + .. versionadded:: 2.0 + + """ + + if typing.TYPE_CHECKING: + + def _compiler(self, dialect: Dialect, **kw: Any) -> SQLCompiler: + """Return a compiler appropriate for this ClauseElement, given a + Dialect.""" + ... + + def compile( # noqa: A001 + self, + bind: Optional[_HasDialect] = None, + dialect: Optional[Dialect] = None, + **kw: Any, + ) -> SQLCompiler: ... + + +class CompilerColumnElement( + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.ColumnsClauseRole, + CompilerElement, +): + """A compiler-only column element used for ad-hoc string compilations. + + .. versionadded:: 2.0 + + """ + + __slots__ = () + + _propagate_attrs = util.EMPTY_DICT + _is_collection_aggregate = False + + +# SQLCoreOperations should be suiting the ExpressionElementRole +# and ColumnsClauseRole. however the MRO issues become too elaborate +# at the moment. +class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): + __slots__ = () + + # annotations for comparison methods + # these are from operators->Operators / ColumnOperators, + # redefined with the specific types returned by ColumnElement hierarchies + if typing.TYPE_CHECKING: + + @util.non_memoized_property + def _propagate_attrs(self) -> _PropagateAttrsType: ... + + def operate( + self, op: OperatorType, *other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... + + def reverse_operate( + self, op: OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: ... + + @overload + def op( + self, + opstring: str, + precedence: int = ..., + is_comparison: bool = ..., + *, + return_type: _TypeEngineArgument[_OPT], + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[_OPT]]: ... + + @overload + def op( + self, + opstring: str, + precedence: int = ..., + is_comparison: bool = ..., + return_type: Optional[_TypeEngineArgument[Any]] = ..., + python_impl: Optional[Callable[..., Any]] = ..., + ) -> Callable[[Any], BinaryExpression[Any]]: ... + + def op( + self, + opstring: str, + precedence: int = 0, + is_comparison: bool = False, + return_type: Optional[_TypeEngineArgument[Any]] = None, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[Any]]: ... + + def bool_op( + self, + opstring: str, + precedence: int = 0, + python_impl: Optional[Callable[..., Any]] = None, + ) -> Callable[[Any], BinaryExpression[bool]]: ... + + def __and__(self, other: Any) -> BooleanClauseList: ... + + def __or__(self, other: Any) -> BooleanClauseList: ... + + def __invert__(self) -> ColumnElement[_T_co]: ... + + def __lt__(self, other: Any) -> ColumnElement[bool]: ... + + def __le__(self, other: Any) -> ColumnElement[bool]: ... + + # declare also that this class has an hash method otherwise + # it may be assumed to be None by type checkers since the + # object defines __eq__ and python sets it to None in that case: + # https://docs.python.org/3/reference/datamodel.html#object.__hash__ + def __hash__(self) -> int: ... + + def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + ... + + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 + ... + + def is_distinct_from(self, other: Any) -> ColumnElement[bool]: ... + + def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: ... + + def __gt__(self, other: Any) -> ColumnElement[bool]: ... + + def __ge__(self, other: Any) -> ColumnElement[bool]: ... + + def __neg__(self) -> UnaryExpression[_T_co]: ... + + def __contains__(self, other: Any) -> ColumnElement[bool]: ... + + def __getitem__(self, index: Any) -> ColumnElement[Any]: ... + + @overload + def __lshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... + + @overload + def __lshift__(self, other: Any) -> ColumnElement[Any]: ... + + def __lshift__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... + + @overload + def __rshift__(self, other: Any) -> ColumnElement[Any]: ... + + def __rshift__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ... + + @overload + def concat(self, other: Any) -> ColumnElement[Any]: ... + + def concat(self, other: Any) -> ColumnElement[Any]: ... + + def like( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... + + def ilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... + + def bitwise_xor(self, other: Any) -> BinaryExpression[Any]: ... + + def bitwise_or(self, other: Any) -> BinaryExpression[Any]: ... + + def bitwise_and(self, other: Any) -> BinaryExpression[Any]: ... + + def bitwise_not(self) -> UnaryExpression[_T_co]: ... + + def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]: ... + + def bitwise_rshift(self, other: Any) -> BinaryExpression[Any]: ... + + def in_( + self, + other: Union[ + Iterable[Any], BindParameter[Any], roles.InElementRole + ], + ) -> BinaryExpression[bool]: ... + + def not_in( + self, + other: Union[ + Iterable[Any], BindParameter[Any], roles.InElementRole + ], + ) -> BinaryExpression[bool]: ... + + def notin_( + self, + other: Union[ + Iterable[Any], BindParameter[Any], roles.InElementRole + ], + ) -> BinaryExpression[bool]: ... + + def not_like( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... + + def notlike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... + + def not_ilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... + + def notilike( + self, other: Any, escape: Optional[str] = None + ) -> BinaryExpression[bool]: ... + + def is_(self, other: Any) -> BinaryExpression[bool]: ... + + def is_not(self, other: Any) -> BinaryExpression[bool]: ... + + def isnot(self, other: Any) -> BinaryExpression[bool]: ... + + def startswith( + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... + + def istartswith( + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... + + def endswith( + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... + + def iendswith( + self, + other: Any, + escape: Optional[str] = None, + autoescape: bool = False, + ) -> ColumnElement[bool]: ... + + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... + + def icontains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... + + def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: ... + + def regexp_match( + self, pattern: Any, flags: Optional[str] = None + ) -> ColumnElement[bool]: ... + + def regexp_replace( + self, pattern: Any, replacement: Any, flags: Optional[str] = None + ) -> ColumnElement[str]: ... + + def desc(self) -> UnaryExpression[_T_co]: ... + + def asc(self) -> UnaryExpression[_T_co]: ... + + def nulls_first(self) -> UnaryExpression[_T_co]: ... + + def nullsfirst(self) -> UnaryExpression[_T_co]: ... + + def nulls_last(self) -> UnaryExpression[_T_co]: ... + + def nullslast(self) -> UnaryExpression[_T_co]: ... + + def collate(self, collation: str) -> CollationClause: ... + + def between( + self, cleft: Any, cright: Any, symmetric: bool = False + ) -> BinaryExpression[bool]: ... + + def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]: ... + + def any_(self) -> CollectionAggregate[Any]: ... + + def all_(self) -> CollectionAggregate[Any]: ... + + # numeric overloads. These need more tweaking + # in particular they all need to have a variant for Optiona[_T] + # because Optional only applies to the data side, not the expression + # side + + @overload + def __add__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... + + @overload + def __add__( + self: _SQO[str], + other: Any, + ) -> ColumnElement[str]: ... + + def __add__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... + + @overload + def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: ... + + def __radd__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __sub__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... + + @overload + def __sub__(self, other: Any) -> ColumnElement[Any]: ... + + def __sub__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __rsub__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... + + @overload + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... + + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __mul__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... + + @overload + def __mul__(self, other: Any) -> ColumnElement[Any]: ... + + def __mul__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __rmul__( + self: _SQO[_NMT], + other: Any, + ) -> ColumnElement[_NMT]: ... + + @overload + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... + + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... + + @overload + def __mod__(self, other: Any) -> ColumnElement[Any]: ... + + def __mod__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... + + @overload + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... + + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __truediv__( + self: _SQO[int], other: Any + ) -> ColumnElement[_NUMERIC]: ... + + @overload + def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: ... + + @overload + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... + + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __rtruediv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NUMERIC]: ... + + @overload + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... + + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __floordiv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NMT]: ... + + @overload + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... + + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... + + @overload + def __rfloordiv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NMT]: ... + + @overload + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... + + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... + + +class SQLColumnExpression( + SQLCoreOperations[_T_co], roles.ExpressionElementRole[_T_co], TypingOnly +): + """A type that may be used to indicate any SQL column element or object + that acts in place of one. + + :class:`.SQLColumnExpression` is a base of + :class:`.ColumnElement`, as well as within the bases of ORM elements + such as :class:`.InstrumentedAttribute`, and may be used in :pep:`484` + typing to indicate arguments or return values that should behave + as column expressions. + + .. versionadded:: 2.0.0b4 + + + """ + + __slots__ = () + + +_SQO = SQLCoreOperations + + +class ColumnElement( + roles.ColumnArgumentOrKeyRole, + roles.StatementOptionRole, + roles.WhereHavingRole, + roles.BinaryElementRole[_T], + roles.OrderByRole, + roles.ColumnsClauseRole, + roles.LimitOffsetRole, + roles.DMLColumnRole, + roles.DDLConstraintColumnRole, + roles.DDLExpressionRole, + SQLColumnExpression[_T], + DQLDMLClauseElement, +): + """Represent a column-oriented SQL expression suitable for usage in the + "columns" clause, WHERE clause etc. of a statement. + + While the most familiar kind of :class:`_expression.ColumnElement` is the + :class:`_schema.Column` object, :class:`_expression.ColumnElement` + serves as the basis + for any unit that may be present in a SQL expression, including + the expressions themselves, SQL functions, bound parameters, + literal expressions, keywords such as ``NULL``, etc. + :class:`_expression.ColumnElement` + is the ultimate base class for all such elements. + + A wide variety of SQLAlchemy Core functions work at the SQL expression + level, and are intended to accept instances of + :class:`_expression.ColumnElement` as + arguments. These functions will typically document that they accept a + "SQL expression" as an argument. What this means in terms of SQLAlchemy + usually refers to an input which is either already in the form of a + :class:`_expression.ColumnElement` object, + or a value which can be **coerced** into + one. The coercion rules followed by most, but not all, SQLAlchemy Core + functions with regards to SQL expressions are as follows: + + * a literal Python value, such as a string, integer or floating + point value, boolean, datetime, ``Decimal`` object, or virtually + any other Python object, will be coerced into a "literal bound + value". This generally means that a :func:`.bindparam` will be + produced featuring the given value embedded into the construct; the + resulting :class:`.BindParameter` object is an instance of + :class:`_expression.ColumnElement`. + The Python value will ultimately be sent + to the DBAPI at execution time as a parameterized argument to the + ``execute()`` or ``executemany()`` methods, after SQLAlchemy + type-specific converters (e.g. those provided by any associated + :class:`.TypeEngine` objects) are applied to the value. + + * any special object value, typically ORM-level constructs, which + feature an accessor called ``__clause_element__()``. The Core + expression system looks for this method when an object of otherwise + unknown type is passed to a function that is looking to coerce the + argument into a :class:`_expression.ColumnElement` and sometimes a + :class:`_expression.SelectBase` expression. + It is used within the ORM to + convert from ORM-specific objects like mapped classes and + mapped attributes into Core expression objects. + + * The Python ``None`` value is typically interpreted as ``NULL``, + which in SQLAlchemy Core produces an instance of :func:`.null`. + + A :class:`_expression.ColumnElement` provides the ability to generate new + :class:`_expression.ColumnElement` + objects using Python expressions. This means that Python operators + such as ``==``, ``!=`` and ``<`` are overloaded to mimic SQL operations, + and allow the instantiation of further :class:`_expression.ColumnElement` + instances + which are composed from other, more fundamental + :class:`_expression.ColumnElement` + objects. For example, two :class:`.ColumnClause` objects can be added + together with the addition operator ``+`` to produce + a :class:`.BinaryExpression`. + Both :class:`.ColumnClause` and :class:`.BinaryExpression` are subclasses + of :class:`_expression.ColumnElement`: + + .. sourcecode:: pycon+sql + + >>> from sqlalchemy.sql import column + >>> column('a') + column('b') + + >>> print(column('a') + column('b')) + {printsql}a + b + + .. seealso:: + + :class:`_schema.Column` + + :func:`_expression.column` + + """ + + __visit_name__ = "column_element" + + primary_key: bool = False + _is_clone_of: Optional[ColumnElement[_T]] + _is_column_element = True + _insert_sentinel: bool = False + _omit_from_statements = False + _is_collection_aggregate = False + + foreign_keys: AbstractSet[ForeignKey] = frozenset() + + @util.memoized_property + def _proxies(self) -> List[ColumnElement[Any]]: + return [] + + @util.non_memoized_property + def _tq_label(self) -> Optional[str]: + """The named label that can be used to target + this column in a result set in a "table qualified" context. + + This label is almost always the label used when + rendering AS AS "; typically columns that don't have + any parent table and are named the same as what the label would be + in any case. + + """ + + _allow_label_resolve = True + """A flag that can be flipped to prevent a column from being resolvable + by string label name. + + The joined eager loader strategy in the ORM uses this, for example. + + """ + + _is_implicitly_boolean = False + + _alt_names: Sequence[str] = () + + @overload + def self_group( + self: ColumnElement[_T], against: Optional[OperatorType] = None + ) -> ColumnElement[_T]: ... + + @overload + def self_group( + self: ColumnElement[Any], against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: ... + + def self_group( + self, against: Optional[OperatorType] = None + ) -> ColumnElement[Any]: + if ( + against in (operators.and_, operators.or_, operators._asbool) + and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity + ): + return AsBoolean(self, operators.is_true, operators.is_false) + elif against in (operators.any_op, operators.all_op): + return Grouping(self) + else: + return self + + @overload + def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: ... + + @overload + def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: ... + + def _negate(self) -> ColumnElement[Any]: + if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: + return AsBoolean(self, operators.is_false, operators.is_true) + else: + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) + + type: TypeEngine[_T] + + if not TYPE_CHECKING: + + @util.memoized_property + def type(self) -> TypeEngine[_T]: # noqa: A001 + # used for delayed setup of + # type_api + return type_api.NULLTYPE + + @HasMemoized.memoized_attribute + def comparator(self) -> TypeEngine.Comparator[_T]: + try: + comparator_factory = self.type.comparator_factory + except AttributeError as err: + raise TypeError( + "Object %r associated with '.type' attribute " + "is not a TypeEngine class or object" % self.type + ) from err + else: + return comparator_factory(self) + + def __setstate__(self, state): + self.__dict__.update(state) + + def __getattr__(self, key: str) -> Any: + try: + return getattr(self.comparator, key) + except AttributeError as err: + raise AttributeError( + "Neither %r object nor %r object has an attribute %r" + % ( + type(self).__name__, + type(self.comparator).__name__, + key, + ) + ) from err + + def operate( + self, + op: operators.OperatorType, + *other: Any, + **kwargs: Any, + ) -> ColumnElement[Any]: + return op(self.comparator, *other, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def reverse_operate( + self, op: operators.OperatorType, other: Any, **kwargs: Any + ) -> ColumnElement[Any]: + return op(other, self.comparator, **kwargs) # type: ignore[no-any-return] # noqa: E501 + + def _bind_param( + self, + operator: operators.OperatorType, + obj: Any, + type_: Optional[TypeEngine[_T]] = None, + expanding: bool = False, + ) -> BindParameter[_T]: + return BindParameter( + None, + obj, + _compared_to_operator=operator, + type_=type_, + _compared_to_type=self.type, + unique=True, + expanding=expanding, + ) + + @property + def expression(self) -> ColumnElement[Any]: + """Return a column expression. + + Part of the inspection interface; returns self. + + """ + return self + + @property + def _select_iterable(self) -> _SelectIterable: + return (self,) + + @util.memoized_property + def base_columns(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset(c for c in self.proxy_set if not c._proxies) + + @util.memoized_property + def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + """set of all columns we are proxying + + as of 2.0 this is explicitly deannotated columns. previously it was + effectively deannotated columns but wasn't enforced. annotated + columns should basically not go into sets if at all possible because + their hashing behavior is very non-performant. + + """ + return frozenset([self._deannotate()]).union( + itertools.chain(*[c.proxy_set for c in self._proxies]) + ) + + @util.memoized_property + def _expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + return frozenset(_expand_cloned(self.proxy_set)) + + def _uncached_proxy_list(self) -> List[ColumnElement[Any]]: + """An 'uncached' version of proxy set. + + This list includes annotated columns which perform very poorly in + set operations. + + """ + + return [self] + list( + itertools.chain(*[c._uncached_proxy_list() for c in self._proxies]) + ) + + def shares_lineage(self, othercolumn: ColumnElement[Any]) -> bool: + """Return True if the given :class:`_expression.ColumnElement` + has a common ancestor to this :class:`_expression.ColumnElement`.""" + + return bool(self.proxy_set.intersection(othercolumn.proxy_set)) + + def _compare_name_for_result(self, other: ColumnElement[Any]) -> bool: + """Return True if the given column element compares to this one + when targeting within a result row.""" + + return ( + hasattr(other, "name") + and hasattr(self, "name") + and other.name == self.name + ) + + @HasMemoized.memoized_attribute + def _proxy_key(self) -> Optional[str]: + if self._annotations and "proxy_key" in self._annotations: + return cast(str, self._annotations["proxy_key"]) + + name = self.key + if not name: + # there's a bit of a seeming contradiction which is that the + # "_non_anon_label" of a column can in fact be an + # "_anonymous_label"; this is when it's on a column that is + # proxying for an anonymous expression in a subquery. + name = self._non_anon_label + + if isinstance(name, _anonymous_label): + return None + else: + return name + + @HasMemoized.memoized_attribute + def _expression_label(self) -> Optional[str]: + """a suggested label to use in the case that the column has no name, + which should be used if possible as the explicit 'AS