summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/channels/backends/asyncpg.py
blob: 967b2085d1dd176a3a75d3490dc74a49c5b7cfae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from __future__ import annotations

import asyncio
from contextlib import AsyncExitStack
from functools import partial
from typing import AsyncGenerator, Awaitable, Callable, Iterable, overload

import asyncpg

from litestar.channels import ChannelsBackend
from litestar.exceptions import ImproperlyConfiguredException


class AsyncPgChannelsBackend(ChannelsBackend):
    _listener_conn: asyncpg.Connection

    @overload
    def __init__(self, dsn: str) -> None: ...

    @overload
    def __init__(
        self,
        *,
        make_connection: Callable[[], Awaitable[asyncpg.Connection]],
    ) -> None: ...

    def __init__(
        self,
        dsn: str | None = None,
        *,
        make_connection: Callable[[], Awaitable[asyncpg.Connection]] | None = None,
    ) -> None:
        if not (dsn or make_connection):
            raise ImproperlyConfiguredException("Need to specify dsn or make_connection")

        self._subscribed_channels: set[str] = set()
        self._exit_stack = AsyncExitStack()
        self._connect = make_connection or partial(asyncpg.connect, dsn=dsn)
        self._queue: asyncio.Queue[tuple[str, bytes]] | None = None

    async def on_startup(self) -> None:
        self._queue = asyncio.Queue()
        self._listener_conn = await self._connect()

    async def on_shutdown(self) -> None:
        await self._listener_conn.close()
        self._queue = None

    async def publish(self, data: bytes, channels: Iterable[str]) -> None:
        if self._queue is None:
            raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")

        dec_data = data.decode("utf-8")

        conn = await self._connect()
        try:
            for channel in channels:
                await conn.execute("SELECT pg_notify($1, $2);", channel, dec_data)
        finally:
            await conn.close()

    async def subscribe(self, channels: Iterable[str]) -> None:
        for channel in set(channels) - self._subscribed_channels:
            await self._listener_conn.add_listener(channel, self._listener)  # type: ignore[arg-type]
            self._subscribed_channels.add(channel)

    async def unsubscribe(self, channels: Iterable[str]) -> None:
        for channel in channels:
            await self._listener_conn.remove_listener(channel, self._listener)  # type: ignore[arg-type]
        self._subscribed_channels = self._subscribed_channels - set(channels)

    async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
        if self._queue is None:
            raise RuntimeError("Backend not yet initialized. Did you forget to call on_startup?")

        while True:
            channel, message = await self._queue.get()
            self._queue.task_done()
            # an UNLISTEN may be in transit while we're getting here, so we double-check
            # that we are actually supposed to deliver this message
            if channel in self._subscribed_channels:
                yield channel, message

    async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
        raise NotImplementedError()

    def _listener(self, /, connection: asyncpg.Connection, pid: int, channel: str, payload: object) -> None:
        if not isinstance(payload, str):
            raise RuntimeError("Invalid data received")
        self._queue.put_nowait((channel, payload.encode("utf-8")))  # type: ignore[union-attr]