summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/channels/plugin.py
blob: 59884454d4a2d60206e4f2537a06fd90fe531e81 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
from __future__ import annotations

import asyncio
from asyncio import CancelledError, Queue, Task, create_task
from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
from functools import partial
from typing import TYPE_CHECKING, AsyncGenerator, Awaitable, Callable, Iterable

import msgspec.json

from litestar.di import Provide
from litestar.exceptions import ImproperlyConfiguredException, LitestarException
from litestar.handlers import WebsocketRouteHandler
from litestar.plugins import InitPluginProtocol
from litestar.serialization import default_serializer

from .subscriber import BacklogStrategy, EventCallback, Subscriber

if TYPE_CHECKING:
    from types import TracebackType

    from litestar.channels.backends.base import ChannelsBackend
    from litestar.config.app import AppConfig
    from litestar.connection import WebSocket
    from litestar.types import LitestarEncodableType, TypeEncodersMap
    from litestar.types.asgi_types import WebSocketMode


class ChannelsException(LitestarException):
    pass


class ChannelsPlugin(InitPluginProtocol, AbstractAsyncContextManager):
    def __init__(
        self,
        backend: ChannelsBackend,
        *,
        channels: Iterable[str] | None = None,
        arbitrary_channels_allowed: bool = False,
        create_ws_route_handlers: bool = False,
        ws_handler_send_history: int = 0,
        ws_handler_base_path: str = "/",
        ws_send_mode: WebSocketMode = "text",
        subscriber_max_backlog: int | None = None,
        subscriber_backlog_strategy: BacklogStrategy = "backoff",
        subscriber_class: type[Subscriber] = Subscriber,
        type_encoders: TypeEncodersMap | None = None,
    ) -> None:
        """Plugin to handle broadcasting to WebSockets with support for channels.

        This plugin is available as an injected dependency using the ``channels`` key.

        Args:
            backend: Backend to store data in
            channels: Channels to serve. If ``arbitrary_channels_allowed`` is ``False`` (the default), trying to
                subscribe to a channel not in ``channels`` will raise an exception
            arbitrary_channels_allowed: Allow the creation of channels on the fly
            create_ws_route_handlers: If ``True``, websocket route handlers will be created for all channels defined in
                ``channels``. If ``arbitrary_channels_allowed`` is ``True``, a single handler will be created instead,
                handling all channels. The handlers created will accept WebSocket connections and start sending received
                events for their respective channels.
            ws_handler_send_history: Amount of history entries to send from the generated websocket route handlers after
                a client has connected. A value of ``0`` indicates to not send a history
            ws_handler_base_path: Path prefix used for the generated route handlers
            ws_send_mode: Send mode to use for sending data through a :class:`WebSocket <.connection.WebSocket>`.
                This will be used when sending within generated route handlers or :meth:`Subscriber.run_in_background`
            subscriber_max_backlog: Maximum amount of unsent messages to be held in memory for a given subscriber. If
                that limit is reached, new messages will be treated accordingly to ``backlog_strategy``
            subscriber_backlog_strategy: Define the behaviour if ``max_backlog`` is reached for a subscriber. `
                `backoff`` will result in new messages being dropped until older ones have been processed. ``dropleft``
                will drop older messages in favour of new ones.
            subscriber_class: A :class:`Subscriber` subclass to return from :meth:`subscribe`
            type_encoders: An additional mapping of type encoders used to encode data before sending

        """
        self._backend = backend
        self._pub_queue: Queue[tuple[bytes, list[str]]] | None = None
        self._pub_task: Task | None = None
        self._sub_task: Task | None = None

        if not (channels or arbitrary_channels_allowed):
            raise ImproperlyConfiguredException("Must define either channels or set arbitrary_channels_allowed=True")

        # make the path absolute, so we can simply concatenate it later
        if not ws_handler_base_path.endswith("/"):
            ws_handler_base_path += "/"

        self._arbitrary_channels_allowed = arbitrary_channels_allowed
        self._create_route_handlers = create_ws_route_handlers
        self._handler_root_path = ws_handler_base_path
        self._socket_send_mode: WebSocketMode = ws_send_mode
        self._encode_json = msgspec.json.Encoder(
            enc_hook=partial(default_serializer, type_encoders=type_encoders)
        ).encode
        self._handler_should_send_history = bool(ws_handler_send_history)
        self._history_limit = None if ws_handler_send_history < 0 else ws_handler_send_history
        self._max_backlog = subscriber_max_backlog
        self._backlog_strategy: BacklogStrategy = subscriber_backlog_strategy
        self._subscriber_class = subscriber_class

        self._channels: dict[str, set[Subscriber]] = {channel: set() for channel in channels or []}

    def encode_data(self, data: LitestarEncodableType) -> bytes:
        """Encode data before storing it in the backend"""
        if isinstance(data, bytes):
            return data

        return data.encode() if isinstance(data, str) else self._encode_json(data)

    def on_app_init(self, app_config: AppConfig) -> AppConfig:
        """Plugin hook. Set up a ``channels`` dependency, add route handlers and register application hooks"""
        app_config.dependencies["channels"] = Provide(lambda: self, use_cache=True, sync_to_thread=False)
        app_config.lifespan.append(self)
        app_config.signature_namespace.update(ChannelsPlugin=ChannelsPlugin)

        if self._create_route_handlers:
            if self._arbitrary_channels_allowed:
                path = self._handler_root_path + "{channel_name:str}"
                route_handlers = [WebsocketRouteHandler(path)(self._ws_handler_func)]
            else:
                route_handlers = [
                    WebsocketRouteHandler(self._handler_root_path + channel_name)(
                        self._create_ws_handler_func(channel_name)
                    )
                    for channel_name in self._channels
                ]
            app_config.route_handlers.extend(route_handlers)

        return app_config

    def publish(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None:
        """Schedule ``data`` to be published to ``channels``.

        .. note::
            This is a synchronous method that returns immediately. There are no
            guarantees that when this method returns the data will have been published
            to the backend. For that, use :meth:`wait_published`

        """
        if isinstance(channels, str):
            channels = [channels]
        data = self.encode_data(data)
        try:
            self._pub_queue.put_nowait((data, list(channels)))  # type: ignore[union-attr]
        except AttributeError as e:
            raise RuntimeError("Plugin not yet initialized. Did you forget to call on_startup?") from e

    async def wait_published(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None:
        """Publish ``data`` to ``channels``"""
        if isinstance(channels, str):
            channels = [channels]
        data = self.encode_data(data)

        await self._backend.publish(data, channels)

    async def subscribe(self, channels: str | Iterable[str], history: int | None = None) -> Subscriber:
        """Create a :class:`Subscriber`, providing a stream of all events in ``channels``.

        The created subscriber will be passive by default and has to be consumed manually,
        either by using :meth:`Subscriber.run_in_background` or iterating over events
        using :meth:`Subscriber.iter_events`.

        Args:
            channels: Channel(s) to subscribe to
            history: If a non-negative integer, add this amount of history entries from
                each channel to the subscriber's event stream. Note that this will wait
                until all history entries are fetched and pushed to the subscriber's
                stream. For more control use :meth:`put_subscriber_history`.

        Returns:
            A :class:`Subscriber`

        Raises:
            ChannelsException: If a channel in ``channels`` has not been declared on this backend and
                ``arbitrary_channels_allowed`` has not been set to ``True``
        """
        if isinstance(channels, str):
            channels = [channels]

        subscriber = self._subscriber_class(
            plugin=self,
            max_backlog=self._max_backlog,
            backlog_strategy=self._backlog_strategy,
        )
        channels_to_subscribe = set()

        for channel in channels:
            if channel not in self._channels:
                if not self._arbitrary_channels_allowed:
                    raise ChannelsException(
                        f"Unknown channel: {channel!r}. Either explicitly defined the channel or set "
                        "arbitrary_channels_allowed=True"
                    )
                self._channels[channel] = set()
            channel_subscribers = self._channels[channel]
            if not channel_subscribers:
                channels_to_subscribe.add(channel)

            channel_subscribers.add(subscriber)

        if channels_to_subscribe:
            await self._backend.subscribe(channels_to_subscribe)

        if history:
            await self.put_subscriber_history(subscriber=subscriber, limit=history, channels=channels)

        return subscriber

    async def unsubscribe(self, subscriber: Subscriber, channels: str | Iterable[str] | None = None) -> None:
        """Unsubscribe a :class:`Subscriber` from ``channels``. If the subscriber has a running sending task, it will
        be stopped.

        Args:
            channels: Channels to unsubscribe from. If ``None``, unsubscribe from all channels
            subscriber: :class:`Subscriber` to unsubscribe
        """
        if channels is None:
            channels = list(self._channels.keys())
        elif isinstance(channels, str):
            channels = [channels]

        channels_to_unsubscribe: set[str] = set()

        for channel in channels:
            channel_subscribers = self._channels[channel]

            try:
                channel_subscribers.remove(subscriber)
            except KeyError:  # subscriber was not subscribed to this channel. This may happen if channels is None
                continue

            if not channel_subscribers:
                channels_to_unsubscribe.add(channel)

        if all(subscriber not in queues for queues in self._channels.values()):
            await subscriber.put(None)  # this will stop any running task or generator by breaking the inner loop
            if subscriber.is_running:
                await subscriber.stop()

        if channels_to_unsubscribe:
            await self._backend.unsubscribe(channels_to_unsubscribe)

    @asynccontextmanager
    async def start_subscription(
        self, channels: str | Iterable[str], history: int | None = None
    ) -> AsyncGenerator[Subscriber, None]:
        """Create a :class:`Subscriber` and tie its subscriptions to a context manager;
        Upon exiting the context, :meth:`unsubscribe` will be called.

        Args:
            channels: Channel(s) to subscribe to
            history: If a non-negative integer, add this amount of history entries from
                each channel to the subscriber's event stream. Note that this will wait
                until all history entries are fetched and pushed to the subscriber's
                stream. For more control use :meth:`put_subscriber_history`.

        Returns:
            A :class:`Subscriber`
        """
        subscriber = await self.subscribe(channels, history=history)

        try:
            yield subscriber
        finally:
            await self.unsubscribe(subscriber, channels)

    async def put_subscriber_history(
        self, subscriber: Subscriber, channels: str | Iterable[str], limit: int | None = None
    ) -> None:
        """Fetch the history of ``channels`` from the backend and put them in the
        subscriber's stream
        """
        if isinstance(channels, str):
            channels = [channels]

        for channel in channels:
            history = await self._backend.get_history(channel, limit)
            for entry in history:
                await subscriber.put(entry)

    async def _ws_handler_func(self, channel_name: str, socket: WebSocket) -> None:
        await socket.accept()

        # the ternary operator triggers a mypy bug: https://github.com/python/mypy/issues/10740
        on_event: EventCallback = socket.send_text if self._socket_send_mode == "text" else socket.send_bytes  # type: ignore[assignment]

        async with self.start_subscription(channel_name) as subscriber:
            if self._handler_should_send_history:
                await self.put_subscriber_history(subscriber, channels=channel_name, limit=self._history_limit)

            # use the background task, so we can block on receive(), breaking the loop when a connection closes
            async with subscriber.run_in_background(on_event):
                while (await socket.receive())["type"] != "websocket.disconnect":
                    continue

    def _create_ws_handler_func(self, channel_name: str) -> Callable[[WebSocket], Awaitable[None]]:
        async def ws_handler_func(socket: WebSocket) -> None:
            await self._ws_handler_func(channel_name=channel_name, socket=socket)

        return ws_handler_func

    async def _pub_worker(self) -> None:
        while self._pub_queue:
            data, channels = await self._pub_queue.get()
            await self._backend.publish(data, channels)
            self._pub_queue.task_done()

    async def _sub_worker(self) -> None:
        async for channel, payload in self._backend.stream_events():
            for subscriber in self._channels.get(channel, []):
                subscriber.put_nowait(payload)

    async def _on_startup(self) -> None:
        await self._backend.on_startup()
        self._pub_queue = Queue()
        self._pub_task = create_task(self._pub_worker())
        self._sub_task = create_task(self._sub_worker())
        if self._channels:
            await self._backend.subscribe(list(self._channels))

    async def _on_shutdown(self) -> None:
        if self._pub_queue:
            await self._pub_queue.join()
            self._pub_queue = None

        await asyncio.gather(
            *[
                subscriber.stop(join=False)
                for subscribers in self._channels.values()
                for subscriber in subscribers
                if subscriber.is_running
            ]
        )

        if self._sub_task:
            self._sub_task.cancel()
            with suppress(CancelledError):
                await self._sub_task
            self._sub_task = None

        if self._pub_task:
            self._pub_task.cancel()
            with suppress(CancelledError):
                await self._pub_task
            self._sub_task = None

        await self._backend.on_shutdown()

    async def __aenter__(self) -> ChannelsPlugin:
        await self._on_startup()
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        await self._on_shutdown()