diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/channels')
21 files changed, 1078 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__init__.py b/venv/lib/python3.11/site-packages/litestar/channels/__init__.py new file mode 100644 index 0000000..0167223 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__init__.py @@ -0,0 +1,5 @@ +from .backends.base import ChannelsBackend +from .plugin import ChannelsPlugin +from .subscriber import Subscriber + +__all__ = ("ChannelsPlugin", "ChannelsBackend", "Subscriber") diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bf9d6bd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/plugin.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/plugin.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..08361dc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/plugin.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/subscriber.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/subscriber.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..8d609b4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/__pycache__/subscriber.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__init__.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__init__.py diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..ab4e477 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/asyncpg.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/asyncpg.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..a577096 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/asyncpg.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..334d295 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/memory.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/memory.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..9a87da5 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/memory.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/psycopg.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/psycopg.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f663280 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/psycopg.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/redis.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/redis.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..bf86a3e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/__pycache__/redis.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_flushall_streams.lua b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_flushall_streams.lua new file mode 100644 index 0000000..a3faa6e --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_flushall_streams.lua @@ -0,0 +1,15 @@ +local key_pattern = ARGV[1] + +local cursor = 0 +local deleted_streams = 0 + +repeat + local result = redis.call('SCAN', cursor, 'MATCH', key_pattern) + for _,key in ipairs(result[2]) do + redis.call('DEL', key) + deleted_streams = deleted_streams + 1 + end + cursor = tonumber(result[1]) +until cursor == 0 + +return deleted_streams diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_pubsub_publish.lua b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_pubsub_publish.lua new file mode 100644 index 0000000..8402d08 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_pubsub_publish.lua @@ -0,0 +1,5 @@ +local data = ARGV[1] + +for _, channel in ipairs(KEYS) do + redis.call("PUBLISH", channel, data) +end diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_xadd_expire.lua b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_xadd_expire.lua new file mode 100644 index 0000000..f6b322f --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/_redis_xadd_expire.lua @@ -0,0 +1,13 @@ +local data = ARGV[1] +local limit = ARGV[2] +local exp = ARGV[3] +local maxlen_approx = ARGV[4] + +for i, key in ipairs(KEYS) do + if maxlen_approx == 1 then + redis.call("XADD", key, "MAXLEN", "~", limit, "*", "data", data, "channel", ARGV[i + 4]) + else + redis.call("XADD", key, "MAXLEN", limit, "*", "data", data, "channel", ARGV[i + 4]) + end + redis.call("PEXPIRE", key, exp) +end diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py new file mode 100644 index 0000000..967b208 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import asyncio +from contextlib import AsyncExitStack +from functools import partial +from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload + +import asyncpg + +from litestar.channels import ChannelsBackend +from litestar.exceptions import ImproperlyConfiguredException + + +class AsyncPgChannelsBackend(ChannelsBackend): + _listener_conn: asyncpg.Connection + + @overload + def __init__(self, dsn: str) -> None: ... + + @overload + def __init__( + self, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]], + ) -> None: ... + + def __init__( + self, + dsn: str | None = None, + *, + make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None, + ) -> None: + if not (dsn or make_connection): + raise ImproperlyConfiguredException("Need to specify dsn or make_connection") + + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + self._connect = make_connection or partial(asyncpg.connect, dsn=dsn) + self._queue: asyncio.Queue[tuple[str, bytes]] | None = None + + async def on_startup(self) -> None: + self._queue = asyncio.Queue() + self._listener_conn = await self._connect() + + async def on_shutdown(self) -> None: + await self._listener_conn.close() + self._queue = None + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + dec_data = data.decode("utf-8") + + conn = await self._connect() + try: + for channel in channels: + await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data) + finally: + await conn.close() + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type] + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + while True: + channel, message = await self._queue.get() + self._queue.task_done() + # an UNLISTEN may be in transit while we're getting here, so we double-check + # that we are actually supposed to deliver this message + if channel in self._subscribed_channels: + yield channel, message + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() + + def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None: + if not isinstance(payload, str): + raise RuntimeError("Invalid data received") + self._queue.put_nowait((channel, payload.encode("utf-8"))) # type: ignore[union-attr] diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/base.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/base.py new file mode 100644 index 0000000..ce7ee81 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/base.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Iterable + + +class ChannelsBackend(ABC): + @abstractmethod + async def on_startup(self) -> None: + """Called by the plugin on application startup""" + ... + + @abstractmethod + async def on_shutdown(self) -> None: + """Called by the plugin on application shutdown""" + ... + + @abstractmethod + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish the message ``data`` to all ``channels``""" + ... + + @abstractmethod + async def subscribe(self, channels: Iterable[str]) -> None: + """Start listening for events on ``channels``""" + ... + + @abstractmethod + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Stop listening for events on ``channels``""" + ... + + @abstractmethod + def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + """Return a generator, iterating over events of subscribed channels as they become available""" + ... + + @abstractmethod + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Return the event history of ``channel``, at most ``limit`` entries""" + ... diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/memory.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/memory.py new file mode 100644 index 0000000..a96a66b --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/memory.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from asyncio import Queue +from collections import defaultdict, deque +from typing import Any, AsyncGenerator, Iterable + +from litestar.channels.backends.base import ChannelsBackend + + +class MemoryChannelsBackend(ChannelsBackend): + """An in-memory channels backend""" + + def __init__(self, history: int = 0) -> None: + self._max_history_length = history + self._channels: set[str] = set() + self._queue: Queue[tuple[str, bytes]] | None = None + self._history: defaultdict[str, deque[bytes]] = defaultdict(lambda: deque(maxlen=self._max_history_length)) + + async def on_startup(self) -> None: + self._queue = Queue() + + async def on_shutdown(self) -> None: + self._queue = None + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish ``data`` to ``channels``. If a channel has not yet been subscribed to, + this will be a no-op. + + Args: + data: Data to publish + channels: Channels to publish to + + Returns: + None + + Raises: + RuntimeError: If ``on_startup`` has not been called yet + """ + if not self._queue: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + for channel in channels: + if channel not in self._channels: + continue + + self._queue.put_nowait((channel, data)) + if self._max_history_length: + for channel in channels: + self._history[channel].append(data) + + async def subscribe(self, channels: Iterable[str]) -> None: + """Subscribe to ``channels``, and enable publishing to them""" + self._channels.update(channels) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Unsubscribe from ``channels``""" + self._channels -= set(channels) + try: + for channel in channels: + del self._history[channel] + except KeyError: + pass + + async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: + """Return a generator, iterating over events of subscribed channels as they become available""" + if self._queue is None: + raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?") + + while True: + channel, message = await self._queue.get() + self._queue.task_done() + + # if a message is published to a channel and the channel is then + # unsubscribed before retrieving that message from the stream, it can still + # end up here, so we double-check if we still are interested in this message + if channel in self._channels: + yield channel, message + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Return the event history of ``channel``, at most ``limit`` entries""" + history = list(self._history[channel]) + if limit: + history = history[-limit:] + return history diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py new file mode 100644 index 0000000..14b53bc --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from contextlib import AsyncExitStack +from typing import AsyncGenerator, Iterable + +import psycopg + +from .base import ChannelsBackend + + +def _safe_quote(ident: str) -> str: + return '"{}"'.format(ident.replace('"', '""')) # sourcery skip + + +class PsycoPgChannelsBackend(ChannelsBackend): + _listener_conn: psycopg.AsyncConnection + + def __init__(self, pg_dsn: str) -> None: + self._pg_dsn = pg_dsn + self._subscribed_channels: set[str] = set() + self._exit_stack = AsyncExitStack() + + async def on_startup(self) -> None: + self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True) + await self._exit_stack.enter_async_context(self._listener_conn) + + async def on_shutdown(self) -> None: + await self._exit_stack.aclose() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + dec_data = data.decode("utf-8") + async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn: + for channel in channels: + await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data)) + + async def subscribe(self, channels: Iterable[str]) -> None: + for channel in set(channels) - self._subscribed_channels: + # can't use placeholders in LISTEN + await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore + + self._subscribed_channels.add(channel) + + async def unsubscribe(self, channels: Iterable[str]) -> None: + for channel in channels: + # can't use placeholders in UNLISTEN + await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore + self._subscribed_channels = self._subscribed_channels - set(channels) + + async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]: + async for notify in self._listener_conn.notifies(): + yield notify.channel, notify.payload.encode("utf-8") + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + raise NotImplementedError() diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/redis.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/redis.py new file mode 100644 index 0000000..f03c9f2 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/redis.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import asyncio +import sys + +if sys.version_info < (3, 9): + import importlib_resources # pyright: ignore +else: + import importlib.resources as importlib_resources +from abc import ABC +from datetime import timedelta +from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, cast + +from litestar.channels.backends.base import ChannelsBackend + +if TYPE_CHECKING: + from redis.asyncio import Redis + from redis.asyncio.client import PubSub + +_resource_path = importlib_resources.files("litestar.channels.backends") +_PUBSUB_PUBLISH_SCRIPT = (_resource_path / "_redis_pubsub_publish.lua").read_text() +_FLUSHALL_STREAMS_SCRIPT = (_resource_path / "_redis_flushall_streams.lua").read_text() +_XADD_EXPIRE_SCRIPT = (_resource_path / "_redis_xadd_expire.lua").read_text() + + +class _LazyEvent: + """A lazy proxy to asyncio.Event that only creates the event once it's accessed. + + It ensures that the Event is created within a running event loop. If it's not, there can be an issue where a future + within the event itself is attached to a different loop. + + This happens in our tests and could also happen when a user creates an instance of the backend outside an event loop + in their application. + """ + + def __init__(self) -> None: + self.__event: asyncio.Event | None = None + + @property + def _event(self) -> asyncio.Event: + if self.__event is None: + self.__event = asyncio.Event() + return self.__event + + def set(self) -> None: + self._event.set() + + def clear(self) -> None: + self._event.clear() + + async def wait(self) -> None: + await self._event.wait() + + +class RedisChannelsBackend(ChannelsBackend, ABC): + def __init__(self, *, redis: Redis, key_prefix: str, stream_sleep_no_subscriptions: int) -> None: + """Base redis channels backend. + + Args: + redis: A :class:`redis.asyncio.Redis` instance + key_prefix: Key prefix to use for storing data in redis + stream_sleep_no_subscriptions: Amount of time in milliseconds to pause the + :meth:`stream_events` generator, should no subscribers exist + """ + self._redis = redis + self._key_prefix = key_prefix + self._stream_sleep_no_subscriptions = stream_sleep_no_subscriptions + + def _make_key(self, channel: str) -> str: + return f"{self._key_prefix}_{channel.upper()}" + + +class RedisChannelsPubSubBackend(RedisChannelsBackend): + def __init__( + self, *, redis: Redis, stream_sleep_no_subscriptions: int = 1, key_prefix: str = "LITESTAR_CHANNELS" + ) -> None: + """Redis channels backend, `Pub/Sub <https://redis.io/docs/manual/pubsub/>`_. + + This backend provides low overhead and resource usage but no support for history. + + Args: + redis: A :class:`redis.asyncio.Redis` instance + key_prefix: Key prefix to use for storing data in redis + stream_sleep_no_subscriptions: Amount of time in milliseconds to pause the + :meth:`stream_events` generator, should no subscribers exist + """ + super().__init__( + redis=redis, stream_sleep_no_subscriptions=stream_sleep_no_subscriptions, key_prefix=key_prefix + ) + self.__pub_sub: PubSub | None = None + self._publish_script = self._redis.register_script(_PUBSUB_PUBLISH_SCRIPT) + self._has_subscribed = _LazyEvent() + + @property + def _pub_sub(self) -> PubSub: + if self.__pub_sub is None: + self.__pub_sub = self._redis.pubsub() + return self.__pub_sub + + async def on_startup(self) -> None: + # this method should not do anything in this case + pass + + async def on_shutdown(self) -> None: + await self._pub_sub.reset() + + async def subscribe(self, channels: Iterable[str]) -> None: + """Subscribe to ``channels``, and enable publishing to them""" + await self._pub_sub.subscribe(*channels) + self._has_subscribed.set() + + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Stop listening for events on ``channels``""" + await self._pub_sub.unsubscribe(*channels) + # if we have no active subscriptions, or only subscriptions which are pending + # to be unsubscribed we consider the backend to be unsubscribed from all + # channels, so we reset the event + if not self._pub_sub.channels.keys() - self._pub_sub.pending_unsubscribe_channels: + self._has_subscribed.clear() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish ``data`` to ``channels`` + + .. note:: + This operation is performed atomically, using a lua script + """ + await self._publish_script(keys=list(set(channels)), args=[data]) + + async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: + """Return a generator, iterating over events of subscribed channels as they become available. + + If no channels have been subscribed to yet via :meth:`subscribe`, sleep for ``stream_sleep_no_subscriptions`` + milliseconds. + """ + + while True: + await self._has_subscribed.wait() + message = await self._pub_sub.get_message( + ignore_subscribe_messages=True, timeout=self._stream_sleep_no_subscriptions + ) + if message is None: + continue + + channel: str = message["channel"].decode() + data: bytes = message["data"] + # redis handles the unsubscibes with a queue; Unsubscribing doesn't mean the + # unsubscribe will happen immediately after requesting it, so we could + # receive a message on a channel that, from a client's perspective, it's not + # subscribed to anymore + if channel.encode() in self._pub_sub.channels.keys() - self._pub_sub.pending_unsubscribe_channels: + yield channel, data + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Not implemented""" + raise NotImplementedError() + + +class RedisChannelsStreamBackend(RedisChannelsBackend): + def __init__( + self, + history: int, + *, + redis: Redis, + stream_sleep_no_subscriptions: int = 1, + cap_streams_approximate: bool = True, + stream_ttl: int | timedelta = timedelta(seconds=60), + key_prefix: str = "LITESTAR_CHANNELS", + ) -> None: + """Redis channels backend, `streams <https://redis.io/docs/data-types/streams/>`_. + + Args: + history: Amount of messages to keep. This will set a ``MAXLEN`` to the streams + redis: A :class:`redis.asyncio.Redis` instance + key_prefix: Key prefix to use for streams + stream_sleep_no_subscriptions: Amount of time in milliseconds to pause the + :meth:`stream_events` generator, should no subscribers exist + cap_streams_approximate: Set the streams ``MAXLEN`` using the ``~`` approximation + operator for improved performance + stream_ttl: TTL of a stream in milliseconds or as a timedelta. A streams TTL will be set on each publishing + operation using ``PEXPIRE`` + """ + super().__init__( + redis=redis, stream_sleep_no_subscriptions=stream_sleep_no_subscriptions, key_prefix=key_prefix + ) + + self._history_limit = history + self._subscribed_channels: set[str] = set() + self._cap_streams_approximate = cap_streams_approximate + self._stream_ttl = stream_ttl if isinstance(stream_ttl, int) else int(stream_ttl.total_seconds() * 1000) + self._flush_all_streams_script = self._redis.register_script(_FLUSHALL_STREAMS_SCRIPT) + self._publish_script = self._redis.register_script(_XADD_EXPIRE_SCRIPT) + self._has_subscribed_channels = _LazyEvent() + + async def on_startup(self) -> None: + """Called on application startup""" + + async def on_shutdown(self) -> None: + """Called on application shutdown""" + + async def subscribe(self, channels: Iterable[str]) -> None: + """Subscribe to ``channels``""" + self._subscribed_channels.update(channels) + self._has_subscribed_channels.set() + + async def unsubscribe(self, channels: Iterable[str]) -> None: + """Unsubscribe from ``channels``""" + self._subscribed_channels -= set(channels) + if not len(self._subscribed_channels): + self._has_subscribed_channels.clear() + + async def publish(self, data: bytes, channels: Iterable[str]) -> None: + """Publish ``data`` to ``channels``. + + .. note:: + This operation is performed atomically, using a Lua script + """ + channels = set(channels) + await self._publish_script( + keys=[self._make_key(key) for key in channels], + args=[ + data, + self._history_limit, + self._stream_ttl, + int(self._cap_streams_approximate), + *channels, + ], + ) + + async def _get_subscribed_channels(self) -> set[str]: + """Get subscribed channels. If no channels are currently subscribed, wait""" + await self._has_subscribed_channels.wait() + return self._subscribed_channels + + async def stream_events(self) -> AsyncGenerator[tuple[str, Any], None]: + """Return a generator, iterating over events of subscribed channels as they become available. + + If no channels have been subscribed to yet via :meth:`subscribe`, sleep for ``stream_sleep_no_subscriptions`` + milliseconds. + """ + stream_ids: dict[str, bytes] = {} + while True: + # We wait for subscribed channels, because we can't pass an empty dict to + # xread and block for subscribers + stream_keys = [self._make_key(c) for c in await self._get_subscribed_channels()] + + data: list[tuple[bytes, list[tuple[bytes, dict[bytes, bytes]]]]] = await self._redis.xread( + {key: stream_ids.get(key, 0) for key in stream_keys}, block=self._stream_sleep_no_subscriptions + ) + + if not data: + continue + + for stream_key, channel_events in data: + for event in channel_events: + event_data = event[1][b"data"] + channel_name = event[1][b"channel"].decode() + stream_ids[stream_key.decode()] = event[0] + yield channel_name, event_data + + async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]: + """Return the history of ``channels``, returning at most ``limit`` messages""" + data: Iterable[tuple[bytes, dict[bytes, bytes]]] + if limit: + data = reversed(await self._redis.xrevrange(self._make_key(channel), count=limit)) + else: + data = await self._redis.xrange(self._make_key(channel)) + + return [event[b"data"] for _, event in data] + + async def flush_all(self) -> int: + """Delete all stream keys with the ``key_prefix``. + + .. important:: + This method is incompatible with redis clusters + """ + deleted_streams = await self._flush_all_streams_script(keys=[], args=[f"{self._key_prefix}*"]) + return cast("int", deleted_streams) diff --git a/venv/lib/python3.11/site-packages/litestar/channels/plugin.py b/venv/lib/python3.11/site-packages/litestar/channels/plugin.py new file mode 100644 index 0000000..5988445 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/plugin.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +import asyncio +from asyncio import CancelledError, Queue, Task, create_task +from contextlib import AbstractAsyncContextManager, asynccontextmanager, suppress +from functools import partial +from typing import TYPE_CHECKING, AsyncGenerator, Awaitable, Callable, Iterable + +import msgspec.json + +from litestar.di import Provide +from litestar.exceptions import ImproperlyConfiguredException, LitestarException +from litestar.handlers import WebsocketRouteHandler +from litestar.plugins import InitPluginProtocol +from litestar.serialization import default_serializer + +from .subscriber import BacklogStrategy, EventCallback, Subscriber + +if TYPE_CHECKING: + from types import TracebackType + + from litestar.channels.backends.base import ChannelsBackend + from litestar.config.app import AppConfig + from litestar.connection import WebSocket + from litestar.types import LitestarEncodableType, TypeEncodersMap + from litestar.types.asgi_types import WebSocketMode + + +class ChannelsException(LitestarException): + pass + + +class ChannelsPlugin(InitPluginProtocol, AbstractAsyncContextManager): + def __init__( + self, + backend: ChannelsBackend, + *, + channels: Iterable[str] | None = None, + arbitrary_channels_allowed: bool = False, + create_ws_route_handlers: bool = False, + ws_handler_send_history: int = 0, + ws_handler_base_path: str = "/", + ws_send_mode: WebSocketMode = "text", + subscriber_max_backlog: int | None = None, + subscriber_backlog_strategy: BacklogStrategy = "backoff", + subscriber_class: type[Subscriber] = Subscriber, + type_encoders: TypeEncodersMap | None = None, + ) -> None: + """Plugin to handle broadcasting to WebSockets with support for channels. + + This plugin is available as an injected dependency using the ``channels`` key. + + Args: + backend: Backend to store data in + channels: Channels to serve. If ``arbitrary_channels_allowed`` is ``False`` (the default), trying to + subscribe to a channel not in ``channels`` will raise an exception + arbitrary_channels_allowed: Allow the creation of channels on the fly + create_ws_route_handlers: If ``True``, websocket route handlers will be created for all channels defined in + ``channels``. If ``arbitrary_channels_allowed`` is ``True``, a single handler will be created instead, + handling all channels. The handlers created will accept WebSocket connections and start sending received + events for their respective channels. + ws_handler_send_history: Amount of history entries to send from the generated websocket route handlers after + a client has connected. A value of ``0`` indicates to not send a history + ws_handler_base_path: Path prefix used for the generated route handlers + ws_send_mode: Send mode to use for sending data through a :class:`WebSocket <.connection.WebSocket>`. + This will be used when sending within generated route handlers or :meth:`Subscriber.run_in_background` + subscriber_max_backlog: Maximum amount of unsent messages to be held in memory for a given subscriber. If + that limit is reached, new messages will be treated accordingly to ``backlog_strategy`` + subscriber_backlog_strategy: Define the behaviour if ``max_backlog`` is reached for a subscriber. ` + `backoff`` will result in new messages being dropped until older ones have been processed. ``dropleft`` + will drop older messages in favour of new ones. + subscriber_class: A :class:`Subscriber` subclass to return from :meth:`subscribe` + type_encoders: An additional mapping of type encoders used to encode data before sending + + """ + self._backend = backend + self._pub_queue: Queue[tuple[bytes, list[str]]] | None = None + self._pub_task: Task | None = None + self._sub_task: Task | None = None + + if not (channels or arbitrary_channels_allowed): + raise ImproperlyConfiguredException("Must define either channels or set arbitrary_channels_allowed=True") + + # make the path absolute, so we can simply concatenate it later + if not ws_handler_base_path.endswith("/"): + ws_handler_base_path += "/" + + self._arbitrary_channels_allowed = arbitrary_channels_allowed + self._create_route_handlers = create_ws_route_handlers + self._handler_root_path = ws_handler_base_path + self._socket_send_mode: WebSocketMode = ws_send_mode + self._encode_json = msgspec.json.Encoder( + enc_hook=partial(default_serializer, type_encoders=type_encoders) + ).encode + self._handler_should_send_history = bool(ws_handler_send_history) + self._history_limit = None if ws_handler_send_history < 0 else ws_handler_send_history + self._max_backlog = subscriber_max_backlog + self._backlog_strategy: BacklogStrategy = subscriber_backlog_strategy + self._subscriber_class = subscriber_class + + self._channels: dict[str, set[Subscriber]] = {channel: set() for channel in channels or []} + + def encode_data(self, data: LitestarEncodableType) -> bytes: + """Encode data before storing it in the backend""" + if isinstance(data, bytes): + return data + + return data.encode() if isinstance(data, str) else self._encode_json(data) + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Plugin hook. Set up a ``channels`` dependency, add route handlers and register application hooks""" + app_config.dependencies["channels"] = Provide(lambda: self, use_cache=True, sync_to_thread=False) + app_config.lifespan.append(self) + app_config.signature_namespace.update(ChannelsPlugin=ChannelsPlugin) + + if self._create_route_handlers: + if self._arbitrary_channels_allowed: + path = self._handler_root_path + "{channel_name:str}" + route_handlers = [WebsocketRouteHandler(path)(self._ws_handler_func)] + else: + route_handlers = [ + WebsocketRouteHandler(self._handler_root_path + channel_name)( + self._create_ws_handler_func(channel_name) + ) + for channel_name in self._channels + ] + app_config.route_handlers.extend(route_handlers) + + return app_config + + def publish(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None: + """Schedule ``data`` to be published to ``channels``. + + .. note:: + This is a synchronous method that returns immediately. There are no + guarantees that when this method returns the data will have been published + to the backend. For that, use :meth:`wait_published` + + """ + if isinstance(channels, str): + channels = [channels] + data = self.encode_data(data) + try: + self._pub_queue.put_nowait((data, list(channels))) # type: ignore[union-attr] + except AttributeError as e: + raise RuntimeError("Plugin not yet initialized. Did you forget to call on_startup?") from e + + async def wait_published(self, data: LitestarEncodableType, channels: str | Iterable[str]) -> None: + """Publish ``data`` to ``channels``""" + if isinstance(channels, str): + channels = [channels] + data = self.encode_data(data) + + await self._backend.publish(data, channels) + + async def subscribe(self, channels: str | Iterable[str], history: int | None = None) -> Subscriber: + """Create a :class:`Subscriber`, providing a stream of all events in ``channels``. + + The created subscriber will be passive by default and has to be consumed manually, + either by using :meth:`Subscriber.run_in_background` or iterating over events + using :meth:`Subscriber.iter_events`. + + Args: + channels: Channel(s) to subscribe to + history: If a non-negative integer, add this amount of history entries from + each channel to the subscriber's event stream. Note that this will wait + until all history entries are fetched and pushed to the subscriber's + stream. For more control use :meth:`put_subscriber_history`. + + Returns: + A :class:`Subscriber` + + Raises: + ChannelsException: If a channel in ``channels`` has not been declared on this backend and + ``arbitrary_channels_allowed`` has not been set to ``True`` + """ + if isinstance(channels, str): + channels = [channels] + + subscriber = self._subscriber_class( + plugin=self, + max_backlog=self._max_backlog, + backlog_strategy=self._backlog_strategy, + ) + channels_to_subscribe = set() + + for channel in channels: + if channel not in self._channels: + if not self._arbitrary_channels_allowed: + raise ChannelsException( + f"Unknown channel: {channel!r}. Either explicitly defined the channel or set " + "arbitrary_channels_allowed=True" + ) + self._channels[channel] = set() + channel_subscribers = self._channels[channel] + if not channel_subscribers: + channels_to_subscribe.add(channel) + + channel_subscribers.add(subscriber) + + if channels_to_subscribe: + await self._backend.subscribe(channels_to_subscribe) + + if history: + await self.put_subscriber_history(subscriber=subscriber, limit=history, channels=channels) + + return subscriber + + async def unsubscribe(self, subscriber: Subscriber, channels: str | Iterable[str] | None = None) -> None: + """Unsubscribe a :class:`Subscriber` from ``channels``. If the subscriber has a running sending task, it will + be stopped. + + Args: + channels: Channels to unsubscribe from. If ``None``, unsubscribe from all channels + subscriber: :class:`Subscriber` to unsubscribe + """ + if channels is None: + channels = list(self._channels.keys()) + elif isinstance(channels, str): + channels = [channels] + + channels_to_unsubscribe: set[str] = set() + + for channel in channels: + channel_subscribers = self._channels[channel] + + try: + channel_subscribers.remove(subscriber) + except KeyError: # subscriber was not subscribed to this channel. This may happen if channels is None + continue + + if not channel_subscribers: + channels_to_unsubscribe.add(channel) + + if all(subscriber not in queues for queues in self._channels.values()): + await subscriber.put(None) # this will stop any running task or generator by breaking the inner loop + if subscriber.is_running: + await subscriber.stop() + + if channels_to_unsubscribe: + await self._backend.unsubscribe(channels_to_unsubscribe) + + @asynccontextmanager + async def start_subscription( + self, channels: str | Iterable[str], history: int | None = None + ) -> AsyncGenerator[Subscriber, None]: + """Create a :class:`Subscriber` and tie its subscriptions to a context manager; + Upon exiting the context, :meth:`unsubscribe` will be called. + + Args: + channels: Channel(s) to subscribe to + history: If a non-negative integer, add this amount of history entries from + each channel to the subscriber's event stream. Note that this will wait + until all history entries are fetched and pushed to the subscriber's + stream. For more control use :meth:`put_subscriber_history`. + + Returns: + A :class:`Subscriber` + """ + subscriber = await self.subscribe(channels, history=history) + + try: + yield subscriber + finally: + await self.unsubscribe(subscriber, channels) + + async def put_subscriber_history( + self, subscriber: Subscriber, channels: str | Iterable[str], limit: int | None = None + ) -> None: + """Fetch the history of ``channels`` from the backend and put them in the + subscriber's stream + """ + if isinstance(channels, str): + channels = [channels] + + for channel in channels: + history = await self._backend.get_history(channel, limit) + for entry in history: + await subscriber.put(entry) + + async def _ws_handler_func(self, channel_name: str, socket: WebSocket) -> None: + await socket.accept() + + # the ternary operator triggers a mypy bug: https://github.com/python/mypy/issues/10740 + on_event: EventCallback = socket.send_text if self._socket_send_mode == "text" else socket.send_bytes # type: ignore[assignment] + + async with self.start_subscription(channel_name) as subscriber: + if self._handler_should_send_history: + await self.put_subscriber_history(subscriber, channels=channel_name, limit=self._history_limit) + + # use the background task, so we can block on receive(), breaking the loop when a connection closes + async with subscriber.run_in_background(on_event): + while (await socket.receive())["type"] != "websocket.disconnect": + continue + + def _create_ws_handler_func(self, channel_name: str) -> Callable[[WebSocket], Awaitable[None]]: + async def ws_handler_func(socket: WebSocket) -> None: + await self._ws_handler_func(channel_name=channel_name, socket=socket) + + return ws_handler_func + + async def _pub_worker(self) -> None: + while self._pub_queue: + data, channels = await self._pub_queue.get() + await self._backend.publish(data, channels) + self._pub_queue.task_done() + + async def _sub_worker(self) -> None: + async for channel, payload in self._backend.stream_events(): + for subscriber in self._channels.get(channel, []): + subscriber.put_nowait(payload) + + async def _on_startup(self) -> None: + await self._backend.on_startup() + self._pub_queue = Queue() + self._pub_task = create_task(self._pub_worker()) + self._sub_task = create_task(self._sub_worker()) + if self._channels: + await self._backend.subscribe(list(self._channels)) + + async def _on_shutdown(self) -> None: + if self._pub_queue: + await self._pub_queue.join() + self._pub_queue = None + + await asyncio.gather( + *[ + subscriber.stop(join=False) + for subscribers in self._channels.values() + for subscriber in subscribers + if subscriber.is_running + ] + ) + + if self._sub_task: + self._sub_task.cancel() + with suppress(CancelledError): + await self._sub_task + self._sub_task = None + + if self._pub_task: + self._pub_task.cancel() + with suppress(CancelledError): + await self._pub_task + self._sub_task = None + + await self._backend.on_shutdown() + + async def __aenter__(self) -> ChannelsPlugin: + await self._on_startup() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._on_shutdown() diff --git a/venv/lib/python3.11/site-packages/litestar/channels/subscriber.py b/venv/lib/python3.11/site-packages/litestar/channels/subscriber.py new file mode 100644 index 0000000..b358bc4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/channels/subscriber.py @@ -0,0 +1,135 @@ +from __future__ import annotations + +import asyncio +from asyncio import CancelledError, Queue, QueueFull +from collections import deque +from contextlib import AsyncExitStack, asynccontextmanager, suppress +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, Generic, Literal, TypeVar + +if TYPE_CHECKING: + from litestar.channels import ChannelsPlugin + + +T = TypeVar("T") + +BacklogStrategy = Literal["backoff", "dropleft"] + +EventCallback = Callable[[bytes], Awaitable[Any]] + + +class AsyncDeque(Queue, Generic[T]): + def __init__(self, maxsize: int | None) -> None: + self._deque_maxlen = maxsize + super().__init__() + + def _init(self, maxsize: int) -> None: + self._queue: deque[T] = deque(maxlen=self._deque_maxlen) + + +class Subscriber: + """A wrapper around a stream of events published to subscribed channels""" + + def __init__( + self, + plugin: ChannelsPlugin, + max_backlog: int | None = None, + backlog_strategy: BacklogStrategy = "backoff", + ) -> None: + self._task: asyncio.Task | None = None + self._plugin = plugin + self._backend = plugin._backend + self._queue: Queue[bytes | None] | AsyncDeque[bytes | None] + + if max_backlog and backlog_strategy == "dropleft": + self._queue = AsyncDeque(maxsize=max_backlog or 0) + else: + self._queue = Queue(maxsize=max_backlog or 0) + + async def put(self, item: bytes | None) -> None: + await self._queue.put(item) + + def put_nowait(self, item: bytes | None) -> bool: + """Put an item in the subscriber's stream without waiting""" + try: + self._queue.put_nowait(item) + return True + except QueueFull: + return False + + @property + def qsize(self) -> int: + return self._queue.qsize() + + async def iter_events(self) -> AsyncGenerator[bytes, None]: + """Iterate over the stream of events. If no items are available, block until + one becomes available + """ + while True: + item = await self._queue.get() + if item is None: + self._queue.task_done() + break + yield item + self._queue.task_done() + + @asynccontextmanager + async def run_in_background(self, on_event: EventCallback, join: bool = True) -> AsyncGenerator[None, None]: + """Start a task in the background that sends events from the subscriber's stream + to ``socket`` as they become available. On exit, it will prevent the stream from + accepting new events and wait until the currently enqueued ones are processed. + Should the context be left with an exception, the task will be cancelled + immediately. + + Args: + on_event: Callback to invoke with the event data for every event + join: If ``True``, wait for all items in the stream to be processed before + stopping the worker. Note that an error occurring within the context + will always lead to the immediate cancellation of the worker + """ + self._start_in_background(on_event=on_event) + async with AsyncExitStack() as exit_stack: + exit_stack.push_async_callback(self.stop, join=False) + yield + exit_stack.pop_all() + await self.stop(join=join) + + async def _worker(self, on_event: EventCallback) -> None: + async for event in self.iter_events(): + await on_event(event) + + def _start_in_background(self, on_event: EventCallback) -> None: + """Start a task in the background that sends events from the subscriber's stream + to ``socket`` as they become available. + + Args: + on_event: Callback to invoke with the event data for every event + """ + if self._task is not None: + raise RuntimeError("Subscriber is already running") + self._task = asyncio.create_task(self._worker(on_event)) + + @property + def is_running(self) -> bool: + """Return whether a sending task is currently running""" + return self._task is not None + + async def stop(self, join: bool = False) -> None: + """Stop a task was previously started with :meth:`run_in_background`. If the + task is not yet done it will be cancelled and awaited + + Args: + join: If ``True`` wait for all items to be processed before stopping the task + """ + if not self._task: + return + + if join: + await self._queue.join() + + if not self._task.done(): + self._task.cancel() + + with suppress(CancelledError): + await self._task + + self._task = None |