summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/events/emitter.py
blob: 7c33c9e73f51abc19defbee2c31061f7c4555a7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations

import math
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import AsyncExitStack
from functools import partial
from typing import TYPE_CHECKING, Any, Sequence

if sys.version_info < (3, 9):
    from typing import AsyncContextManager
else:
    from contextlib import AbstractAsyncContextManager as AsyncContextManager

import anyio

from litestar.exceptions import ImproperlyConfiguredException

if TYPE_CHECKING:
    from types import TracebackType

    from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

    from litestar.events.listener import EventListener

__all__ = ("BaseEventEmitterBackend", "SimpleEventEmitter")


class BaseEventEmitterBackend(AsyncContextManager["BaseEventEmitterBackend"], ABC):
    """Abstract class used to define event emitter backends."""

    __slots__ = ("listeners",)

    listeners: defaultdict[str, set[EventListener]]

    def __init__(self, listeners: Sequence[EventListener]) -> None:
        """Create an event emitter instance.

        Args:
            listeners: A list of listeners.
        """
        self.listeners = defaultdict(set)
        for listener in listeners:
            for event_id in listener.event_ids:
                self.listeners[event_id].add(listener)

    @abstractmethod
    def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None:
        """Emit an event to all attached listeners.

        Args:
            event_id: The ID of the event to emit, e.g 'my_event'.
            *args: args to pass to the listener(s).
            **kwargs: kwargs to pass to the listener(s)

        Returns:
            None
        """
        raise NotImplementedError("not implemented")


class SimpleEventEmitter(BaseEventEmitterBackend):
    """Event emitter the works only in the current process"""

    __slots__ = ("_queue", "_exit_stack", "_receive_stream", "_send_stream")

    def __init__(self, listeners: Sequence[EventListener]) -> None:
        """Create an event emitter instance.

        Args:
            listeners: A list of listeners.
        """
        super().__init__(listeners=listeners)
        self._receive_stream: MemoryObjectReceiveStream | None = None
        self._send_stream: MemoryObjectSendStream | None = None
        self._exit_stack: AsyncExitStack | None = None

    async def _worker(self, receive_stream: MemoryObjectReceiveStream) -> None:
        """Run items from ``receive_stream`` in a task group.

        Returns:
            None
        """
        async with receive_stream, anyio.create_task_group() as task_group:
            async for item in receive_stream:
                fn, args, kwargs = item
                if kwargs:
                    fn = partial(fn, **kwargs)
                task_group.start_soon(fn, *args)  # pyright: ignore[reportGeneralTypeIssues]

    async def __aenter__(self) -> SimpleEventEmitter:
        self._exit_stack = AsyncExitStack()
        send_stream, receive_stream = anyio.create_memory_object_stream(math.inf)  # type: ignore[var-annotated]
        self._send_stream = send_stream
        task_group = anyio.create_task_group()

        await self._exit_stack.enter_async_context(task_group)
        await self._exit_stack.enter_async_context(send_stream)
        task_group.start_soon(self._worker, receive_stream)

        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        if self._exit_stack:
            await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)

        self._exit_stack = None
        self._send_stream = None

    def emit(self, event_id: str, *args: Any, **kwargs: Any) -> None:
        """Emit an event to all attached listeners.

        Args:
            event_id: The ID of the event to emit, e.g 'my_event'.
            *args: args to pass to the listener(s).
            **kwargs: kwargs to pass to the listener(s)

        Returns:
            None
        """
        if not (self._send_stream and self._exit_stack):
            raise RuntimeError("Emitter not initialized")

        if listeners := self.listeners.get(event_id):
            for listener in listeners:
                self._send_stream.send_nowait((listener.fn, args, kwargs))
            return
        raise ImproperlyConfiguredException(f"no event listeners are registered for event ID: {event_id}")