1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
from __future__ import annotations
import asyncio
from contextlib import AsyncExitStack
from functools import partial
from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload
import asyncpg
from litestar.channels import ChannelsBackend
from litestar.exceptions import ImproperlyConfiguredException
class AsyncPgChannelsBackend(ChannelsBackend):
_listener_conn: asyncpg.Connection
@overload
def __init__(self, dsn: str) -> None: ...
@overload
def __init__(
self,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]],
) -> None: ...
def __init__(
self,
dsn: str | None = None,
*,
make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
) -> None:
if not (dsn or make_connection):
raise ImproperlyConfiguredException("Need to specify dsn or make_connection")
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()
self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)
self._queue: asyncio.Queue[tuple[str, bytes]] | None = None
async def on_startup(self) -> None:
self._queue = asyncio.Queue()
self._listener_conn = await self._connect()
async def on_shutdown(self) -> None:
await self._listener_conn.close()
self._queue = None
async def publish(self, data: bytes, channels: Iterable[str]) -> None:
if self._queue is None:
raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")
dec_data = data.decode("utf-8")
conn = await self._connect()
try:
for channel in channels:
await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
finally:
await conn.close()
async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
await self._listener_conn.add_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels.add(channel)
async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
await self._listener_conn.remove_listener(channel, self._listener) # type: ignore[arg-type]
self._subscribed_channels = self._subscribed_channels - set(channels)
async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
if self._queue is None:
raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")
while True:
channel, message = await self._queue.get()
self._queue.task_done()
# an UNLISTEN may be in transit while we're getting here, so we double-check
# that we are actually supposed to deliver this message
if channel in self._subscribed_channels:
yield channel, message
async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
raise NotImplementedError()
def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
if not isinstance(payload, str):
raise RuntimeError("Invalid data received")
self._queue.put_nowait((channel, payload.encode("utf-8"))) # type: ignore[union-attr]
|