summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/sqlalchemy/orm/state_changes.py
blob: 56963c6af1d30a126b3b3135a16d1d24791ae95c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# orm/state_changes.py
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php

"""State tracking utilities used by :class:`_orm.Session`.

"""

from __future__ import annotations

import contextlib
from enum import Enum
from typing import Any
from typing import Callable
from typing import cast
from typing import Iterator
from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union

from .. import exc as sa_exc
from .. import util
from ..util.typing import Literal

_F = TypeVar("_F", bound=Callable[..., Any])


class _StateChangeState(Enum):
    pass


class _StateChangeStates(_StateChangeState):
    ANY = 1
    NO_CHANGE = 2
    CHANGE_IN_PROGRESS = 3


class _StateChange:
    """Supplies state assertion decorators.

    The current use case is for the :class:`_orm.SessionTransaction` class. The
    :class:`_StateChange` class itself is agnostic of the
    :class:`_orm.SessionTransaction` class so could in theory be generalized
    for other systems as well.

    """

    _next_state: _StateChangeState = _StateChangeStates.ANY
    _state: _StateChangeState = _StateChangeStates.NO_CHANGE
    _current_fn: Optional[Callable[..., Any]] = None

    def _raise_for_prerequisite_state(
        self, operation_name: str, state: _StateChangeState
    ) -> NoReturn:
        raise sa_exc.IllegalStateChangeError(
            f"Can't run operation '{operation_name}()' when Session "
            f"is in state {state!r}",
            code="isce",
        )

    @classmethod
    def declare_states(
        cls,
        prerequisite_states: Union[
            Literal[_StateChangeStates.ANY], Tuple[_StateChangeState, ...]
        ],
        moves_to: _StateChangeState,
    ) -> Callable[[_F], _F]:
        """Method decorator declaring valid states.

        :param prerequisite_states: sequence of acceptable prerequisite
         states.   Can be the single constant _State.ANY to indicate no
         prerequisite state

        :param moves_to: the expected state at the end of the method, assuming
         no exceptions raised.   Can be the constant _State.NO_CHANGE to
         indicate state should not change at the end of the method.

        """
        assert prerequisite_states, "no prequisite states sent"
        has_prerequisite_states = (
            prerequisite_states is not _StateChangeStates.ANY
        )

        prerequisite_state_collection = cast(
            "Tuple[_StateChangeState, ...]", prerequisite_states
        )
        expect_state_change = moves_to is not _StateChangeStates.NO_CHANGE

        @util.decorator
        def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
            current_state = self._state

            if (
                has_prerequisite_states
                and current_state not in prerequisite_state_collection
            ):
                self._raise_for_prerequisite_state(fn.__name__, current_state)

            next_state = self._next_state
            existing_fn = self._current_fn
            expect_state = moves_to if expect_state_change else current_state

            if (
                # destination states are restricted
                next_state is not _StateChangeStates.ANY
                # method seeks to change state
                and expect_state_change
                # destination state incorrect
                and next_state is not expect_state
            ):
                if existing_fn and next_state in (
                    _StateChangeStates.NO_CHANGE,
                    _StateChangeStates.CHANGE_IN_PROGRESS,
                ):
                    raise sa_exc.IllegalStateChangeError(
                        f"Method '{fn.__name__}()' can't be called here; "
                        f"method '{existing_fn.__name__}()' is already "
                        f"in progress and this would cause an unexpected "
                        f"state change to {moves_to!r}",
                        code="isce",
                    )
                else:
                    raise sa_exc.IllegalStateChangeError(
                        f"Cant run operation '{fn.__name__}()' here; "
                        f"will move to state {moves_to!r} where we are "
                        f"expecting {next_state!r}",
                        code="isce",
                    )

            self._current_fn = fn
            self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS
            try:
                ret_value = fn(self, *arg, **kw)
            except:
                raise
            else:
                if self._state is expect_state:
                    return ret_value

                if self._state is current_state:
                    raise sa_exc.IllegalStateChangeError(
                        f"Method '{fn.__name__}()' failed to "
                        "change state "
                        f"to {moves_to!r} as expected",
                        code="isce",
                    )
                elif existing_fn:
                    raise sa_exc.IllegalStateChangeError(
                        f"While method '{existing_fn.__name__}()' was "
                        "running, "
                        f"method '{fn.__name__}()' caused an "
                        "unexpected "
                        f"state change to {self._state!r}",
                        code="isce",
                    )
                else:
                    raise sa_exc.IllegalStateChangeError(
                        f"Method '{fn.__name__}()' caused an unexpected "
                        f"state change to {self._state!r}",
                        code="isce",
                    )

            finally:
                self._next_state = next_state
                self._current_fn = existing_fn

        return _go

    @contextlib.contextmanager
    def _expect_state(self, expected: _StateChangeState) -> Iterator[Any]:
        """called within a method that changes states.

        method must also use the ``@declare_states()`` decorator.

        """
        assert self._next_state is _StateChangeStates.CHANGE_IN_PROGRESS, (
            "Unexpected call to _expect_state outside of "
            "state-changing method"
        )

        self._next_state = expected
        try:
            yield
        except:
            raise
        else:
            if self._state is not expected:
                raise sa_exc.IllegalStateChangeError(
                    f"Unexpected state change to {self._state!r}", code="isce"
                )
        finally:
            self._next_state = _StateChangeStates.CHANGE_IN_PROGRESS