diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py | 1782 |
1 files changed, 0 insertions, 1782 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py deleted file mode 100644 index 369fc59..0000000 --- a/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py +++ /dev/null @@ -1,1782 +0,0 @@ -# orm/persistence.py -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# <see AUTHORS file> -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - - -"""private module containing functions used to emit INSERT, UPDATE -and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending -mappers. - -The functions here are called only by the unit of work functions -in unitofwork.py. - -""" -from __future__ import annotations - -from itertools import chain -from itertools import groupby -from itertools import zip_longest -import operator - -from . import attributes -from . import exc as orm_exc -from . import loading -from . import sync -from .base import state_str -from .. import exc as sa_exc -from .. import future -from .. import sql -from .. import util -from ..engine import cursor as _cursor -from ..sql import operators -from ..sql.elements import BooleanClauseList -from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL - - -def save_obj(base_mapper, states, uowtransaction, single=False): - """Issue ``INSERT`` and/or ``UPDATE`` statements for a list - of objects. - - This is called within the context of a UOWTransaction during a - flush operation, given a list of states to be flushed. The - base mapper in an inheritance hierarchy handles the inserts/ - updates for all descendant mappers. - - """ - - # if batch=false, call _save_obj separately for each object - if not single and not base_mapper.batch: - for state in _sort_states(base_mapper, states): - save_obj(base_mapper, [state], uowtransaction, single=True) - return - - states_to_update = [] - states_to_insert = [] - - for ( - state, - dict_, - mapper, - connection, - has_identity, - row_switch, - update_version_id, - ) in _organize_states_for_save(base_mapper, states, uowtransaction): - if has_identity or row_switch: - states_to_update.append( - (state, dict_, mapper, connection, update_version_id) - ) - else: - states_to_insert.append((state, dict_, mapper, connection)) - - for table, mapper in base_mapper._sorted_tables.items(): - if table not in mapper._pks_by_table: - continue - insert = _collect_insert_commands(table, states_to_insert) - - update = _collect_update_commands( - uowtransaction, table, states_to_update - ) - - _emit_update_statements( - base_mapper, - uowtransaction, - mapper, - table, - update, - ) - - _emit_insert_statements( - base_mapper, - uowtransaction, - mapper, - table, - insert, - ) - - _finalize_insert_update_commands( - base_mapper, - uowtransaction, - chain( - ( - (state, state_dict, mapper, connection, False) - for (state, state_dict, mapper, connection) in states_to_insert - ), - ( - (state, state_dict, mapper, connection, True) - for ( - state, - state_dict, - mapper, - connection, - update_version_id, - ) in states_to_update - ), - ), - ) - - -def post_update(base_mapper, states, uowtransaction, post_update_cols): - """Issue UPDATE statements on behalf of a relationship() which - specifies post_update. - - """ - - states_to_update = list( - _organize_states_for_post_update(base_mapper, states, uowtransaction) - ) - - for table, mapper in base_mapper._sorted_tables.items(): - if table not in mapper._pks_by_table: - continue - - update = ( - ( - state, - state_dict, - sub_mapper, - connection, - ( - mapper._get_committed_state_attr_by_column( - state, state_dict, mapper.version_id_col - ) - if mapper.version_id_col is not None - else None - ), - ) - for state, state_dict, sub_mapper, connection in states_to_update - if table in sub_mapper._pks_by_table - ) - - update = _collect_post_update_commands( - base_mapper, uowtransaction, table, update, post_update_cols - ) - - _emit_post_update_statements( - base_mapper, - uowtransaction, - mapper, - table, - update, - ) - - -def delete_obj(base_mapper, states, uowtransaction): - """Issue ``DELETE`` statements for a list of objects. - - This is called within the context of a UOWTransaction during a - flush operation. - - """ - - states_to_delete = list( - _organize_states_for_delete(base_mapper, states, uowtransaction) - ) - - table_to_mapper = base_mapper._sorted_tables - - for table in reversed(list(table_to_mapper.keys())): - mapper = table_to_mapper[table] - if table not in mapper._pks_by_table: - continue - elif mapper.inherits and mapper.passive_deletes: - continue - - delete = _collect_delete_commands( - base_mapper, uowtransaction, table, states_to_delete - ) - - _emit_delete_statements( - base_mapper, - uowtransaction, - mapper, - table, - delete, - ) - - for ( - state, - state_dict, - mapper, - connection, - update_version_id, - ) in states_to_delete: - mapper.dispatch.after_delete(mapper, connection, state) - - -def _organize_states_for_save(base_mapper, states, uowtransaction): - """Make an initial pass across a set of states for INSERT or - UPDATE. - - This includes splitting out into distinct lists for - each, calling before_insert/before_update, obtaining - key information for each state including its dictionary, - mapper, the connection to use for the execution per state, - and the identity flag. - - """ - - for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, states - ): - has_identity = bool(state.key) - - instance_key = state.key or mapper._identity_key_from_state(state) - - row_switch = update_version_id = None - - # call before_XXX extensions - if not has_identity: - mapper.dispatch.before_insert(mapper, connection, state) - else: - mapper.dispatch.before_update(mapper, connection, state) - - if mapper._validate_polymorphic_identity: - mapper._validate_polymorphic_identity(mapper, state, dict_) - - # detect if we have a "pending" instance (i.e. has - # no instance_key attached to it), and another instance - # with the same identity key already exists as persistent. - # convert to an UPDATE if so. - if ( - not has_identity - and instance_key in uowtransaction.session.identity_map - ): - instance = uowtransaction.session.identity_map[instance_key] - existing = attributes.instance_state(instance) - - if not uowtransaction.was_already_deleted(existing): - if not uowtransaction.is_deleted(existing): - util.warn( - "New instance %s with identity key %s conflicts " - "with persistent instance %s" - % (state_str(state), instance_key, state_str(existing)) - ) - else: - base_mapper._log_debug( - "detected row switch for identity %s. " - "will update %s, remove %s from " - "transaction", - instance_key, - state_str(state), - state_str(existing), - ) - - # remove the "delete" flag from the existing element - uowtransaction.remove_state_actions(existing) - row_switch = existing - - if (has_identity or row_switch) and mapper.version_id_col is not None: - update_version_id = mapper._get_committed_state_attr_by_column( - row_switch if row_switch else state, - row_switch.dict if row_switch else dict_, - mapper.version_id_col, - ) - - yield ( - state, - dict_, - mapper, - connection, - has_identity, - row_switch, - update_version_id, - ) - - -def _organize_states_for_post_update(base_mapper, states, uowtransaction): - """Make an initial pass across a set of states for UPDATE - corresponding to post_update. - - This includes obtaining key information for each state - including its dictionary, mapper, the connection to use for - the execution per state. - - """ - return _connections_for_states(base_mapper, uowtransaction, states) - - -def _organize_states_for_delete(base_mapper, states, uowtransaction): - """Make an initial pass across a set of states for DELETE. - - This includes calling out before_delete and obtaining - key information for each state including its dictionary, - mapper, the connection to use for the execution per state. - - """ - for state, dict_, mapper, connection in _connections_for_states( - base_mapper, uowtransaction, states - ): - mapper.dispatch.before_delete(mapper, connection, state) - - if mapper.version_id_col is not None: - update_version_id = mapper._get_committed_state_attr_by_column( - state, dict_, mapper.version_id_col - ) - else: - update_version_id = None - - yield (state, dict_, mapper, connection, update_version_id) - - -def _collect_insert_commands( - table, - states_to_insert, - *, - bulk=False, - return_defaults=False, - render_nulls=False, - include_bulk_keys=(), -): - """Identify sets of values to use in INSERT statements for a - list of states. - - """ - for state, state_dict, mapper, connection in states_to_insert: - if table not in mapper._pks_by_table: - continue - - params = {} - value_params = {} - - propkey_to_col = mapper._propkey_to_col[table] - - eval_none = mapper._insert_cols_evaluating_none[table] - - for propkey in set(propkey_to_col).intersection(state_dict): - value = state_dict[propkey] - col = propkey_to_col[propkey] - if value is None and col not in eval_none and not render_nulls: - continue - elif not bulk and ( - hasattr(value, "__clause_element__") - or isinstance(value, sql.ClauseElement) - ): - value_params[col] = ( - value.__clause_element__() - if hasattr(value, "__clause_element__") - else value - ) - else: - params[col.key] = value - - if not bulk: - # for all the columns that have no default and we don't have - # a value and where "None" is not a special value, add - # explicit None to the INSERT. This is a legacy behavior - # which might be worth removing, as it should not be necessary - # and also produces confusion, given that "missing" and None - # now have distinct meanings - for colkey in ( - mapper._insert_cols_as_none[table] - .difference(params) - .difference([c.key for c in value_params]) - ): - params[colkey] = None - - if not bulk or return_defaults: - # params are in terms of Column key objects, so - # compare to pk_keys_by_table - has_all_pks = mapper._pk_keys_by_table[table].issubset(params) - - if mapper.base_mapper._prefer_eager_defaults( - connection.dialect, table - ): - has_all_defaults = mapper._server_default_col_keys[ - table - ].issubset(params) - else: - has_all_defaults = True - else: - has_all_defaults = has_all_pks = True - - if ( - mapper.version_id_generator is not False - and mapper.version_id_col is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ): - params[mapper.version_id_col.key] = mapper.version_id_generator( - None - ) - - if bulk: - if mapper._set_polymorphic_identity: - params.setdefault( - mapper._polymorphic_attr_key, mapper.polymorphic_identity - ) - - if include_bulk_keys: - params.update((k, state_dict[k]) for k in include_bulk_keys) - - yield ( - state, - state_dict, - params, - mapper, - connection, - value_params, - has_all_pks, - has_all_defaults, - ) - - -def _collect_update_commands( - uowtransaction, - table, - states_to_update, - *, - bulk=False, - use_orm_update_stmt=None, - include_bulk_keys=(), -): - """Identify sets of values to use in UPDATE statements for a - list of states. - - This function works intricately with the history system - to determine exactly what values should be updated - as well as how the row should be matched within an UPDATE - statement. Includes some tricky scenarios where the primary - key of an object might have been changed. - - """ - - for ( - state, - state_dict, - mapper, - connection, - update_version_id, - ) in states_to_update: - if table not in mapper._pks_by_table: - continue - - pks = mapper._pks_by_table[table] - - if use_orm_update_stmt is not None: - # TODO: ordered values, etc - value_params = use_orm_update_stmt._values - else: - value_params = {} - - propkey_to_col = mapper._propkey_to_col[table] - - if bulk: - # keys here are mapped attribute keys, so - # look at mapper attribute keys for pk - params = { - propkey_to_col[propkey].key: state_dict[propkey] - for propkey in set(propkey_to_col) - .intersection(state_dict) - .difference(mapper._pk_attr_keys_by_table[table]) - } - has_all_defaults = True - else: - params = {} - for propkey in set(propkey_to_col).intersection( - state.committed_state - ): - value = state_dict[propkey] - col = propkey_to_col[propkey] - - if hasattr(value, "__clause_element__") or isinstance( - value, sql.ClauseElement - ): - value_params[col] = ( - value.__clause_element__() - if hasattr(value, "__clause_element__") - else value - ) - # guard against values that generate non-__nonzero__ - # objects for __eq__() - elif ( - state.manager[propkey].impl.is_equal( - value, state.committed_state[propkey] - ) - is not True - ): - params[col.key] = value - - if mapper.base_mapper.eager_defaults is True: - has_all_defaults = ( - mapper._server_onupdate_default_col_keys[table] - ).issubset(params) - else: - has_all_defaults = True - - if ( - update_version_id is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ): - if not bulk and not (params or value_params): - # HACK: check for history in other tables, in case the - # history is only in a different table than the one - # where the version_id_col is. This logic was lost - # from 0.9 -> 1.0.0 and restored in 1.0.6. - for prop in mapper._columntoproperty.values(): - history = state.manager[prop.key].impl.get_history( - state, state_dict, attributes.PASSIVE_NO_INITIALIZE - ) - if history.added: - break - else: - # no net change, break - continue - - col = mapper.version_id_col - no_params = not params and not value_params - params[col._label] = update_version_id - - if ( - bulk or col.key not in params - ) and mapper.version_id_generator is not False: - val = mapper.version_id_generator(update_version_id) - params[col.key] = val - elif mapper.version_id_generator is False and no_params: - # no version id generator, no values set on the table, - # and version id wasn't manually incremented. - # set version id to itself so we get an UPDATE - # statement - params[col.key] = update_version_id - - elif not (params or value_params): - continue - - has_all_pks = True - expect_pk_cascaded = False - if bulk: - # keys here are mapped attribute keys, so - # look at mapper attribute keys for pk - pk_params = { - propkey_to_col[propkey]._label: state_dict.get(propkey) - for propkey in set(propkey_to_col).intersection( - mapper._pk_attr_keys_by_table[table] - ) - } - if util.NONE_SET.intersection(pk_params.values()): - raise sa_exc.InvalidRequestError( - f"No primary key value supplied for column(s) " - f"""{ - ', '.join( - str(c) for c in pks if pk_params[c._label] is None - ) - }; """ - "per-row ORM Bulk UPDATE by Primary Key requires that " - "records contain primary key values", - code="bupq", - ) - - else: - pk_params = {} - for col in pks: - propkey = mapper._columntoproperty[col].key - - history = state.manager[propkey].impl.get_history( - state, state_dict, attributes.PASSIVE_OFF - ) - - if history.added: - if ( - not history.deleted - or ("pk_cascaded", state, col) - in uowtransaction.attributes - ): - expect_pk_cascaded = True - pk_params[col._label] = history.added[0] - params.pop(col.key, None) - else: - # else, use the old value to locate the row - pk_params[col._label] = history.deleted[0] - if col in value_params: - has_all_pks = False - else: - pk_params[col._label] = history.unchanged[0] - if pk_params[col._label] is None: - raise orm_exc.FlushError( - "Can't update table %s using NULL for primary " - "key value on column %s" % (table, col) - ) - - if include_bulk_keys: - params.update((k, state_dict[k]) for k in include_bulk_keys) - - if params or value_params: - params.update(pk_params) - yield ( - state, - state_dict, - params, - mapper, - connection, - value_params, - has_all_defaults, - has_all_pks, - ) - elif expect_pk_cascaded: - # no UPDATE occurs on this table, but we expect that CASCADE rules - # have changed the primary key of the row; propagate this event to - # other columns that expect to have been modified. this normally - # occurs after the UPDATE is emitted however we invoke it here - # explicitly in the absence of our invoking an UPDATE - for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate( - state, - m, - state, - m, - equated_pairs, - uowtransaction, - mapper.passive_updates, - ) - - -def _collect_post_update_commands( - base_mapper, uowtransaction, table, states_to_update, post_update_cols -): - """Identify sets of values to use in UPDATE statements for a - list of states within a post_update operation. - - """ - - for ( - state, - state_dict, - mapper, - connection, - update_version_id, - ) in states_to_update: - # assert table in mapper._pks_by_table - - pks = mapper._pks_by_table[table] - params = {} - hasdata = False - - for col in mapper._cols_by_table[table]: - if col in pks: - params[col._label] = mapper._get_state_attr_by_column( - state, state_dict, col, passive=attributes.PASSIVE_OFF - ) - - elif col in post_update_cols or col.onupdate is not None: - prop = mapper._columntoproperty[col] - history = state.manager[prop.key].impl.get_history( - state, state_dict, attributes.PASSIVE_NO_INITIALIZE - ) - if history.added: - value = history.added[0] - params[col.key] = value - hasdata = True - if hasdata: - if ( - update_version_id is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ): - col = mapper.version_id_col - params[col._label] = update_version_id - - if ( - bool(state.key) - and col.key not in params - and mapper.version_id_generator is not False - ): - val = mapper.version_id_generator(update_version_id) - params[col.key] = val - yield state, state_dict, mapper, connection, params - - -def _collect_delete_commands( - base_mapper, uowtransaction, table, states_to_delete -): - """Identify values to use in DELETE statements for a list of - states to be deleted.""" - - for ( - state, - state_dict, - mapper, - connection, - update_version_id, - ) in states_to_delete: - if table not in mapper._pks_by_table: - continue - - params = {} - for col in mapper._pks_by_table[table]: - params[col.key] = value = ( - mapper._get_committed_state_attr_by_column( - state, state_dict, col - ) - ) - if value is None: - raise orm_exc.FlushError( - "Can't delete from table %s " - "using NULL for primary " - "key value on column %s" % (table, col) - ) - - if ( - update_version_id is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ): - params[mapper.version_id_col.key] = update_version_id - yield params, connection - - -def _emit_update_statements( - base_mapper, - uowtransaction, - mapper, - table, - update, - *, - bookkeeping=True, - use_orm_update_stmt=None, - enable_check_rowcount=True, -): - """Emit UPDATE statements corresponding to value lists collected - by _collect_update_commands().""" - - needs_version_id = ( - mapper.version_id_col is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ) - - execution_options = {"compiled_cache": base_mapper._compiled_cache} - - def update_stmt(existing_stmt=None): - clauses = BooleanClauseList._construct_raw(operators.and_) - - for col in mapper._pks_by_table[table]: - clauses._append_inplace( - col == sql.bindparam(col._label, type_=col.type) - ) - - if needs_version_id: - clauses._append_inplace( - mapper.version_id_col - == sql.bindparam( - mapper.version_id_col._label, - type_=mapper.version_id_col.type, - ) - ) - - if existing_stmt is not None: - stmt = existing_stmt.where(clauses) - else: - stmt = table.update().where(clauses) - return stmt - - if use_orm_update_stmt is not None: - cached_stmt = update_stmt(use_orm_update_stmt) - - else: - cached_stmt = base_mapper._memo(("update", table), update_stmt) - - for ( - (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), - records, - ) in groupby( - update, - lambda rec: ( - rec[4], # connection - set(rec[2]), # set of parameter keys - bool(rec[5]), # whether or not we have "value" parameters - rec[6], # has_all_defaults - rec[7], # has all pks - ), - ): - rows = 0 - records = list(records) - - statement = cached_stmt - - if use_orm_update_stmt is not None: - statement = statement._annotate( - { - "_emit_update_table": table, - "_emit_update_mapper": mapper, - } - ) - - return_defaults = False - - if not has_all_pks: - statement = statement.return_defaults(*mapper._pks_by_table[table]) - return_defaults = True - - if ( - bookkeeping - and not has_all_defaults - and mapper.base_mapper.eager_defaults is True - # change as of #8889 - if RETURNING is not going to be used anyway, - # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure - # we can do an executemany UPDATE which is more efficient - and table.implicit_returning - and connection.dialect.update_returning - ): - statement = statement.return_defaults( - *mapper._server_onupdate_default_cols[table] - ) - return_defaults = True - - if mapper._version_id_has_server_side_value: - statement = statement.return_defaults(mapper.version_id_col) - return_defaults = True - - assert_singlerow = connection.dialect.supports_sane_rowcount - - assert_multirow = ( - assert_singlerow - and connection.dialect.supports_sane_multi_rowcount - ) - - # change as of #8889 - if RETURNING is not going to be used anyway, - # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure - # we can do an executemany UPDATE which is more efficient - allow_executemany = not return_defaults and not needs_version_id - - if hasvalue: - for ( - state, - state_dict, - params, - mapper, - connection, - value_params, - has_all_defaults, - has_all_pks, - ) in records: - c = connection.execute( - statement.values(value_params), - params, - execution_options=execution_options, - ) - if bookkeeping: - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params, - True, - c.returned_defaults, - ) - rows += c.rowcount - check_rowcount = enable_check_rowcount and assert_singlerow - else: - if not allow_executemany: - check_rowcount = enable_check_rowcount and assert_singlerow - for ( - state, - state_dict, - params, - mapper, - connection, - value_params, - has_all_defaults, - has_all_pks, - ) in records: - c = connection.execute( - statement, params, execution_options=execution_options - ) - - # TODO: why with bookkeeping=False? - if bookkeeping: - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params, - True, - c.returned_defaults, - ) - rows += c.rowcount - else: - multiparams = [rec[2] for rec in records] - - check_rowcount = enable_check_rowcount and ( - assert_multirow - or (assert_singlerow and len(multiparams) == 1) - ) - - c = connection.execute( - statement, multiparams, execution_options=execution_options - ) - - rows += c.rowcount - - for ( - state, - state_dict, - params, - mapper, - connection, - value_params, - has_all_defaults, - has_all_pks, - ) in records: - if bookkeeping: - _postfetch( - mapper, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - value_params, - True, - ( - c.returned_defaults - if not c.context.executemany - else None - ), - ) - - if check_rowcount: - if rows != len(records): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." - % (table.description, len(records), rows) - ) - - elif needs_version_id: - util.warn( - "Dialect %s does not support updated rowcount " - "- versioning cannot be verified." - % c.dialect.dialect_description - ) - - -def _emit_insert_statements( - base_mapper, - uowtransaction, - mapper, - table, - insert, - *, - bookkeeping=True, - use_orm_insert_stmt=None, - execution_options=None, -): - """Emit INSERT statements corresponding to value lists collected - by _collect_insert_commands().""" - - if use_orm_insert_stmt is not None: - cached_stmt = use_orm_insert_stmt - exec_opt = util.EMPTY_DICT - - # if a user query with RETURNING was passed, we definitely need - # to use RETURNING. - returning_is_required_anyway = bool(use_orm_insert_stmt._returning) - deterministic_results_reqd = ( - returning_is_required_anyway - and use_orm_insert_stmt._sort_by_parameter_order - ) or bookkeeping - else: - returning_is_required_anyway = False - deterministic_results_reqd = bookkeeping - cached_stmt = base_mapper._memo(("insert", table), table.insert) - exec_opt = {"compiled_cache": base_mapper._compiled_cache} - - if execution_options: - execution_options = util.EMPTY_DICT.merge_with( - exec_opt, execution_options - ) - else: - execution_options = exec_opt - - return_result = None - - for ( - (connection, _, hasvalue, has_all_pks, has_all_defaults), - records, - ) in groupby( - insert, - lambda rec: ( - rec[4], # connection - set(rec[2]), # parameter keys - bool(rec[5]), # whether we have "value" parameters - rec[6], - rec[7], - ), - ): - statement = cached_stmt - - if use_orm_insert_stmt is not None: - statement = statement._annotate( - { - "_emit_insert_table": table, - "_emit_insert_mapper": mapper, - } - ) - - if ( - ( - not bookkeeping - or ( - has_all_defaults - or not base_mapper._prefer_eager_defaults( - connection.dialect, table - ) - or not table.implicit_returning - or not connection.dialect.insert_returning - ) - ) - and not returning_is_required_anyway - and has_all_pks - and not hasvalue - ): - # the "we don't need newly generated values back" section. - # here we have all the PKs, all the defaults or we don't want - # to fetch them, or the dialect doesn't support RETURNING at all - # so we have to post-fetch / use lastrowid anyway. - records = list(records) - multiparams = [rec[2] for rec in records] - - result = connection.execute( - statement, multiparams, execution_options=execution_options - ) - if bookkeeping: - for ( - ( - state, - state_dict, - params, - mapper_rec, - conn, - value_params, - has_all_pks, - has_all_defaults, - ), - last_inserted_params, - ) in zip(records, result.context.compiled_parameters): - if state: - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - last_inserted_params, - value_params, - False, - ( - result.returned_defaults - if not result.context.executemany - else None - ), - ) - else: - _postfetch_bulk_save(mapper_rec, state_dict, table) - - else: - # here, we need defaults and/or pk values back or we otherwise - # know that we are using RETURNING in any case - - records = list(records) - - if returning_is_required_anyway or ( - table.implicit_returning and not hasvalue and len(records) > 1 - ): - if ( - deterministic_results_reqd - and connection.dialect.insert_executemany_returning_sort_by_parameter_order # noqa: E501 - ) or ( - not deterministic_results_reqd - and connection.dialect.insert_executemany_returning - ): - do_executemany = True - elif returning_is_required_anyway: - if deterministic_results_reqd: - dt = " with RETURNING and sort by parameter order" - else: - dt = " with RETURNING" - raise sa_exc.InvalidRequestError( - f"Can't use explicit RETURNING for bulk INSERT " - f"operation with " - f"{connection.dialect.dialect_description} backend; " - f"executemany{dt} is not enabled for this dialect." - ) - else: - do_executemany = False - else: - do_executemany = False - - if use_orm_insert_stmt is None: - if ( - not has_all_defaults - and base_mapper._prefer_eager_defaults( - connection.dialect, table - ) - ): - statement = statement.return_defaults( - *mapper._server_default_cols[table], - sort_by_parameter_order=bookkeeping, - ) - - if mapper.version_id_col is not None: - statement = statement.return_defaults( - mapper.version_id_col, - sort_by_parameter_order=bookkeeping, - ) - elif do_executemany: - statement = statement.return_defaults( - *table.primary_key, sort_by_parameter_order=bookkeeping - ) - - if do_executemany: - multiparams = [rec[2] for rec in records] - - result = connection.execute( - statement, multiparams, execution_options=execution_options - ) - - if use_orm_insert_stmt is not None: - if return_result is None: - return_result = result - else: - return_result = return_result.splice_vertically(result) - - if bookkeeping: - for ( - ( - state, - state_dict, - params, - mapper_rec, - conn, - value_params, - has_all_pks, - has_all_defaults, - ), - last_inserted_params, - inserted_primary_key, - returned_defaults, - ) in zip_longest( - records, - result.context.compiled_parameters, - result.inserted_primary_key_rows, - result.returned_defaults_rows or (), - ): - if inserted_primary_key is None: - # this is a real problem and means that we didn't - # get back as many PK rows. we can't continue - # since this indicates PK rows were missing, which - # means we likely mis-populated records starting - # at that point with incorrectly matched PK - # values. - raise orm_exc.FlushError( - "Multi-row INSERT statement for %s did not " - "produce " - "the correct number of INSERTed rows for " - "RETURNING. Ensure there are no triggers or " - "special driver issues preventing INSERT from " - "functioning properly." % mapper_rec - ) - - for pk, col in zip( - inserted_primary_key, - mapper._pks_by_table[table], - ): - prop = mapper_rec._columntoproperty[col] - if state_dict.get(prop.key) is None: - state_dict[prop.key] = pk - - if state: - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - last_inserted_params, - value_params, - False, - returned_defaults, - ) - else: - _postfetch_bulk_save(mapper_rec, state_dict, table) - else: - assert not returning_is_required_anyway - - for ( - state, - state_dict, - params, - mapper_rec, - connection, - value_params, - has_all_pks, - has_all_defaults, - ) in records: - if value_params: - result = connection.execute( - statement.values(value_params), - params, - execution_options=execution_options, - ) - else: - result = connection.execute( - statement, - params, - execution_options=execution_options, - ) - - primary_key = result.inserted_primary_key - if primary_key is None: - raise orm_exc.FlushError( - "Single-row INSERT statement for %s " - "did not produce a " - "new primary key result " - "being invoked. Ensure there are no triggers or " - "special driver issues preventing INSERT from " - "functioning properly." % (mapper_rec,) - ) - for pk, col in zip( - primary_key, mapper._pks_by_table[table] - ): - prop = mapper_rec._columntoproperty[col] - if ( - col in value_params - or state_dict.get(prop.key) is None - ): - state_dict[prop.key] = pk - if bookkeeping: - if state: - _postfetch( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - result, - result.context.compiled_parameters[0], - value_params, - False, - ( - result.returned_defaults - if not result.context.executemany - else None - ), - ) - else: - _postfetch_bulk_save(mapper_rec, state_dict, table) - - if use_orm_insert_stmt is not None: - if return_result is None: - return _cursor.null_dml_result() - else: - return return_result - - -def _emit_post_update_statements( - base_mapper, uowtransaction, mapper, table, update -): - """Emit UPDATE statements corresponding to value lists collected - by _collect_post_update_commands().""" - - execution_options = {"compiled_cache": base_mapper._compiled_cache} - - needs_version_id = ( - mapper.version_id_col is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ) - - def update_stmt(): - clauses = BooleanClauseList._construct_raw(operators.and_) - - for col in mapper._pks_by_table[table]: - clauses._append_inplace( - col == sql.bindparam(col._label, type_=col.type) - ) - - if needs_version_id: - clauses._append_inplace( - mapper.version_id_col - == sql.bindparam( - mapper.version_id_col._label, - type_=mapper.version_id_col.type, - ) - ) - - stmt = table.update().where(clauses) - - return stmt - - statement = base_mapper._memo(("post_update", table), update_stmt) - - if mapper._version_id_has_server_side_value: - statement = statement.return_defaults(mapper.version_id_col) - - # execute each UPDATE in the order according to the original - # list of states to guarantee row access order, but - # also group them into common (connection, cols) sets - # to support executemany(). - for key, records in groupby( - update, - lambda rec: (rec[3], set(rec[4])), # connection # parameter keys - ): - rows = 0 - - records = list(records) - connection = key[0] - - assert_singlerow = connection.dialect.supports_sane_rowcount - assert_multirow = ( - assert_singlerow - and connection.dialect.supports_sane_multi_rowcount - ) - allow_executemany = not needs_version_id or assert_multirow - - if not allow_executemany: - check_rowcount = assert_singlerow - for state, state_dict, mapper_rec, connection, params in records: - c = connection.execute( - statement, params, execution_options=execution_options - ) - - _postfetch_post_update( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - ) - rows += c.rowcount - else: - multiparams = [ - params - for state, state_dict, mapper_rec, conn, params in records - ] - - check_rowcount = assert_multirow or ( - assert_singlerow and len(multiparams) == 1 - ) - - c = connection.execute( - statement, multiparams, execution_options=execution_options - ) - - rows += c.rowcount - for state, state_dict, mapper_rec, connection, params in records: - _postfetch_post_update( - mapper_rec, - uowtransaction, - table, - state, - state_dict, - c, - c.context.compiled_parameters[0], - ) - - if check_rowcount: - if rows != len(records): - raise orm_exc.StaleDataError( - "UPDATE statement on table '%s' expected to " - "update %d row(s); %d were matched." - % (table.description, len(records), rows) - ) - - elif needs_version_id: - util.warn( - "Dialect %s does not support updated rowcount " - "- versioning cannot be verified." - % c.dialect.dialect_description - ) - - -def _emit_delete_statements( - base_mapper, uowtransaction, mapper, table, delete -): - """Emit DELETE statements corresponding to value lists collected - by _collect_delete_commands().""" - - need_version_id = ( - mapper.version_id_col is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ) - - def delete_stmt(): - clauses = BooleanClauseList._construct_raw(operators.and_) - - for col in mapper._pks_by_table[table]: - clauses._append_inplace( - col == sql.bindparam(col.key, type_=col.type) - ) - - if need_version_id: - clauses._append_inplace( - mapper.version_id_col - == sql.bindparam( - mapper.version_id_col.key, type_=mapper.version_id_col.type - ) - ) - - return table.delete().where(clauses) - - statement = base_mapper._memo(("delete", table), delete_stmt) - for connection, recs in groupby(delete, lambda rec: rec[1]): # connection - del_objects = [params for params, connection in recs] - - execution_options = {"compiled_cache": base_mapper._compiled_cache} - expected = len(del_objects) - rows_matched = -1 - only_warn = False - - if ( - need_version_id - and not connection.dialect.supports_sane_multi_rowcount - ): - if connection.dialect.supports_sane_rowcount: - rows_matched = 0 - # execute deletes individually so that versioned - # rows can be verified - for params in del_objects: - c = connection.execute( - statement, params, execution_options=execution_options - ) - rows_matched += c.rowcount - else: - util.warn( - "Dialect %s does not support deleted rowcount " - "- versioning cannot be verified." - % connection.dialect.dialect_description - ) - connection.execute( - statement, del_objects, execution_options=execution_options - ) - else: - c = connection.execute( - statement, del_objects, execution_options=execution_options - ) - - if not need_version_id: - only_warn = True - - rows_matched = c.rowcount - - if ( - base_mapper.confirm_deleted_rows - and rows_matched > -1 - and expected != rows_matched - and ( - connection.dialect.supports_sane_multi_rowcount - or len(del_objects) == 1 - ) - ): - # TODO: why does this "only warn" if versioning is turned off, - # whereas the UPDATE raises? - if only_warn: - util.warn( - "DELETE statement on table '%s' expected to " - "delete %d row(s); %d were matched. Please set " - "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." - % (table.description, expected, rows_matched) - ) - else: - raise orm_exc.StaleDataError( - "DELETE statement on table '%s' expected to " - "delete %d row(s); %d were matched. Please set " - "confirm_deleted_rows=False within the mapper " - "configuration to prevent this warning." - % (table.description, expected, rows_matched) - ) - - -def _finalize_insert_update_commands(base_mapper, uowtransaction, states): - """finalize state on states that have been inserted or updated, - including calling after_insert/after_update events. - - """ - for state, state_dict, mapper, connection, has_identity in states: - if mapper._readonly_props: - readonly = state.unmodified_intersection( - [ - p.key - for p in mapper._readonly_props - if ( - p.expire_on_flush - and (not p.deferred or p.key in state.dict) - ) - or ( - not p.expire_on_flush - and not p.deferred - and p.key not in state.dict - ) - ] - ) - if readonly: - state._expire_attributes(state.dict, readonly) - - # if eager_defaults option is enabled, load - # all expired cols. Else if we have a version_id_col, make sure - # it isn't expired. - toload_now = [] - - # this is specifically to emit a second SELECT for eager_defaults, - # so only if it's set to True, not "auto" - if base_mapper.eager_defaults is True: - toload_now.extend( - state._unloaded_non_object.intersection( - mapper._server_default_plus_onupdate_propkeys - ) - ) - - if ( - mapper.version_id_col is not None - and mapper.version_id_generator is False - ): - if mapper._version_id_prop.key in state.unloaded: - toload_now.extend([mapper._version_id_prop.key]) - - if toload_now: - state.key = base_mapper._identity_key_from_state(state) - stmt = future.select(mapper).set_label_style( - LABEL_STYLE_TABLENAME_PLUS_COL - ) - loading.load_on_ident( - uowtransaction.session, - stmt, - state.key, - refresh_state=state, - only_load_props=toload_now, - ) - - # call after_XXX extensions - if not has_identity: - mapper.dispatch.after_insert(mapper, connection, state) - else: - mapper.dispatch.after_update(mapper, connection, state) - - if ( - mapper.version_id_generator is False - and mapper.version_id_col is not None - ): - if state_dict[mapper._version_id_prop.key] is None: - raise orm_exc.FlushError( - "Instance does not contain a non-NULL version value" - ) - - -def _postfetch_post_update( - mapper, uowtransaction, table, state, dict_, result, params -): - needs_version_id = ( - mapper.version_id_col is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ) - - if not uowtransaction.is_deleted(state): - # post updating after a regular INSERT or UPDATE, do a full postfetch - prefetch_cols = result.context.compiled.prefetch - postfetch_cols = result.context.compiled.postfetch - elif needs_version_id: - # post updating before a DELETE with a version_id_col, need to - # postfetch just version_id_col - prefetch_cols = postfetch_cols = () - else: - # post updating before a DELETE without a version_id_col, - # don't need to postfetch - return - - if needs_version_id: - prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] - - refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) - if refresh_flush: - load_evt_attrs = [] - - for c in prefetch_cols: - if c.key in params and c in mapper._columntoproperty: - dict_[mapper._columntoproperty[c].key] = params[c.key] - if refresh_flush: - load_evt_attrs.append(mapper._columntoproperty[c].key) - - if refresh_flush and load_evt_attrs: - mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs - ) - - if postfetch_cols: - state._expire_attributes( - state.dict, - [ - mapper._columntoproperty[c].key - for c in postfetch_cols - if c in mapper._columntoproperty - ], - ) - - -def _postfetch( - mapper, - uowtransaction, - table, - state, - dict_, - result, - params, - value_params, - isupdate, - returned_defaults, -): - """Expire attributes in need of newly persisted database state, - after an INSERT or UPDATE statement has proceeded for that - state.""" - - prefetch_cols = result.context.compiled.prefetch - postfetch_cols = result.context.compiled.postfetch - returning_cols = result.context.compiled.effective_returning - - if ( - mapper.version_id_col is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ): - prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] - - refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) - if refresh_flush: - load_evt_attrs = [] - - if returning_cols: - row = returned_defaults - if row is not None: - for row_value, col in zip(row, returning_cols): - # pk cols returned from insert are handled - # distinctly, don't step on the values here - if col.primary_key and result.context.isinsert: - continue - - # note that columns can be in the "return defaults" that are - # not mapped to this mapper, typically because they are - # "excluded", which can be specified directly or also occurs - # when using declarative w/ single table inheritance - prop = mapper._columntoproperty.get(col) - if prop: - dict_[prop.key] = row_value - if refresh_flush: - load_evt_attrs.append(prop.key) - - for c in prefetch_cols: - if c.key in params and c in mapper._columntoproperty: - pkey = mapper._columntoproperty[c].key - - # set prefetched value in dict and also pop from committed_state, - # since this is new database state that replaces whatever might - # have previously been fetched (see #10800). this is essentially a - # shorthand version of set_committed_value(), which could also be - # used here directly (with more overhead) - dict_[pkey] = params[c.key] - state.committed_state.pop(pkey, None) - - if refresh_flush: - load_evt_attrs.append(pkey) - - if refresh_flush and load_evt_attrs: - mapper.class_manager.dispatch.refresh_flush( - state, uowtransaction, load_evt_attrs - ) - - if isupdate and value_params: - # explicitly suit the use case specified by - # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING - # database which are set to themselves in order to do a version bump. - postfetch_cols.extend( - [ - col - for col in value_params - if col.primary_key and col not in returning_cols - ] - ) - - if postfetch_cols: - state._expire_attributes( - state.dict, - [ - mapper._columntoproperty[c].key - for c in postfetch_cols - if c in mapper._columntoproperty - ], - ) - - # synchronize newly inserted ids from one table to the next - # TODO: this still goes a little too often. would be nice to - # have definitive list of "columns that changed" here - for m, equated_pairs in mapper._table_to_equated[table]: - sync.populate( - state, - m, - state, - m, - equated_pairs, - uowtransaction, - mapper.passive_updates, - ) - - -def _postfetch_bulk_save(mapper, dict_, table): - for m, equated_pairs in mapper._table_to_equated[table]: - sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) - - -def _connections_for_states(base_mapper, uowtransaction, states): - """Return an iterator of (state, state.dict, mapper, connection). - - The states are sorted according to _sort_states, then paired - with the connection they should be using for the given - unit of work transaction. - - """ - # if session has a connection callable, - # organize individual states with the connection - # to use for update - if uowtransaction.session.connection_callable: - connection_callable = uowtransaction.session.connection_callable - else: - connection = uowtransaction.transaction.connection(base_mapper) - connection_callable = None - - for state in _sort_states(base_mapper, states): - if connection_callable: - connection = connection_callable(base_mapper, state.obj()) - - mapper = state.manager.mapper - - yield state, state.dict, mapper, connection - - -def _sort_states(mapper, states): - pending = set(states) - persistent = {s for s in pending if s.key is not None} - pending.difference_update(persistent) - - try: - persistent_sorted = sorted( - persistent, key=mapper._persistent_sortkey_fn - ) - except TypeError as err: - raise sa_exc.InvalidRequestError( - "Could not sort objects by primary key; primary key " - "values must be sortable in Python (was: %s)" % err - ) from err - return ( - sorted(pending, key=operator.attrgetter("insert_order")) - + persistent_sorted - ) |