summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/_kwargs/cleanup.py
blob: 8839d360df5c45703c8a6a13980d3f521f78d54c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import annotations

from inspect import Traceback, isasyncgen
from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generator

from anyio import create_task_group

from litestar.utils import ensure_async_callable
from litestar.utils.compat import async_next

__all__ = ("DependencyCleanupGroup",)


if TYPE_CHECKING:
    from litestar.types import AnyGenerator


class DependencyCleanupGroup:
    """Wrapper for generator based dependencies.

    Simplify cleanup by wrapping :func:`next` / :func:`anext` calls and providing facilities to
    :meth:`throw <generator.throw>` / :meth:`athrow <agen.athrow>` into all generators consecutively. An instance of
    this class can be used as a contextmanager, which will automatically throw any exceptions into its generators. All
    exceptions caught in this manner will be re-raised after they have been thrown in the generators.
    """

    __slots__ = ("_generators", "_closed")

    def __init__(self, generators: list[AnyGenerator] | None = None) -> None:
        """Initialize ``DependencyCleanupGroup``.

        Args:
            generators: An optional list of generators to be called at cleanup
        """
        self._generators = generators or []
        self._closed = False

    def add(self, generator: Generator[Any, None, None] | AsyncGenerator[Any, None]) -> None:
        """Add a new generator to the group.

        Args:
            generator: The generator to add

        Returns:
            None
        """
        if self._closed:
            raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup")
        self._generators.append(generator)

    @staticmethod
    def _wrap_next(generator: AnyGenerator) -> Callable[[], Awaitable[None]]:
        if isasyncgen(generator):

            async def wrapped_async() -> None:
                await async_next(generator, None)

            return wrapped_async

        def wrapped() -> None:
            next(generator, None)  # type: ignore[arg-type]

        return ensure_async_callable(wrapped)

    async def cleanup(self) -> None:
        """Execute cleanup by calling :func:`next` / :func:`anext` on all generators.

        If there are multiple generators to be called, they will be executed in a :class:`anyio.TaskGroup`.

        Returns:
            None
        """
        if self._closed:
            raise RuntimeError("Cannot call cleanup on a closed DependencyCleanupGroup")

        self._closed = True

        if not self._generators:
            return

        if len(self._generators) == 1:
            await self._wrap_next(self._generators[0])()
            return

        async with create_task_group() as task_group:
            for generator in self._generators:
                task_group.start_soon(self._wrap_next(generator))

    async def __aenter__(self) -> None:
        """Support the async contextmanager protocol to allow for easier catching and throwing of exceptions into the
        generators.
        """

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: Traceback | None,
    ) -> None:
        """If an exception was raised within the contextmanager block, throw it into all generators."""
        if exc_val:
            await self.throw(exc_val)

    async def throw(self, exc: BaseException) -> None:
        """Throw an exception in all generators sequentially.

        Args:
            exc: Exception to throw
        """
        for gen in self._generators:
            try:
                if isasyncgen(gen):
                    await gen.athrow(exc)
                else:
                    gen.throw(exc)  # type: ignore[union-attr]
            except (StopIteration, StopAsyncIteration):
                continue