summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/sqlalchemy/orm/identity.py
blob: 23682f7ef22ff8dd549761bec78e14b146299e12 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
# orm/identity.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

from __future__ import annotations

from typing import Any
from typing import cast
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import NoReturn
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
import weakref

from . import util as orm_util
from .. import exc as sa_exc

if TYPE_CHECKING:
    from ._typing import _IdentityKeyType
    from .state import InstanceState


_T = TypeVar("_T", bound=Any)

_O = TypeVar("_O", bound=object)


class IdentityMap:
    _wr: weakref.ref[IdentityMap]

    _dict: Dict[_IdentityKeyType[Any], Any]
    _modified: Set[InstanceState[Any]]

    def __init__(self) -> None:
        self._dict = {}
        self._modified = set()
        self._wr = weakref.ref(self)

    def _kill(self) -> None:
        self._add_unpresent = _killed  # type: ignore

    def all_states(self) -> List[InstanceState[Any]]:
        raise NotImplementedError()

    def contains_state(self, state: InstanceState[Any]) -> bool:
        raise NotImplementedError()

    def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
        raise NotImplementedError()

    def safe_discard(self, state: InstanceState[Any]) -> None:
        raise NotImplementedError()

    def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
        raise NotImplementedError()

    def get(
        self, key: _IdentityKeyType[_O], default: Optional[_O] = None
    ) -> Optional[_O]:
        raise NotImplementedError()

    def fast_get_state(
        self, key: _IdentityKeyType[_O]
    ) -> Optional[InstanceState[_O]]:
        raise NotImplementedError()

    def keys(self) -> Iterable[_IdentityKeyType[Any]]:
        return self._dict.keys()

    def values(self) -> Iterable[object]:
        raise NotImplementedError()

    def replace(self, state: InstanceState[_O]) -> Optional[InstanceState[_O]]:
        raise NotImplementedError()

    def add(self, state: InstanceState[Any]) -> bool:
        raise NotImplementedError()

    def _fast_discard(self, state: InstanceState[Any]) -> None:
        raise NotImplementedError()

    def _add_unpresent(
        self, state: InstanceState[Any], key: _IdentityKeyType[Any]
    ) -> None:
        """optional inlined form of add() which can assume item isn't present
        in the map"""
        self.add(state)

    def _manage_incoming_state(self, state: InstanceState[Any]) -> None:
        state._instance_dict = self._wr

        if state.modified:
            self._modified.add(state)

    def _manage_removed_state(self, state: InstanceState[Any]) -> None:
        del state._instance_dict
        if state.modified:
            self._modified.discard(state)

    def _dirty_states(self) -> Set[InstanceState[Any]]:
        return self._modified

    def check_modified(self) -> bool:
        """return True if any InstanceStates present have been marked
        as 'modified'.

        """
        return bool(self._modified)

    def has_key(self, key: _IdentityKeyType[Any]) -> bool:
        return key in self

    def __len__(self) -> int:
        return len(self._dict)


class WeakInstanceDict(IdentityMap):
    _dict: Dict[_IdentityKeyType[Any], InstanceState[Any]]

    def __getitem__(self, key: _IdentityKeyType[_O]) -> _O:
        state = cast("InstanceState[_O]", self._dict[key])
        o = state.obj()
        if o is None:
            raise KeyError(key)
        return o

    def __contains__(self, key: _IdentityKeyType[Any]) -> bool:
        try:
            if key in self._dict:
                state = self._dict[key]
                o = state.obj()
            else:
                return False
        except KeyError:
            return False
        else:
            return o is not None

    def contains_state(self, state: InstanceState[Any]) -> bool:
        if state.key in self._dict:
            if TYPE_CHECKING:
                assert state.key is not None
            try:
                return self._dict[state.key] is state
            except KeyError:
                return False
        else:
            return False

    def replace(
        self, state: InstanceState[Any]
    ) -> Optional[InstanceState[Any]]:
        assert state.key is not None
        if state.key in self._dict:
            try:
                existing = existing_non_none = self._dict[state.key]
            except KeyError:
                # catch gc removed the key after we just checked for it
                existing = None
            else:
                if existing_non_none is not state:
                    self._manage_removed_state(existing_non_none)
                else:
                    return None
        else:
            existing = None

        self._dict[state.key] = state
        self._manage_incoming_state(state)
        return existing

    def add(self, state: InstanceState[Any]) -> bool:
        key = state.key
        assert key is not None
        # inline of self.__contains__
        if key in self._dict:
            try:
                existing_state = self._dict[key]
            except KeyError:
                # catch gc removed the key after we just checked for it
                pass
            else:
                if existing_state is not state:
                    o = existing_state.obj()
                    if o is not None:
                        raise sa_exc.InvalidRequestError(
                            "Can't attach instance "
                            "%s; another instance with key %s is already "
                            "present in this session."
                            % (orm_util.state_str(state), state.key)
                        )
                else:
                    return False
        self._dict[key] = state
        self._manage_incoming_state(state)
        return True

    def _add_unpresent(
        self, state: InstanceState[Any], key: _IdentityKeyType[Any]
    ) -> None:
        # inlined form of add() called by loading.py
        self._dict[key] = state
        state._instance_dict = self._wr

    def fast_get_state(
        self, key: _IdentityKeyType[_O]
    ) -> Optional[InstanceState[_O]]:
        return self._dict.get(key)

    def get(
        self, key: _IdentityKeyType[_O], default: Optional[_O] = None
    ) -> Optional[_O]:
        if key not in self._dict:
            return default
        try:
            state = cast("InstanceState[_O]", self._dict[key])
        except KeyError:
            # catch gc removed the key after we just checked for it
            return default
        else:
            o = state.obj()
            if o is None:
                return default
            return o

    def items(self) -> List[Tuple[_IdentityKeyType[Any], InstanceState[Any]]]:
        values = self.all_states()
        result = []
        for state in values:
            value = state.obj()
            key = state.key
            assert key is not None
            if value is not None:
                result.append((key, value))
        return result

    def values(self) -> List[object]:
        values = self.all_states()
        result = []
        for state in values:
            value = state.obj()
            if value is not None:
                result.append(value)

        return result

    def __iter__(self) -> Iterator[_IdentityKeyType[Any]]:
        return iter(self.keys())

    def all_states(self) -> List[InstanceState[Any]]:
        return list(self._dict.values())

    def _fast_discard(self, state: InstanceState[Any]) -> None:
        # used by InstanceState for state being
        # GC'ed, inlines _managed_removed_state
        key = state.key
        assert key is not None
        try:
            st = self._dict[key]
        except KeyError:
            # catch gc removed the key after we just checked for it
            pass
        else:
            if st is state:
                self._dict.pop(key, None)

    def discard(self, state: InstanceState[Any]) -> None:
        self.safe_discard(state)

    def safe_discard(self, state: InstanceState[Any]) -> None:
        key = state.key
        if key in self._dict:
            assert key is not None
            try:
                st = self._dict[key]
            except KeyError:
                # catch gc removed the key after we just checked for it
                pass
            else:
                if st is state:
                    self._dict.pop(key, None)
                    self._manage_removed_state(state)


def _killed(state: InstanceState[Any], key: _IdentityKeyType[Any]) -> NoReturn:
    # external function to avoid creating cycles when assigned to
    # the IdentityMap
    raise sa_exc.InvalidRequestError(
        "Object %s cannot be converted to 'persistent' state, as this "
        "identity map is no longer valid.  Has the owning Session "
        "been closed?" % orm_util.state_str(state),
        code="lkrp",
    )