summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/channels/plugin.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/channels/plugin.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/channels/plugin.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/channels/plugin.py359
1 files changed, 359 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/channels/plugin.py b/venv/lib/python3.11/site-packages/litestar/channels/plugin.py
new file mode 100644
index 0000000..5988445
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/channels/plugin.py
@@ -0,0 +1,359 @@
+from __future__ import annotations
+
+import asyncio
+from asyncio import CancelledError, Queue, Task, create_task
+from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress
+from functools import partial
+from typing import TYPE_CHECKING, AsyncGenerator, Awaitable, Callable, Iterable
+
+import msgspec.json
+
+from litestar.di import Provide
+from litestar.exceptions import ImproperlyConfiguredException, LitestarException
+from litestar.handlers import WebsocketRouteHandler
+from litestar.plugins import InitPluginProtocol
+from litestar.serialization import default_serializer
+
+from .subscriber import BacklogStrategy, EventCallback, Subscriber
+
+if TYPE_CHECKING:
+ from types import TracebackType
+
+ from litestar.channels.backends.base import ChannelsBackend
+ from litestar.config.app import AppConfig
+ from litestar.connection import WebSocket
+ from litestar.types import LitestarEncodableType, TypeEncodersMap
+ from litestar.types.asgi_types import WebSocketMode
+
+
+class ChannelsException(LitestarException):
+ pass
+
+
+class ChannelsPlugin(InitPluginProtocol, AbstractAsyncContextManager):
+ def __init__(
+ self,
+ backend: ChannelsBackend,
+ *,
+ channels: Iterable[str] | None = None,
+ arbitrary_channels_allowed: bool = False,
+ create_ws_route_handlers: bool = False,
+ ws_handler_send_history: int = 0,
+ ws_handler_base_path: str = "/",
+ ws_send_mode: WebSocketMode = "text",
+ subscriber_max_backlog: int | None = None,
+ subscriber_backlog_strategy: BacklogStrategy = "backoff",
+ subscriber_class: type[Subscriber] = Subscriber,
+ type_encoders: TypeEncodersMap | None = None,
+ ) -> None:
+ """Plugin to handle broadcasting to WebSockets with support for channels.
+
+ This plugin is available as an injected dependency using the ``channels`` key.
+
+ Args:
+ backend: Backend to store data in
+ channels: Channels to serve. If ``arbitrary_channels_allowed`` is ``False`` (the default), trying to
+ subscribe to a channel not in ``channels`` will raise an exception
+ arbitrary_channels_allowed: Allow the creation of channels on the fly
+ create_ws_route_handlers: If ``True``, websocket route handlers will be created for all channels defined in
+ ``channels``. If ``arbitrary_channels_allowed`` is ``True``, a single handler will be created instead,
+ handling all channels. The handlers created will accept WebSocket connections and start sending received
+ events for their respective channels.
+ ws_handler_send_history: Amount of history entries to send from the generated websocket route handlers after
+ a client has connected. A value of ``0`` indicates to not send a history
+ ws_handler_base_path: Path prefix used for the generated route handlers
+ ws_send_mode: Send mode to use for sending data through a :class:`WebSocket <.connection.WebSocket>`.
+ This will be used when sending within generated route handlers or :meth:`Subscriber.run_in_background`
+ subscriber_max_backlog: Maximum amount of unsent messages to be held in memory for a given subscriber. If
+ that limit is reached, new messages will be treated accordingly to ``backlog_strategy``
+ subscriber_backlog_strategy: Define the behaviour if ``max_backlog`` is reached for a subscriber. `
+ `backoff`` will result in new messages being dropped until older ones have been processed. ``dropleft``
+ will drop older messages in favour of new ones.
+ subscriber_class: A :class:`Subscriber` subclass to return from :meth:`subscribe`
+ type_encoders: An additional mapping of type encoders used to encode data before sending
+
+ """
+ self._backend = backend
+ self._pub_queue: Queue[tuple[bytes, list[str]]] | None = None
+ self._pub_task: Task | None = None
+ self._sub_task: Task | None = None
+
+ if not (channels or arbitrary_channels_allowed):
+ raise ImproperlyConfiguredException("Must define either channels or set arbitrary_channels_allowed=True")
+
+ # make the path absolute, so we can simply concatenate it later
+ if not ws_handler_base_path.endswith("/"):
+ ws_handler_base_path += "/"
+
+ self._arbitrary_channels_allowed = arbitrary_channels_allowed
+ self._create_route_handlers = create_ws_route_handlers
+ self._handler_root_path = ws_handler_base_path
+ self._socket_send_mode: WebSocketMode = ws_send_mode
+ self._encode_json = msgspec.json.Encoder(
+ enc_hook=partial(default_serializer, type_encoders=type_encoders)
+ ).encode
+ self._handler_should_send_history = bool(ws_handler_send_history)
+ self._history_limit = None if ws_handler_send_history < 0 else ws_handler_send_history
+ self._max_backlog = subscriber_max_backlog
+ self._backlog_strategy: BacklogStrategy = subscriber_backlog_strategy
+ self._subscriber_class = subscriber_class
+
+ self._channels: dict[str, set[Subscriber]] = {channel: set() for channel in channels or []}
+
+ def encode_data(self, data: LitestarEncodableType) -> bytes:
+ """Encode data before storing it in the backend"""
+ if isinstance(data, bytes):
+ return data
+
+ return data.encode() if isinstance(data, str) else self._encode_json(data)
+
+ def on_app_init(self, app_config: AppConfig) -> AppConfig:
+ """Plugin hook. Set up a ``channels`` dependency, add route handlers and register application hooks"""
+ app_config.dependencies["channels"] = Provide(lambda: self, use_cache=True, sync_to_thread=False)
+ app_config.lifespan.append(self)
+ app_config.signature_namespace.update(ChannelsPlugin=ChannelsPlugin)
+
+ if self._create_route_handlers:
+ if self._arbitrary_channels_allowed:
+ path = self._handler_root_path + "{channel_name:str}"
+ route_handlers = [WebsocketRouteHandler(path)(self._ws_handler_func)]
+ else:
+ route_handlers = [
+ WebsocketRouteHandler(self._handler_root_path + channel_name)(
+ self._create_ws_handler_func(channel_name)
+ )
+ for channel_name in self._channels
+ ]
+ app_config.route_handlers.extend(route_handlers)
+
+ return app_config
+
+ def publish(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None:
+ """Schedule ``data`` to be published to ``channels``.
+
+ .. note::
+ This is a synchronous method that returns immediately. There are no
+ guarantees that when this method returns the data will have been published
+ to the backend. For that, use :meth:`wait_published`
+
+ """
+ if isinstance(channels, str):
+ channels = [channels]
+ data = self.encode_data(data)
+ try:
+ self._pub_queue.put_nowait((data, list(channels))) # type: ignore[union-attr]
+ except AttributeError as e:
+ raise RuntimeError("Plugin not yet initialized. Did you forget to call on_startup?") from e
+
+ async def wait_published(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None:
+ """Publish ``data`` to ``channels``"""
+ if isinstance(channels, str):
+ channels = [channels]
+ data = self.encode_data(data)
+
+ await self._backend.publish(data, channels)
+
+ async def subscribe(self, channels: str | Iterable[str], history: int | None = None) -> Subscriber:
+ """Create a :class:`Subscriber`, providing a stream of all events in ``channels``.
+
+ The created subscriber will be passive by default and has to be consumed manually,
+ either by using :meth:`Subscriber.run_in_background` or iterating over events
+ using :meth:`Subscriber.iter_events`.
+
+ Args:
+ channels: Channel(s) to subscribe to
+ history: If a non-negative integer, add this amount of history entries from
+ each channel to the subscriber's event stream. Note that this will wait
+ until all history entries are fetched and pushed to the subscriber's
+ stream. For more control use :meth:`put_subscriber_history`.
+
+ Returns:
+ A :class:`Subscriber`
+
+ Raises:
+ ChannelsException: If a channel in ``channels`` has not been declared on this backend and
+ ``arbitrary_channels_allowed`` has not been set to ``True``
+ """
+ if isinstance(channels, str):
+ channels = [channels]
+
+ subscriber = self._subscriber_class(
+ plugin=self,
+ max_backlog=self._max_backlog,
+ backlog_strategy=self._backlog_strategy,
+ )
+ channels_to_subscribe = set()
+
+ for channel in channels:
+ if channel not in self._channels:
+ if not self._arbitrary_channels_allowed:
+ raise ChannelsException(
+ f"Unknown channel: {channel!r}. Either explicitly defined the channel or set "
+ "arbitrary_channels_allowed=True"
+ )
+ self._channels[channel] = set()
+ channel_subscribers = self._channels[channel]
+ if not channel_subscribers:
+ channels_to_subscribe.add(channel)
+
+ channel_subscribers.add(subscriber)
+
+ if channels_to_subscribe:
+ await self._backend.subscribe(channels_to_subscribe)
+
+ if history:
+ await self.put_subscriber_history(subscriber=subscriber, limit=history, channels=channels)
+
+ return subscriber
+
+ async def unsubscribe(self, subscriber: Subscriber, channels: str | Iterable[str] | None = None) -> None:
+ """Unsubscribe a :class:`Subscriber` from ``channels``. If the subscriber has a running sending task, it will
+ be stopped.
+
+ Args:
+ channels: Channels to unsubscribe from. If ``None``, unsubscribe from all channels
+ subscriber: :class:`Subscriber` to unsubscribe
+ """
+ if channels is None:
+ channels = list(self._channels.keys())
+ elif isinstance(channels, str):
+ channels = [channels]
+
+ channels_to_unsubscribe: set[str] = set()
+
+ for channel in channels:
+ channel_subscribers = self._channels[channel]
+
+ try:
+ channel_subscribers.remove(subscriber)
+ except KeyError: # subscriber was not subscribed to this channel. This may happen if channels is None
+ continue
+
+ if not channel_subscribers:
+ channels_to_unsubscribe.add(channel)
+
+ if all(subscriber not in queues for queues in self._channels.values()):
+ await subscriber.put(None) # this will stop any running task or generator by breaking the inner loop
+ if subscriber.is_running:
+ await subscriber.stop()
+
+ if channels_to_unsubscribe:
+ await self._backend.unsubscribe(channels_to_unsubscribe)
+
+ @asynccontextmanager
+ async def start_subscription(
+ self, channels: str | Iterable[str], history: int | None = None
+ ) -> AsyncGenerator[Subscriber, None]:
+ """Create a :class:`Subscriber` and tie its subscriptions to a context manager;
+ Upon exiting the context, :meth:`unsubscribe` will be called.
+
+ Args:
+ channels: Channel(s) to subscribe to
+ history: If a non-negative integer, add this amount of history entries from
+ each channel to the subscriber's event stream. Note that this will wait
+ until all history entries are fetched and pushed to the subscriber's
+ stream. For more control use :meth:`put_subscriber_history`.
+
+ Returns:
+ A :class:`Subscriber`
+ """
+ subscriber = await self.subscribe(channels, history=history)
+
+ try:
+ yield subscriber
+ finally:
+ await self.unsubscribe(subscriber, channels)
+
+ async def put_subscriber_history(
+ self, subscriber: Subscriber, channels: str | Iterable[str], limit: int | None = None
+ ) -> None:
+ """Fetch the history of ``channels`` from the backend and put them in the
+ subscriber's stream
+ """
+ if isinstance(channels, str):
+ channels = [channels]
+
+ for channel in channels:
+ history = await self._backend.get_history(channel, limit)
+ for entry in history:
+ await subscriber.put(entry)
+
+ async def _ws_handler_func(self, channel_name: str, socket: WebSocket) -> None:
+ await socket.accept()
+
+ # the ternary operator triggers a mypy bug: https://github.com/python/mypy/issues/10740
+ on_event: EventCallback = socket.send_text if self._socket_send_mode == "text" else socket.send_bytes # type: ignore[assignment]
+
+ async with self.start_subscription(channel_name) as subscriber:
+ if self._handler_should_send_history:
+ await self.put_subscriber_history(subscriber, channels=channel_name, limit=self._history_limit)
+
+ # use the background task, so we can block on receive(), breaking the loop when a connection closes
+ async with subscriber.run_in_background(on_event):
+ while (await socket.receive())["type"] != "websocket.disconnect":
+ continue
+
+ def _create_ws_handler_func(self, channel_name: str) -> Callable[[WebSocket], Awaitable[None]]:
+ async def ws_handler_func(socket: WebSocket) -> None:
+ await self._ws_handler_func(channel_name=channel_name, socket=socket)
+
+ return ws_handler_func
+
+ async def _pub_worker(self) -> None:
+ while self._pub_queue:
+ data, channels = await self._pub_queue.get()
+ await self._backend.publish(data, channels)
+ self._pub_queue.task_done()
+
+ async def _sub_worker(self) -> None:
+ async for channel, payload in self._backend.stream_events():
+ for subscriber in self._channels.get(channel, []):
+ subscriber.put_nowait(payload)
+
+ async def _on_startup(self) -> None:
+ await self._backend.on_startup()
+ self._pub_queue = Queue()
+ self._pub_task = create_task(self._pub_worker())
+ self._sub_task = create_task(self._sub_worker())
+ if self._channels:
+ await self._backend.subscribe(list(self._channels))
+
+ async def _on_shutdown(self) -> None:
+ if self._pub_queue:
+ await self._pub_queue.join()
+ self._pub_queue = None
+
+ await asyncio.gather(
+ *[
+ subscriber.stop(join=False)
+ for subscribers in self._channels.values()
+ for subscriber in subscribers
+ if subscriber.is_running
+ ]
+ )
+
+ if self._sub_task:
+ self._sub_task.cancel()
+ with suppress(CancelledError):
+ await self._sub_task
+ self._sub_task = None
+
+ if self._pub_task:
+ self._pub_task.cancel()
+ with suppress(CancelledError):
+ await self._pub_task
+ self._sub_task = None
+
+ await self._backend.on_shutdown()
+
+ async def __aenter__(self) -> ChannelsPlugin:
+ await self._on_startup()
+ return self
+
+ async def __aexit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ await self._on_shutdown()