summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py90
1 files changed, 90 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py
new file mode 100644
index 0000000..967b208
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py
@@ -0,0 +1,90 @@
+from __future__ import annotations
+
+import asyncio
+from contextlib import AsyncExitStack
+from functools import partial
+from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload
+
+import asyncpg
+
+from litestar.channels import ChannelsBackend
+from litestar.exceptions import ImproperlyConfiguredException
+
+
+class AsyncPgChannelsBackend(ChannelsBackend):
+ _listener_conn: asyncpg.Connection
+
+ @overload
+ def __init__(self, dsn: str) -> None: ...
+
+ @overload
+ def __init__(
+ self,
+ *,
+ make_connection: Callable[[], Awaitable[asyncpg.Connection]],
+ ) -> None: ...
+
+ def __init__(
+ self,
+ dsn: str | None = None,
+ *,
+ make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
+ ) -> None:
+ if not (dsn or make_connection):
+ raise ImproperlyConfiguredException("Need to specify dsn or make_connection")
+
+ self._subscribed_channels: set[str] = set()
+ self._exit_stack = AsyncExitStack()
+ self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)
+ self._queue: asyncio.Queue[tuple[str, bytes]] | None = None
+
+ async def on_startup(self) -> None:
+ self._queue = asyncio.Queue()
+ self._listener_conn = await self._connect()
+
+ async def on_shutdown(self) -> None:
+ await self._listener_conn.close()
+ self._queue = None
+
+ async def publish(self, data: bytes, channels: Iterable[str]) -> None:
+ if self._queue is None:
+ raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")
+
+ dec_data = data.decode("utf-8")
+
+ conn = await self._connect()
+ try:
+ for channel in channels:
+ await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
+ finally:
+ await conn.close()
+
+ async def subscribe(self, channels: Iterable[str]) -> None:
+ for channel in set(channels) - self._subscribed_channels:
+ await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type]
+ self._subscribed_channels.add(channel)
+
+ async def unsubscribe(self, channels: Iterable[str]) -> None:
+ for channel in channels:
+ await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type]
+ self._subscribed_channels = self._subscribed_channels - set(channels)
+
+ async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
+ if self._queue is None:
+ raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")
+
+ while True:
+ channel, message = await self._queue.get()
+ self._queue.task_done()
+ # an UNLISTEN may be in transit while we're getting here, so we double-check
+ # that we are actually supposed to deliver this message
+ if channel in self._subscribed_channels:
+ yield channel, message
+
+ async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
+ raise NotImplementedError()
+
+ def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
+ if not isinstance(payload, str):
+ raise RuntimeError("Invalid data received")
+ self._queue.put_nowait((channel, payload.encode("utf-8"))) # type: ignore[union-attr]