diff options
| author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
|---|---|---|
| committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 | 
| commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
| tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py | |
| parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) | |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py')
| -rw-r--r-- | venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py | 1782 | 
1 files changed, 1782 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py b/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py new file mode 100644 index 0000000..369fc59 --- /dev/null +++ b/venv/lib/python3.11/site-packages/sqlalchemy/orm/persistence.py @@ -0,0 +1,1782 @@ +# orm/persistence.py +# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +"""private module containing functions used to emit INSERT, UPDATE +and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending +mappers. + +The functions here are called only by the unit of work functions +in unitofwork.py. + +""" +from __future__ import annotations + +from itertools import chain +from itertools import groupby +from itertools import zip_longest +import operator + +from . import attributes +from . import exc as orm_exc +from . import loading +from . import sync +from .base import state_str +from .. import exc as sa_exc +from .. import future +from .. import sql +from .. import util +from ..engine import cursor as _cursor +from ..sql import operators +from ..sql.elements import BooleanClauseList +from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL + + +def save_obj(base_mapper, states, uowtransaction, single=False): +    """Issue ``INSERT`` and/or ``UPDATE`` statements for a list +    of objects. + +    This is called within the context of a UOWTransaction during a +    flush operation, given a list of states to be flushed.  The +    base mapper in an inheritance hierarchy handles the inserts/ +    updates for all descendant mappers. + +    """ + +    # if batch=false, call _save_obj separately for each object +    if not single and not base_mapper.batch: +        for state in _sort_states(base_mapper, states): +            save_obj(base_mapper, [state], uowtransaction, single=True) +        return + +    states_to_update = [] +    states_to_insert = [] + +    for ( +        state, +        dict_, +        mapper, +        connection, +        has_identity, +        row_switch, +        update_version_id, +    ) in _organize_states_for_save(base_mapper, states, uowtransaction): +        if has_identity or row_switch: +            states_to_update.append( +                (state, dict_, mapper, connection, update_version_id) +            ) +        else: +            states_to_insert.append((state, dict_, mapper, connection)) + +    for table, mapper in base_mapper._sorted_tables.items(): +        if table not in mapper._pks_by_table: +            continue +        insert = _collect_insert_commands(table, states_to_insert) + +        update = _collect_update_commands( +            uowtransaction, table, states_to_update +        ) + +        _emit_update_statements( +            base_mapper, +            uowtransaction, +            mapper, +            table, +            update, +        ) + +        _emit_insert_statements( +            base_mapper, +            uowtransaction, +            mapper, +            table, +            insert, +        ) + +    _finalize_insert_update_commands( +        base_mapper, +        uowtransaction, +        chain( +            ( +                (state, state_dict, mapper, connection, False) +                for (state, state_dict, mapper, connection) in states_to_insert +            ), +            ( +                (state, state_dict, mapper, connection, True) +                for ( +                    state, +                    state_dict, +                    mapper, +                    connection, +                    update_version_id, +                ) in states_to_update +            ), +        ), +    ) + + +def post_update(base_mapper, states, uowtransaction, post_update_cols): +    """Issue UPDATE statements on behalf of a relationship() which +    specifies post_update. + +    """ + +    states_to_update = list( +        _organize_states_for_post_update(base_mapper, states, uowtransaction) +    ) + +    for table, mapper in base_mapper._sorted_tables.items(): +        if table not in mapper._pks_by_table: +            continue + +        update = ( +            ( +                state, +                state_dict, +                sub_mapper, +                connection, +                ( +                    mapper._get_committed_state_attr_by_column( +                        state, state_dict, mapper.version_id_col +                    ) +                    if mapper.version_id_col is not None +                    else None +                ), +            ) +            for state, state_dict, sub_mapper, connection in states_to_update +            if table in sub_mapper._pks_by_table +        ) + +        update = _collect_post_update_commands( +            base_mapper, uowtransaction, table, update, post_update_cols +        ) + +        _emit_post_update_statements( +            base_mapper, +            uowtransaction, +            mapper, +            table, +            update, +        ) + + +def delete_obj(base_mapper, states, uowtransaction): +    """Issue ``DELETE`` statements for a list of objects. + +    This is called within the context of a UOWTransaction during a +    flush operation. + +    """ + +    states_to_delete = list( +        _organize_states_for_delete(base_mapper, states, uowtransaction) +    ) + +    table_to_mapper = base_mapper._sorted_tables + +    for table in reversed(list(table_to_mapper.keys())): +        mapper = table_to_mapper[table] +        if table not in mapper._pks_by_table: +            continue +        elif mapper.inherits and mapper.passive_deletes: +            continue + +        delete = _collect_delete_commands( +            base_mapper, uowtransaction, table, states_to_delete +        ) + +        _emit_delete_statements( +            base_mapper, +            uowtransaction, +            mapper, +            table, +            delete, +        ) + +    for ( +        state, +        state_dict, +        mapper, +        connection, +        update_version_id, +    ) in states_to_delete: +        mapper.dispatch.after_delete(mapper, connection, state) + + +def _organize_states_for_save(base_mapper, states, uowtransaction): +    """Make an initial pass across a set of states for INSERT or +    UPDATE. + +    This includes splitting out into distinct lists for +    each, calling before_insert/before_update, obtaining +    key information for each state including its dictionary, +    mapper, the connection to use for the execution per state, +    and the identity flag. + +    """ + +    for state, dict_, mapper, connection in _connections_for_states( +        base_mapper, uowtransaction, states +    ): +        has_identity = bool(state.key) + +        instance_key = state.key or mapper._identity_key_from_state(state) + +        row_switch = update_version_id = None + +        # call before_XXX extensions +        if not has_identity: +            mapper.dispatch.before_insert(mapper, connection, state) +        else: +            mapper.dispatch.before_update(mapper, connection, state) + +        if mapper._validate_polymorphic_identity: +            mapper._validate_polymorphic_identity(mapper, state, dict_) + +        # detect if we have a "pending" instance (i.e. has +        # no instance_key attached to it), and another instance +        # with the same identity key already exists as persistent. +        # convert to an UPDATE if so. +        if ( +            not has_identity +            and instance_key in uowtransaction.session.identity_map +        ): +            instance = uowtransaction.session.identity_map[instance_key] +            existing = attributes.instance_state(instance) + +            if not uowtransaction.was_already_deleted(existing): +                if not uowtransaction.is_deleted(existing): +                    util.warn( +                        "New instance %s with identity key %s conflicts " +                        "with persistent instance %s" +                        % (state_str(state), instance_key, state_str(existing)) +                    ) +                else: +                    base_mapper._log_debug( +                        "detected row switch for identity %s.  " +                        "will update %s, remove %s from " +                        "transaction", +                        instance_key, +                        state_str(state), +                        state_str(existing), +                    ) + +                    # remove the "delete" flag from the existing element +                    uowtransaction.remove_state_actions(existing) +                    row_switch = existing + +        if (has_identity or row_switch) and mapper.version_id_col is not None: +            update_version_id = mapper._get_committed_state_attr_by_column( +                row_switch if row_switch else state, +                row_switch.dict if row_switch else dict_, +                mapper.version_id_col, +            ) + +        yield ( +            state, +            dict_, +            mapper, +            connection, +            has_identity, +            row_switch, +            update_version_id, +        ) + + +def _organize_states_for_post_update(base_mapper, states, uowtransaction): +    """Make an initial pass across a set of states for UPDATE +    corresponding to post_update. + +    This includes obtaining key information for each state +    including its dictionary, mapper, the connection to use for +    the execution per state. + +    """ +    return _connections_for_states(base_mapper, uowtransaction, states) + + +def _organize_states_for_delete(base_mapper, states, uowtransaction): +    """Make an initial pass across a set of states for DELETE. + +    This includes calling out before_delete and obtaining +    key information for each state including its dictionary, +    mapper, the connection to use for the execution per state. + +    """ +    for state, dict_, mapper, connection in _connections_for_states( +        base_mapper, uowtransaction, states +    ): +        mapper.dispatch.before_delete(mapper, connection, state) + +        if mapper.version_id_col is not None: +            update_version_id = mapper._get_committed_state_attr_by_column( +                state, dict_, mapper.version_id_col +            ) +        else: +            update_version_id = None + +        yield (state, dict_, mapper, connection, update_version_id) + + +def _collect_insert_commands( +    table, +    states_to_insert, +    *, +    bulk=False, +    return_defaults=False, +    render_nulls=False, +    include_bulk_keys=(), +): +    """Identify sets of values to use in INSERT statements for a +    list of states. + +    """ +    for state, state_dict, mapper, connection in states_to_insert: +        if table not in mapper._pks_by_table: +            continue + +        params = {} +        value_params = {} + +        propkey_to_col = mapper._propkey_to_col[table] + +        eval_none = mapper._insert_cols_evaluating_none[table] + +        for propkey in set(propkey_to_col).intersection(state_dict): +            value = state_dict[propkey] +            col = propkey_to_col[propkey] +            if value is None and col not in eval_none and not render_nulls: +                continue +            elif not bulk and ( +                hasattr(value, "__clause_element__") +                or isinstance(value, sql.ClauseElement) +            ): +                value_params[col] = ( +                    value.__clause_element__() +                    if hasattr(value, "__clause_element__") +                    else value +                ) +            else: +                params[col.key] = value + +        if not bulk: +            # for all the columns that have no default and we don't have +            # a value and where "None" is not a special value, add +            # explicit None to the INSERT.   This is a legacy behavior +            # which might be worth removing, as it should not be necessary +            # and also produces confusion, given that "missing" and None +            # now have distinct meanings +            for colkey in ( +                mapper._insert_cols_as_none[table] +                .difference(params) +                .difference([c.key for c in value_params]) +            ): +                params[colkey] = None + +        if not bulk or return_defaults: +            # params are in terms of Column key objects, so +            # compare to pk_keys_by_table +            has_all_pks = mapper._pk_keys_by_table[table].issubset(params) + +            if mapper.base_mapper._prefer_eager_defaults( +                connection.dialect, table +            ): +                has_all_defaults = mapper._server_default_col_keys[ +                    table +                ].issubset(params) +            else: +                has_all_defaults = True +        else: +            has_all_defaults = has_all_pks = True + +        if ( +            mapper.version_id_generator is not False +            and mapper.version_id_col is not None +            and mapper.version_id_col in mapper._cols_by_table[table] +        ): +            params[mapper.version_id_col.key] = mapper.version_id_generator( +                None +            ) + +        if bulk: +            if mapper._set_polymorphic_identity: +                params.setdefault( +                    mapper._polymorphic_attr_key, mapper.polymorphic_identity +                ) + +            if include_bulk_keys: +                params.update((k, state_dict[k]) for k in include_bulk_keys) + +        yield ( +            state, +            state_dict, +            params, +            mapper, +            connection, +            value_params, +            has_all_pks, +            has_all_defaults, +        ) + + +def _collect_update_commands( +    uowtransaction, +    table, +    states_to_update, +    *, +    bulk=False, +    use_orm_update_stmt=None, +    include_bulk_keys=(), +): +    """Identify sets of values to use in UPDATE statements for a +    list of states. + +    This function works intricately with the history system +    to determine exactly what values should be updated +    as well as how the row should be matched within an UPDATE +    statement.  Includes some tricky scenarios where the primary +    key of an object might have been changed. + +    """ + +    for ( +        state, +        state_dict, +        mapper, +        connection, +        update_version_id, +    ) in states_to_update: +        if table not in mapper._pks_by_table: +            continue + +        pks = mapper._pks_by_table[table] + +        if use_orm_update_stmt is not None: +            # TODO: ordered values, etc +            value_params = use_orm_update_stmt._values +        else: +            value_params = {} + +        propkey_to_col = mapper._propkey_to_col[table] + +        if bulk: +            # keys here are mapped attribute keys, so +            # look at mapper attribute keys for pk +            params = { +                propkey_to_col[propkey].key: state_dict[propkey] +                for propkey in set(propkey_to_col) +                .intersection(state_dict) +                .difference(mapper._pk_attr_keys_by_table[table]) +            } +            has_all_defaults = True +        else: +            params = {} +            for propkey in set(propkey_to_col).intersection( +                state.committed_state +            ): +                value = state_dict[propkey] +                col = propkey_to_col[propkey] + +                if hasattr(value, "__clause_element__") or isinstance( +                    value, sql.ClauseElement +                ): +                    value_params[col] = ( +                        value.__clause_element__() +                        if hasattr(value, "__clause_element__") +                        else value +                    ) +                # guard against values that generate non-__nonzero__ +                # objects for __eq__() +                elif ( +                    state.manager[propkey].impl.is_equal( +                        value, state.committed_state[propkey] +                    ) +                    is not True +                ): +                    params[col.key] = value + +            if mapper.base_mapper.eager_defaults is True: +                has_all_defaults = ( +                    mapper._server_onupdate_default_col_keys[table] +                ).issubset(params) +            else: +                has_all_defaults = True + +        if ( +            update_version_id is not None +            and mapper.version_id_col in mapper._cols_by_table[table] +        ): +            if not bulk and not (params or value_params): +                # HACK: check for history in other tables, in case the +                # history is only in a different table than the one +                # where the version_id_col is.  This logic was lost +                # from 0.9 -> 1.0.0 and restored in 1.0.6. +                for prop in mapper._columntoproperty.values(): +                    history = state.manager[prop.key].impl.get_history( +                        state, state_dict, attributes.PASSIVE_NO_INITIALIZE +                    ) +                    if history.added: +                        break +                else: +                    # no net change, break +                    continue + +            col = mapper.version_id_col +            no_params = not params and not value_params +            params[col._label] = update_version_id + +            if ( +                bulk or col.key not in params +            ) and mapper.version_id_generator is not False: +                val = mapper.version_id_generator(update_version_id) +                params[col.key] = val +            elif mapper.version_id_generator is False and no_params: +                # no version id generator, no values set on the table, +                # and version id wasn't manually incremented. +                # set version id to itself so we get an UPDATE +                # statement +                params[col.key] = update_version_id + +        elif not (params or value_params): +            continue + +        has_all_pks = True +        expect_pk_cascaded = False +        if bulk: +            # keys here are mapped attribute keys, so +            # look at mapper attribute keys for pk +            pk_params = { +                propkey_to_col[propkey]._label: state_dict.get(propkey) +                for propkey in set(propkey_to_col).intersection( +                    mapper._pk_attr_keys_by_table[table] +                ) +            } +            if util.NONE_SET.intersection(pk_params.values()): +                raise sa_exc.InvalidRequestError( +                    f"No primary key value supplied for column(s) " +                    f"""{ +                        ', '.join( +                            str(c) for c in pks if pk_params[c._label] is None +                        ) +                    }; """ +                    "per-row ORM Bulk UPDATE by Primary Key requires that " +                    "records contain primary key values", +                    code="bupq", +                ) + +        else: +            pk_params = {} +            for col in pks: +                propkey = mapper._columntoproperty[col].key + +                history = state.manager[propkey].impl.get_history( +                    state, state_dict, attributes.PASSIVE_OFF +                ) + +                if history.added: +                    if ( +                        not history.deleted +                        or ("pk_cascaded", state, col) +                        in uowtransaction.attributes +                    ): +                        expect_pk_cascaded = True +                        pk_params[col._label] = history.added[0] +                        params.pop(col.key, None) +                    else: +                        # else, use the old value to locate the row +                        pk_params[col._label] = history.deleted[0] +                        if col in value_params: +                            has_all_pks = False +                else: +                    pk_params[col._label] = history.unchanged[0] +                if pk_params[col._label] is None: +                    raise orm_exc.FlushError( +                        "Can't update table %s using NULL for primary " +                        "key value on column %s" % (table, col) +                    ) + +        if include_bulk_keys: +            params.update((k, state_dict[k]) for k in include_bulk_keys) + +        if params or value_params: +            params.update(pk_params) +            yield ( +                state, +                state_dict, +                params, +                mapper, +                connection, +                value_params, +                has_all_defaults, +                has_all_pks, +            ) +        elif expect_pk_cascaded: +            # no UPDATE occurs on this table, but we expect that CASCADE rules +            # have changed the primary key of the row; propagate this event to +            # other columns that expect to have been modified. this normally +            # occurs after the UPDATE is emitted however we invoke it here +            # explicitly in the absence of our invoking an UPDATE +            for m, equated_pairs in mapper._table_to_equated[table]: +                sync.populate( +                    state, +                    m, +                    state, +                    m, +                    equated_pairs, +                    uowtransaction, +                    mapper.passive_updates, +                ) + + +def _collect_post_update_commands( +    base_mapper, uowtransaction, table, states_to_update, post_update_cols +): +    """Identify sets of values to use in UPDATE statements for a +    list of states within a post_update operation. + +    """ + +    for ( +        state, +        state_dict, +        mapper, +        connection, +        update_version_id, +    ) in states_to_update: +        # assert table in mapper._pks_by_table + +        pks = mapper._pks_by_table[table] +        params = {} +        hasdata = False + +        for col in mapper._cols_by_table[table]: +            if col in pks: +                params[col._label] = mapper._get_state_attr_by_column( +                    state, state_dict, col, passive=attributes.PASSIVE_OFF +                ) + +            elif col in post_update_cols or col.onupdate is not None: +                prop = mapper._columntoproperty[col] +                history = state.manager[prop.key].impl.get_history( +                    state, state_dict, attributes.PASSIVE_NO_INITIALIZE +                ) +                if history.added: +                    value = history.added[0] +                    params[col.key] = value +                    hasdata = True +        if hasdata: +            if ( +                update_version_id is not None +                and mapper.version_id_col in mapper._cols_by_table[table] +            ): +                col = mapper.version_id_col +                params[col._label] = update_version_id + +                if ( +                    bool(state.key) +                    and col.key not in params +                    and mapper.version_id_generator is not False +                ): +                    val = mapper.version_id_generator(update_version_id) +                    params[col.key] = val +            yield state, state_dict, mapper, connection, params + + +def _collect_delete_commands( +    base_mapper, uowtransaction, table, states_to_delete +): +    """Identify values to use in DELETE statements for a list of +    states to be deleted.""" + +    for ( +        state, +        state_dict, +        mapper, +        connection, +        update_version_id, +    ) in states_to_delete: +        if table not in mapper._pks_by_table: +            continue + +        params = {} +        for col in mapper._pks_by_table[table]: +            params[col.key] = value = ( +                mapper._get_committed_state_attr_by_column( +                    state, state_dict, col +                ) +            ) +            if value is None: +                raise orm_exc.FlushError( +                    "Can't delete from table %s " +                    "using NULL for primary " +                    "key value on column %s" % (table, col) +                ) + +        if ( +            update_version_id is not None +            and mapper.version_id_col in mapper._cols_by_table[table] +        ): +            params[mapper.version_id_col.key] = update_version_id +        yield params, connection + + +def _emit_update_statements( +    base_mapper, +    uowtransaction, +    mapper, +    table, +    update, +    *, +    bookkeeping=True, +    use_orm_update_stmt=None, +    enable_check_rowcount=True, +): +    """Emit UPDATE statements corresponding to value lists collected +    by _collect_update_commands().""" + +    needs_version_id = ( +        mapper.version_id_col is not None +        and mapper.version_id_col in mapper._cols_by_table[table] +    ) + +    execution_options = {"compiled_cache": base_mapper._compiled_cache} + +    def update_stmt(existing_stmt=None): +        clauses = BooleanClauseList._construct_raw(operators.and_) + +        for col in mapper._pks_by_table[table]: +            clauses._append_inplace( +                col == sql.bindparam(col._label, type_=col.type) +            ) + +        if needs_version_id: +            clauses._append_inplace( +                mapper.version_id_col +                == sql.bindparam( +                    mapper.version_id_col._label, +                    type_=mapper.version_id_col.type, +                ) +            ) + +        if existing_stmt is not None: +            stmt = existing_stmt.where(clauses) +        else: +            stmt = table.update().where(clauses) +        return stmt + +    if use_orm_update_stmt is not None: +        cached_stmt = update_stmt(use_orm_update_stmt) + +    else: +        cached_stmt = base_mapper._memo(("update", table), update_stmt) + +    for ( +        (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks), +        records, +    ) in groupby( +        update, +        lambda rec: ( +            rec[4],  # connection +            set(rec[2]),  # set of parameter keys +            bool(rec[5]),  # whether or not we have "value" parameters +            rec[6],  # has_all_defaults +            rec[7],  # has all pks +        ), +    ): +        rows = 0 +        records = list(records) + +        statement = cached_stmt + +        if use_orm_update_stmt is not None: +            statement = statement._annotate( +                { +                    "_emit_update_table": table, +                    "_emit_update_mapper": mapper, +                } +            ) + +        return_defaults = False + +        if not has_all_pks: +            statement = statement.return_defaults(*mapper._pks_by_table[table]) +            return_defaults = True + +        if ( +            bookkeeping +            and not has_all_defaults +            and mapper.base_mapper.eager_defaults is True +            # change as of #8889 - if RETURNING is not going to be used anyway, +            # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure +            # we can do an executemany UPDATE which is more efficient +            and table.implicit_returning +            and connection.dialect.update_returning +        ): +            statement = statement.return_defaults( +                *mapper._server_onupdate_default_cols[table] +            ) +            return_defaults = True + +        if mapper._version_id_has_server_side_value: +            statement = statement.return_defaults(mapper.version_id_col) +            return_defaults = True + +        assert_singlerow = connection.dialect.supports_sane_rowcount + +        assert_multirow = ( +            assert_singlerow +            and connection.dialect.supports_sane_multi_rowcount +        ) + +        # change as of #8889 - if RETURNING is not going to be used anyway, +        # (applies to MySQL, MariaDB which lack UPDATE RETURNING) ensure +        # we can do an executemany UPDATE which is more efficient +        allow_executemany = not return_defaults and not needs_version_id + +        if hasvalue: +            for ( +                state, +                state_dict, +                params, +                mapper, +                connection, +                value_params, +                has_all_defaults, +                has_all_pks, +            ) in records: +                c = connection.execute( +                    statement.values(value_params), +                    params, +                    execution_options=execution_options, +                ) +                if bookkeeping: +                    _postfetch( +                        mapper, +                        uowtransaction, +                        table, +                        state, +                        state_dict, +                        c, +                        c.context.compiled_parameters[0], +                        value_params, +                        True, +                        c.returned_defaults, +                    ) +                rows += c.rowcount +                check_rowcount = enable_check_rowcount and assert_singlerow +        else: +            if not allow_executemany: +                check_rowcount = enable_check_rowcount and assert_singlerow +                for ( +                    state, +                    state_dict, +                    params, +                    mapper, +                    connection, +                    value_params, +                    has_all_defaults, +                    has_all_pks, +                ) in records: +                    c = connection.execute( +                        statement, params, execution_options=execution_options +                    ) + +                    # TODO: why with bookkeeping=False? +                    if bookkeeping: +                        _postfetch( +                            mapper, +                            uowtransaction, +                            table, +                            state, +                            state_dict, +                            c, +                            c.context.compiled_parameters[0], +                            value_params, +                            True, +                            c.returned_defaults, +                        ) +                    rows += c.rowcount +            else: +                multiparams = [rec[2] for rec in records] + +                check_rowcount = enable_check_rowcount and ( +                    assert_multirow +                    or (assert_singlerow and len(multiparams) == 1) +                ) + +                c = connection.execute( +                    statement, multiparams, execution_options=execution_options +                ) + +                rows += c.rowcount + +                for ( +                    state, +                    state_dict, +                    params, +                    mapper, +                    connection, +                    value_params, +                    has_all_defaults, +                    has_all_pks, +                ) in records: +                    if bookkeeping: +                        _postfetch( +                            mapper, +                            uowtransaction, +                            table, +                            state, +                            state_dict, +                            c, +                            c.context.compiled_parameters[0], +                            value_params, +                            True, +                            ( +                                c.returned_defaults +                                if not c.context.executemany +                                else None +                            ), +                        ) + +        if check_rowcount: +            if rows != len(records): +                raise orm_exc.StaleDataError( +                    "UPDATE statement on table '%s' expected to " +                    "update %d row(s); %d were matched." +                    % (table.description, len(records), rows) +                ) + +        elif needs_version_id: +            util.warn( +                "Dialect %s does not support updated rowcount " +                "- versioning cannot be verified." +                % c.dialect.dialect_description +            ) + + +def _emit_insert_statements( +    base_mapper, +    uowtransaction, +    mapper, +    table, +    insert, +    *, +    bookkeeping=True, +    use_orm_insert_stmt=None, +    execution_options=None, +): +    """Emit INSERT statements corresponding to value lists collected +    by _collect_insert_commands().""" + +    if use_orm_insert_stmt is not None: +        cached_stmt = use_orm_insert_stmt +        exec_opt = util.EMPTY_DICT + +        # if a user query with RETURNING was passed, we definitely need +        # to use RETURNING. +        returning_is_required_anyway = bool(use_orm_insert_stmt._returning) +        deterministic_results_reqd = ( +            returning_is_required_anyway +            and use_orm_insert_stmt._sort_by_parameter_order +        ) or bookkeeping +    else: +        returning_is_required_anyway = False +        deterministic_results_reqd = bookkeeping +        cached_stmt = base_mapper._memo(("insert", table), table.insert) +        exec_opt = {"compiled_cache": base_mapper._compiled_cache} + +    if execution_options: +        execution_options = util.EMPTY_DICT.merge_with( +            exec_opt, execution_options +        ) +    else: +        execution_options = exec_opt + +    return_result = None + +    for ( +        (connection, _, hasvalue, has_all_pks, has_all_defaults), +        records, +    ) in groupby( +        insert, +        lambda rec: ( +            rec[4],  # connection +            set(rec[2]),  # parameter keys +            bool(rec[5]),  # whether we have "value" parameters +            rec[6], +            rec[7], +        ), +    ): +        statement = cached_stmt + +        if use_orm_insert_stmt is not None: +            statement = statement._annotate( +                { +                    "_emit_insert_table": table, +                    "_emit_insert_mapper": mapper, +                } +            ) + +        if ( +            ( +                not bookkeeping +                or ( +                    has_all_defaults +                    or not base_mapper._prefer_eager_defaults( +                        connection.dialect, table +                    ) +                    or not table.implicit_returning +                    or not connection.dialect.insert_returning +                ) +            ) +            and not returning_is_required_anyway +            and has_all_pks +            and not hasvalue +        ): +            # the "we don't need newly generated values back" section. +            # here we have all the PKs, all the defaults or we don't want +            # to fetch them, or the dialect doesn't support RETURNING at all +            # so we have to post-fetch / use lastrowid anyway. +            records = list(records) +            multiparams = [rec[2] for rec in records] + +            result = connection.execute( +                statement, multiparams, execution_options=execution_options +            ) +            if bookkeeping: +                for ( +                    ( +                        state, +                        state_dict, +                        params, +                        mapper_rec, +                        conn, +                        value_params, +                        has_all_pks, +                        has_all_defaults, +                    ), +                    last_inserted_params, +                ) in zip(records, result.context.compiled_parameters): +                    if state: +                        _postfetch( +                            mapper_rec, +                            uowtransaction, +                            table, +                            state, +                            state_dict, +                            result, +                            last_inserted_params, +                            value_params, +                            False, +                            ( +                                result.returned_defaults +                                if not result.context.executemany +                                else None +                            ), +                        ) +                    else: +                        _postfetch_bulk_save(mapper_rec, state_dict, table) + +        else: +            # here, we need defaults and/or pk values back or we otherwise +            # know that we are using RETURNING in any case + +            records = list(records) + +            if returning_is_required_anyway or ( +                table.implicit_returning and not hasvalue and len(records) > 1 +            ): +                if ( +                    deterministic_results_reqd +                    and connection.dialect.insert_executemany_returning_sort_by_parameter_order  # noqa: E501 +                ) or ( +                    not deterministic_results_reqd +                    and connection.dialect.insert_executemany_returning +                ): +                    do_executemany = True +                elif returning_is_required_anyway: +                    if deterministic_results_reqd: +                        dt = " with RETURNING and sort by parameter order" +                    else: +                        dt = " with RETURNING" +                    raise sa_exc.InvalidRequestError( +                        f"Can't use explicit RETURNING for bulk INSERT " +                        f"operation with " +                        f"{connection.dialect.dialect_description} backend; " +                        f"executemany{dt} is not enabled for this dialect." +                    ) +                else: +                    do_executemany = False +            else: +                do_executemany = False + +            if use_orm_insert_stmt is None: +                if ( +                    not has_all_defaults +                    and base_mapper._prefer_eager_defaults( +                        connection.dialect, table +                    ) +                ): +                    statement = statement.return_defaults( +                        *mapper._server_default_cols[table], +                        sort_by_parameter_order=bookkeeping, +                    ) + +            if mapper.version_id_col is not None: +                statement = statement.return_defaults( +                    mapper.version_id_col, +                    sort_by_parameter_order=bookkeeping, +                ) +            elif do_executemany: +                statement = statement.return_defaults( +                    *table.primary_key, sort_by_parameter_order=bookkeeping +                ) + +            if do_executemany: +                multiparams = [rec[2] for rec in records] + +                result = connection.execute( +                    statement, multiparams, execution_options=execution_options +                ) + +                if use_orm_insert_stmt is not None: +                    if return_result is None: +                        return_result = result +                    else: +                        return_result = return_result.splice_vertically(result) + +                if bookkeeping: +                    for ( +                        ( +                            state, +                            state_dict, +                            params, +                            mapper_rec, +                            conn, +                            value_params, +                            has_all_pks, +                            has_all_defaults, +                        ), +                        last_inserted_params, +                        inserted_primary_key, +                        returned_defaults, +                    ) in zip_longest( +                        records, +                        result.context.compiled_parameters, +                        result.inserted_primary_key_rows, +                        result.returned_defaults_rows or (), +                    ): +                        if inserted_primary_key is None: +                            # this is a real problem and means that we didn't +                            # get back as many PK rows.  we can't continue +                            # since this indicates PK rows were missing, which +                            # means we likely mis-populated records starting +                            # at that point with incorrectly matched PK +                            # values. +                            raise orm_exc.FlushError( +                                "Multi-row INSERT statement for %s did not " +                                "produce " +                                "the correct number of INSERTed rows for " +                                "RETURNING.  Ensure there are no triggers or " +                                "special driver issues preventing INSERT from " +                                "functioning properly." % mapper_rec +                            ) + +                        for pk, col in zip( +                            inserted_primary_key, +                            mapper._pks_by_table[table], +                        ): +                            prop = mapper_rec._columntoproperty[col] +                            if state_dict.get(prop.key) is None: +                                state_dict[prop.key] = pk + +                        if state: +                            _postfetch( +                                mapper_rec, +                                uowtransaction, +                                table, +                                state, +                                state_dict, +                                result, +                                last_inserted_params, +                                value_params, +                                False, +                                returned_defaults, +                            ) +                        else: +                            _postfetch_bulk_save(mapper_rec, state_dict, table) +            else: +                assert not returning_is_required_anyway + +                for ( +                    state, +                    state_dict, +                    params, +                    mapper_rec, +                    connection, +                    value_params, +                    has_all_pks, +                    has_all_defaults, +                ) in records: +                    if value_params: +                        result = connection.execute( +                            statement.values(value_params), +                            params, +                            execution_options=execution_options, +                        ) +                    else: +                        result = connection.execute( +                            statement, +                            params, +                            execution_options=execution_options, +                        ) + +                    primary_key = result.inserted_primary_key +                    if primary_key is None: +                        raise orm_exc.FlushError( +                            "Single-row INSERT statement for %s " +                            "did not produce a " +                            "new primary key result " +                            "being invoked.  Ensure there are no triggers or " +                            "special driver issues preventing INSERT from " +                            "functioning properly." % (mapper_rec,) +                        ) +                    for pk, col in zip( +                        primary_key, mapper._pks_by_table[table] +                    ): +                        prop = mapper_rec._columntoproperty[col] +                        if ( +                            col in value_params +                            or state_dict.get(prop.key) is None +                        ): +                            state_dict[prop.key] = pk +                    if bookkeeping: +                        if state: +                            _postfetch( +                                mapper_rec, +                                uowtransaction, +                                table, +                                state, +                                state_dict, +                                result, +                                result.context.compiled_parameters[0], +                                value_params, +                                False, +                                ( +                                    result.returned_defaults +                                    if not result.context.executemany +                                    else None +                                ), +                            ) +                        else: +                            _postfetch_bulk_save(mapper_rec, state_dict, table) + +    if use_orm_insert_stmt is not None: +        if return_result is None: +            return _cursor.null_dml_result() +        else: +            return return_result + + +def _emit_post_update_statements( +    base_mapper, uowtransaction, mapper, table, update +): +    """Emit UPDATE statements corresponding to value lists collected +    by _collect_post_update_commands().""" + +    execution_options = {"compiled_cache": base_mapper._compiled_cache} + +    needs_version_id = ( +        mapper.version_id_col is not None +        and mapper.version_id_col in mapper._cols_by_table[table] +    ) + +    def update_stmt(): +        clauses = BooleanClauseList._construct_raw(operators.and_) + +        for col in mapper._pks_by_table[table]: +            clauses._append_inplace( +                col == sql.bindparam(col._label, type_=col.type) +            ) + +        if needs_version_id: +            clauses._append_inplace( +                mapper.version_id_col +                == sql.bindparam( +                    mapper.version_id_col._label, +                    type_=mapper.version_id_col.type, +                ) +            ) + +        stmt = table.update().where(clauses) + +        return stmt + +    statement = base_mapper._memo(("post_update", table), update_stmt) + +    if mapper._version_id_has_server_side_value: +        statement = statement.return_defaults(mapper.version_id_col) + +    # execute each UPDATE in the order according to the original +    # list of states to guarantee row access order, but +    # also group them into common (connection, cols) sets +    # to support executemany(). +    for key, records in groupby( +        update, +        lambda rec: (rec[3], set(rec[4])),  # connection  # parameter keys +    ): +        rows = 0 + +        records = list(records) +        connection = key[0] + +        assert_singlerow = connection.dialect.supports_sane_rowcount +        assert_multirow = ( +            assert_singlerow +            and connection.dialect.supports_sane_multi_rowcount +        ) +        allow_executemany = not needs_version_id or assert_multirow + +        if not allow_executemany: +            check_rowcount = assert_singlerow +            for state, state_dict, mapper_rec, connection, params in records: +                c = connection.execute( +                    statement, params, execution_options=execution_options +                ) + +                _postfetch_post_update( +                    mapper_rec, +                    uowtransaction, +                    table, +                    state, +                    state_dict, +                    c, +                    c.context.compiled_parameters[0], +                ) +                rows += c.rowcount +        else: +            multiparams = [ +                params +                for state, state_dict, mapper_rec, conn, params in records +            ] + +            check_rowcount = assert_multirow or ( +                assert_singlerow and len(multiparams) == 1 +            ) + +            c = connection.execute( +                statement, multiparams, execution_options=execution_options +            ) + +            rows += c.rowcount +            for state, state_dict, mapper_rec, connection, params in records: +                _postfetch_post_update( +                    mapper_rec, +                    uowtransaction, +                    table, +                    state, +                    state_dict, +                    c, +                    c.context.compiled_parameters[0], +                ) + +        if check_rowcount: +            if rows != len(records): +                raise orm_exc.StaleDataError( +                    "UPDATE statement on table '%s' expected to " +                    "update %d row(s); %d were matched." +                    % (table.description, len(records), rows) +                ) + +        elif needs_version_id: +            util.warn( +                "Dialect %s does not support updated rowcount " +                "- versioning cannot be verified." +                % c.dialect.dialect_description +            ) + + +def _emit_delete_statements( +    base_mapper, uowtransaction, mapper, table, delete +): +    """Emit DELETE statements corresponding to value lists collected +    by _collect_delete_commands().""" + +    need_version_id = ( +        mapper.version_id_col is not None +        and mapper.version_id_col in mapper._cols_by_table[table] +    ) + +    def delete_stmt(): +        clauses = BooleanClauseList._construct_raw(operators.and_) + +        for col in mapper._pks_by_table[table]: +            clauses._append_inplace( +                col == sql.bindparam(col.key, type_=col.type) +            ) + +        if need_version_id: +            clauses._append_inplace( +                mapper.version_id_col +                == sql.bindparam( +                    mapper.version_id_col.key, type_=mapper.version_id_col.type +                ) +            ) + +        return table.delete().where(clauses) + +    statement = base_mapper._memo(("delete", table), delete_stmt) +    for connection, recs in groupby(delete, lambda rec: rec[1]):  # connection +        del_objects = [params for params, connection in recs] + +        execution_options = {"compiled_cache": base_mapper._compiled_cache} +        expected = len(del_objects) +        rows_matched = -1 +        only_warn = False + +        if ( +            need_version_id +            and not connection.dialect.supports_sane_multi_rowcount +        ): +            if connection.dialect.supports_sane_rowcount: +                rows_matched = 0 +                # execute deletes individually so that versioned +                # rows can be verified +                for params in del_objects: +                    c = connection.execute( +                        statement, params, execution_options=execution_options +                    ) +                    rows_matched += c.rowcount +            else: +                util.warn( +                    "Dialect %s does not support deleted rowcount " +                    "- versioning cannot be verified." +                    % connection.dialect.dialect_description +                ) +                connection.execute( +                    statement, del_objects, execution_options=execution_options +                ) +        else: +            c = connection.execute( +                statement, del_objects, execution_options=execution_options +            ) + +            if not need_version_id: +                only_warn = True + +            rows_matched = c.rowcount + +        if ( +            base_mapper.confirm_deleted_rows +            and rows_matched > -1 +            and expected != rows_matched +            and ( +                connection.dialect.supports_sane_multi_rowcount +                or len(del_objects) == 1 +            ) +        ): +            # TODO: why does this "only warn" if versioning is turned off, +            # whereas the UPDATE raises? +            if only_warn: +                util.warn( +                    "DELETE statement on table '%s' expected to " +                    "delete %d row(s); %d were matched.  Please set " +                    "confirm_deleted_rows=False within the mapper " +                    "configuration to prevent this warning." +                    % (table.description, expected, rows_matched) +                ) +            else: +                raise orm_exc.StaleDataError( +                    "DELETE statement on table '%s' expected to " +                    "delete %d row(s); %d were matched.  Please set " +                    "confirm_deleted_rows=False within the mapper " +                    "configuration to prevent this warning." +                    % (table.description, expected, rows_matched) +                ) + + +def _finalize_insert_update_commands(base_mapper, uowtransaction, states): +    """finalize state on states that have been inserted or updated, +    including calling after_insert/after_update events. + +    """ +    for state, state_dict, mapper, connection, has_identity in states: +        if mapper._readonly_props: +            readonly = state.unmodified_intersection( +                [ +                    p.key +                    for p in mapper._readonly_props +                    if ( +                        p.expire_on_flush +                        and (not p.deferred or p.key in state.dict) +                    ) +                    or ( +                        not p.expire_on_flush +                        and not p.deferred +                        and p.key not in state.dict +                    ) +                ] +            ) +            if readonly: +                state._expire_attributes(state.dict, readonly) + +        # if eager_defaults option is enabled, load +        # all expired cols.  Else if we have a version_id_col, make sure +        # it isn't expired. +        toload_now = [] + +        # this is specifically to emit a second SELECT for eager_defaults, +        # so only if it's set to True, not "auto" +        if base_mapper.eager_defaults is True: +            toload_now.extend( +                state._unloaded_non_object.intersection( +                    mapper._server_default_plus_onupdate_propkeys +                ) +            ) + +        if ( +            mapper.version_id_col is not None +            and mapper.version_id_generator is False +        ): +            if mapper._version_id_prop.key in state.unloaded: +                toload_now.extend([mapper._version_id_prop.key]) + +        if toload_now: +            state.key = base_mapper._identity_key_from_state(state) +            stmt = future.select(mapper).set_label_style( +                LABEL_STYLE_TABLENAME_PLUS_COL +            ) +            loading.load_on_ident( +                uowtransaction.session, +                stmt, +                state.key, +                refresh_state=state, +                only_load_props=toload_now, +            ) + +        # call after_XXX extensions +        if not has_identity: +            mapper.dispatch.after_insert(mapper, connection, state) +        else: +            mapper.dispatch.after_update(mapper, connection, state) + +        if ( +            mapper.version_id_generator is False +            and mapper.version_id_col is not None +        ): +            if state_dict[mapper._version_id_prop.key] is None: +                raise orm_exc.FlushError( +                    "Instance does not contain a non-NULL version value" +                ) + + +def _postfetch_post_update( +    mapper, uowtransaction, table, state, dict_, result, params +): +    needs_version_id = ( +        mapper.version_id_col is not None +        and mapper.version_id_col in mapper._cols_by_table[table] +    ) + +    if not uowtransaction.is_deleted(state): +        # post updating after a regular INSERT or UPDATE, do a full postfetch +        prefetch_cols = result.context.compiled.prefetch +        postfetch_cols = result.context.compiled.postfetch +    elif needs_version_id: +        # post updating before a DELETE with a version_id_col, need to +        # postfetch just version_id_col +        prefetch_cols = postfetch_cols = () +    else: +        # post updating before a DELETE without a version_id_col, +        # don't need to postfetch +        return + +    if needs_version_id: +        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + +    refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) +    if refresh_flush: +        load_evt_attrs = [] + +    for c in prefetch_cols: +        if c.key in params and c in mapper._columntoproperty: +            dict_[mapper._columntoproperty[c].key] = params[c.key] +            if refresh_flush: +                load_evt_attrs.append(mapper._columntoproperty[c].key) + +    if refresh_flush and load_evt_attrs: +        mapper.class_manager.dispatch.refresh_flush( +            state, uowtransaction, load_evt_attrs +        ) + +    if postfetch_cols: +        state._expire_attributes( +            state.dict, +            [ +                mapper._columntoproperty[c].key +                for c in postfetch_cols +                if c in mapper._columntoproperty +            ], +        ) + + +def _postfetch( +    mapper, +    uowtransaction, +    table, +    state, +    dict_, +    result, +    params, +    value_params, +    isupdate, +    returned_defaults, +): +    """Expire attributes in need of newly persisted database state, +    after an INSERT or UPDATE statement has proceeded for that +    state.""" + +    prefetch_cols = result.context.compiled.prefetch +    postfetch_cols = result.context.compiled.postfetch +    returning_cols = result.context.compiled.effective_returning + +    if ( +        mapper.version_id_col is not None +        and mapper.version_id_col in mapper._cols_by_table[table] +    ): +        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] + +    refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) +    if refresh_flush: +        load_evt_attrs = [] + +    if returning_cols: +        row = returned_defaults +        if row is not None: +            for row_value, col in zip(row, returning_cols): +                # pk cols returned from insert are handled +                # distinctly, don't step on the values here +                if col.primary_key and result.context.isinsert: +                    continue + +                # note that columns can be in the "return defaults" that are +                # not mapped to this mapper, typically because they are +                # "excluded", which can be specified directly or also occurs +                # when using declarative w/ single table inheritance +                prop = mapper._columntoproperty.get(col) +                if prop: +                    dict_[prop.key] = row_value +                    if refresh_flush: +                        load_evt_attrs.append(prop.key) + +    for c in prefetch_cols: +        if c.key in params and c in mapper._columntoproperty: +            pkey = mapper._columntoproperty[c].key + +            # set prefetched value in dict and also pop from committed_state, +            # since this is new database state that replaces whatever might +            # have previously been fetched (see #10800).  this is essentially a +            # shorthand version of set_committed_value(), which could also be +            # used here directly (with more overhead) +            dict_[pkey] = params[c.key] +            state.committed_state.pop(pkey, None) + +            if refresh_flush: +                load_evt_attrs.append(pkey) + +    if refresh_flush and load_evt_attrs: +        mapper.class_manager.dispatch.refresh_flush( +            state, uowtransaction, load_evt_attrs +        ) + +    if isupdate and value_params: +        # explicitly suit the use case specified by +        # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING +        # database which are set to themselves in order to do a version bump. +        postfetch_cols.extend( +            [ +                col +                for col in value_params +                if col.primary_key and col not in returning_cols +            ] +        ) + +    if postfetch_cols: +        state._expire_attributes( +            state.dict, +            [ +                mapper._columntoproperty[c].key +                for c in postfetch_cols +                if c in mapper._columntoproperty +            ], +        ) + +    # synchronize newly inserted ids from one table to the next +    # TODO: this still goes a little too often.  would be nice to +    # have definitive list of "columns that changed" here +    for m, equated_pairs in mapper._table_to_equated[table]: +        sync.populate( +            state, +            m, +            state, +            m, +            equated_pairs, +            uowtransaction, +            mapper.passive_updates, +        ) + + +def _postfetch_bulk_save(mapper, dict_, table): +    for m, equated_pairs in mapper._table_to_equated[table]: +        sync.bulk_populate_inherit_keys(dict_, m, equated_pairs) + + +def _connections_for_states(base_mapper, uowtransaction, states): +    """Return an iterator of (state, state.dict, mapper, connection). + +    The states are sorted according to _sort_states, then paired +    with the connection they should be using for the given +    unit of work transaction. + +    """ +    # if session has a connection callable, +    # organize individual states with the connection +    # to use for update +    if uowtransaction.session.connection_callable: +        connection_callable = uowtransaction.session.connection_callable +    else: +        connection = uowtransaction.transaction.connection(base_mapper) +        connection_callable = None + +    for state in _sort_states(base_mapper, states): +        if connection_callable: +            connection = connection_callable(base_mapper, state.obj()) + +        mapper = state.manager.mapper + +        yield state, state.dict, mapper, connection + + +def _sort_states(mapper, states): +    pending = set(states) +    persistent = {s for s in pending if s.key is not None} +    pending.difference_update(persistent) + +    try: +        persistent_sorted = sorted( +            persistent, key=mapper._persistent_sortkey_fn +        ) +    except TypeError as err: +        raise sa_exc.InvalidRequestError( +            "Could not sort objects by primary key; primary key " +            "values must be sortable in Python (was: %s)" % err +        ) from err +    return ( +        sorted(pending, key=operator.attrgetter("insert_order")) +        + persistent_sorted +    )  | 
