summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/response/sse.py
blob: 48a9192b98c448541001b2798d9fe0b1c3389fb9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from __future__ import annotations

import io
import re
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, AsyncIterator, Iterable, Iterator

from litestar.concurrency import sync_to_thread
from litestar.exceptions import ImproperlyConfiguredException
from litestar.response.streaming import Stream
from litestar.utils import AsyncIteratorWrapper

if TYPE_CHECKING:
    from litestar.background_tasks import BackgroundTask, BackgroundTasks
    from litestar.types import ResponseCookies, ResponseHeaders, SSEData, StreamType

_LINE_BREAK_RE = re.compile(r"\r\n|\r|\n")
DEFAULT_SEPARATOR = "\r\n"


class _ServerSentEventIterator(AsyncIteratorWrapper[bytes]):
    __slots__ = ("content_async_iterator", "event_id", "event_type", "retry_duration", "comment_message")

    content_async_iterator: AsyncIterable[SSEData]

    def __init__(
        self,
        content: str | bytes | StreamType[SSEData],
        event_type: str | None = None,
        event_id: int | str | None = None,
        retry_duration: int | None = None,
        comment_message: str | None = None,
    ) -> None:
        self.comment_message = comment_message
        self.event_id = event_id
        self.event_type = event_type
        self.retry_duration = retry_duration
        chunks: list[bytes] = []
        if comment_message is not None:
            chunks.extend([f": {chunk}\r\n".encode() for chunk in _LINE_BREAK_RE.split(comment_message)])

        if event_id is not None:
            chunks.append(f"id: {event_id}\r\n".encode())

        if event_type is not None:
            chunks.append(f"event: {event_type}\r\n".encode())

        if retry_duration is not None:
            chunks.append(f"retry: {retry_duration}\r\n".encode())

        super().__init__(iterator=chunks)

        if not isinstance(content, (Iterator, AsyncIterator, AsyncIteratorWrapper)) and callable(content):
            content = content()  # type: ignore[unreachable]

        if isinstance(content, (str, bytes)):
            self.content_async_iterator = AsyncIteratorWrapper([content])
        elif isinstance(content, (Iterable, Iterator)):
            self.content_async_iterator = AsyncIteratorWrapper(content)
        elif isinstance(content, (AsyncIterable, AsyncIterator, AsyncIteratorWrapper)):
            self.content_async_iterator = content
        else:
            raise ImproperlyConfiguredException(f"Invalid type {type(content)} for ServerSentEvent")

    def ensure_bytes(self, data: str | int | bytes | dict | ServerSentEventMessage | Any, sep: str) -> bytes:
        if isinstance(data, ServerSentEventMessage):
            return data.encode()
        if isinstance(data, dict):
            data["sep"] = sep
            return ServerSentEventMessage(**data).encode()

        return ServerSentEventMessage(
            data=data, id=self.event_id, event=self.event_type, retry=self.retry_duration, sep=sep
        ).encode()

    def _call_next(self) -> bytes:
        try:
            return next(self.iterator)
        except StopIteration as e:
            raise ValueError from e

    async def _async_generator(self) -> AsyncGenerator[bytes, None]:
        while True:
            try:
                yield await sync_to_thread(self._call_next)
            except ValueError:
                async for value in self.content_async_iterator:
                    yield self.ensure_bytes(value, DEFAULT_SEPARATOR)
                break


@dataclass
class ServerSentEventMessage:
    data: str | int | bytes | None = ""
    event: str | None = None
    id: int | str | None = None
    retry: int | None = None
    comment: str | None = None
    sep: str = DEFAULT_SEPARATOR

    def encode(self) -> bytes:
        buffer = io.StringIO()
        if self.comment is not None:
            for chunk in _LINE_BREAK_RE.split(str(self.comment)):
                buffer.write(f": {chunk}")
                buffer.write(self.sep)

        if self.id is not None:
            buffer.write(_LINE_BREAK_RE.sub("", f"id: {self.id}"))
            buffer.write(self.sep)

        if self.event is not None:
            buffer.write(_LINE_BREAK_RE.sub("", f"event: {self.event}"))
            buffer.write(self.sep)

        if self.data is not None:
            data = self.data
            for chunk in _LINE_BREAK_RE.split(data.decode() if isinstance(data, bytes) else str(data)):
                buffer.write(f"data: {chunk}")
                buffer.write(self.sep)

        if self.retry is not None:
            buffer.write(f"retry: {self.retry}")
            buffer.write(self.sep)

        buffer.write(self.sep)
        return buffer.getvalue().encode("utf-8")


class ServerSentEvent(Stream):
    def __init__(
        self,
        content: str | bytes | StreamType[SSEData],
        *,
        background: BackgroundTask | BackgroundTasks | None = None,
        cookies: ResponseCookies | None = None,
        encoding: str = "utf-8",
        headers: ResponseHeaders | None = None,
        event_type: str | None = None,
        event_id: int | str | None = None,
        retry_duration: int | None = None,
        comment_message: str | None = None,
        status_code: int | None = None,
    ) -> None:
        """Initialize the response.

        Args:
            content: Bytes, string or a sync or async iterator or iterable.
            background: A :class:`BackgroundTask <.background_tasks.BackgroundTask>` instance or
                :class:`BackgroundTasks <.background_tasks.BackgroundTasks>` to execute after the response is finished.
                Defaults to None.
            cookies: A list of :class:`Cookie <.datastructures.Cookie>` instances to be set under the response
                ``Set-Cookie`` header.
            encoding: The encoding to be used for the response headers.
            headers: A string keyed dictionary of response headers. Header keys are insensitive.
            status_code: The response status code. Defaults to 200.
            event_type: The type of the SSE event. If given, the browser will sent the event to any 'event-listener'
                declared for it (e.g. via 'addEventListener' in JS).
            event_id: The event ID. This sets the event source's 'last event id'.
            retry_duration: Retry duration in milliseconds.
            comment_message: A comment message. This value is ignored by clients and is used mostly for pinging.
        """
        super().__init__(
            content=_ServerSentEventIterator(
                content=content,
                event_type=event_type,
                event_id=event_id,
                retry_duration=retry_duration,
                comment_message=comment_message,
            ),
            media_type="text/event-stream",
            background=background,
            cookies=cookies,
            encoding=encoding,
            headers=headers,
            status_code=status_code,
        )
        self.headers.setdefault("Cache-Control", "no-cache")
        self.headers["Connection"] = "keep-alive"
        self.headers["X-Accel-Buffering"] = "no"