summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py b/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py
new file mode 100644
index 0000000..14b53bc
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py
@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+from contextlib import AsyncExitStack
+from typing import AsyncGenerator, Iterable
+
+import psycopg
+
+from .base import ChannelsBackend
+
+
+def _safe_quote(ident: str) -> str:
+ return '"{}"'.format(ident.replace('"', '""')) # sourcery skip
+
+
+class PsycoPgChannelsBackend(ChannelsBackend):
+ _listener_conn: psycopg.AsyncConnection
+
+ def __init__(self, pg_dsn: str) -> None:
+ self._pg_dsn = pg_dsn
+ self._subscribed_channels: set[str] = set()
+ self._exit_stack = AsyncExitStack()
+
+ async def on_startup(self) -> None:
+ self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
+ await self._exit_stack.enter_async_context(self._listener_conn)
+
+ async def on_shutdown(self) -> None:
+ await self._exit_stack.aclose()
+
+ async def publish(self, data: bytes, channels: Iterable[str]) -> None:
+ dec_data = data.decode("utf-8")
+ async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
+ for channel in channels:
+ await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))
+
+ async def subscribe(self, channels: Iterable[str]) -> None:
+ for channel in set(channels) - self._subscribed_channels:
+ # can't use placeholders in LISTEN
+ await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore
+
+ self._subscribed_channels.add(channel)
+
+ async def unsubscribe(self, channels: Iterable[str]) -> None:
+ for channel in channels:
+ # can't use placeholders in UNLISTEN
+ await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore
+ self._subscribed_channels = self._subscribed_channels - set(channels)
+
+ async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
+ async for notify in self._listener_conn.notifies():
+ yield notify.channel, notify.payload.encode("utf-8")
+
+ async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
+ raise NotImplementedError()