summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/channels/backends/psycopg.py
blob: 14b53bcd1aba86770dbb7979e21f8a33ad236379 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import AsyncGenerator, Iterable

import psycopg

from .base import ChannelsBackend


def _safe_quote(ident: str) -> str:
    return '"{}"'.format(ident.replace('"', '""'))  # sourcery skip


class PsycoPgChannelsBackend(ChannelsBackend):
    _listener_conn: psycopg.AsyncConnection

    def __init__(self, pg_dsn: str) -> None:
        self._pg_dsn = pg_dsn
        self._subscribed_channels: set[str] = set()
        self._exit_stack = AsyncExitStack()

    async def on_startup(self) -> None:
        self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
        await self._exit_stack.enter_async_context(self._listener_conn)

    async def on_shutdown(self) -> None:
        await self._exit_stack.aclose()

    async def publish(self, data: bytes, channels: Iterable[str]) -> None:
        dec_data = data.decode("utf-8")
        async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
            for channel in channels:
                await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))

    async def subscribe(self, channels: Iterable[str]) -> None:
        for channel in set(channels) - self._subscribed_channels:
            # can't use placeholders in LISTEN
            await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};")  # pyright: ignore

            self._subscribed_channels.add(channel)

    async def unsubscribe(self, channels: Iterable[str]) -> None:
        for channel in channels:
            # can't use placeholders in UNLISTEN
            await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};")  # pyright: ignore
        self._subscribed_channels = self._subscribed_channels - set(channels)

    async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
        async for notify in self._listener_conn.notifies():
            yield notify.channel, notify.payload.encode("utf-8")

    async def get_history(self, channel: str, limit: int | None = None) -> list[bytes]:
        raise NotImplementedError()