diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/connection/websocket.py | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/connection/websocket.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/websocket.py | 343 |
1 files changed, 343 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py new file mode 100644 index 0000000..0c7bc04 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, Literal, cast, overload + +from litestar.connection.base import ( + ASGIConnection, + AuthT, + StateT, + UserT, + empty_receive, + empty_send, +) +from litestar.datastructures.headers import Headers +from litestar.exceptions import WebSocketDisconnect +from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack +from litestar.status_codes import WS_1000_NORMAL_CLOSURE + +__all__ = ("WebSocket",) + + +if TYPE_CHECKING: + from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401 + from litestar.types import Message, Serializer, WebSocketScope + from litestar.types.asgi_types import ( + Receive, + ReceiveMessage, + Scope, + Send, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketMode, + WebSocketReceiveEvent, + WebSocketSendEvent, + ) + +DISCONNECT_MESSAGE = "connection is disconnected" + + +class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]): + """The Litestar WebSocket class.""" + + __slots__ = ("connection_state",) + + scope: WebSocketScope # pyright: ignore + """The ASGI scope attached to the connection.""" + receive: Receive + """The ASGI receive function.""" + send: Send + """The ASGI send function.""" + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize ``WebSocket``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + """ + super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send)) + self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init" + + def receive_wrapper(self, receive: Receive) -> Receive: + """Wrap ``receive`` to set 'self.connection_state' and validate events. + + Args: + receive: The ASGI receive function. + + Returns: + An ASGI receive function. + """ + + async def wrapped_receive() -> ReceiveMessage: + if self.connection_state == "disconnect": + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) + message = await receive() + if message["type"] == "websocket.connect": + self.connection_state = "connect" + elif message["type"] == "websocket.receive": + self.connection_state = "receive" + else: + self.connection_state = "disconnect" + return message + + return wrapped_receive + + def send_wrapper(self, send: Send) -> Send: + """Wrap ``send`` to ensure that state is not disconnected. + + Args: + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + + async def wrapped_send(message: Message) -> None: + if self.connection_state == "disconnect": + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover + await send(message) + + return wrapped_send + + async def accept( + self, + subprotocols: str | None = None, + headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None, + ) -> None: + """Accept the incoming connection. This method should be called before receiving data. + + Args: + subprotocols: Websocket sub-protocol to use. + headers: Headers to set on the data sent. + + Returns: + None + """ + if self.connection_state == "init": + await self.receive() + _headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else [] + + if isinstance(headers, dict): + _headers = Headers(headers=headers).to_header_list() + + if isinstance(headers, Headers): + _headers = headers.to_header_list() + + event: WebSocketAcceptEvent = { + "type": "websocket.accept", + "subprotocol": subprotocols, + "headers": _headers, + } + await self.send(event) + + async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None: + """Send an 'websocket.close' event. + + Args: + code: Status code. + reason: Reason for closing the connection + + Returns: + None + """ + event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""} + await self.send(event) + + @overload + async def receive_data(self, mode: Literal["text"]) -> str: ... + + @overload + async def receive_data(self, mode: Literal["binary"]) -> bytes: ... + + async def receive_data(self, mode: WebSocketMode) -> str | bytes: + """Receive an 'websocket.receive' event and returns the data stored on it. + + Args: + mode: The respective event key to use. + + Returns: + The event's data. + """ + if self.connection_state == "init": + await self.accept() + event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive()) + if event["type"] == "websocket.disconnect": + raise WebSocketDisconnect(detail="disconnect event", code=event["code"]) + return event.get("text") or "" if mode == "text" else event.get("bytes") or b"" + + @overload + def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ... + + @overload + def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ... + + async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]: + """Continuously receive data and yield it + + Args: + mode: Socket mode to use. Either ``text`` or ``binary`` + """ + try: + while True: + yield await self.receive_data(mode) + except WebSocketDisconnect: + pass + + async def receive_text(self) -> str: + """Receive data as text. + + Returns: + A string. + """ + return await self.receive_data(mode="text") + + async def receive_bytes(self) -> bytes: + """Receive data as bytes. + + Returns: + A byte-string. + """ + return await self.receive_data(mode="binary") + + async def receive_json(self, mode: WebSocketMode = "text") -> Any: + """Receive data and decode it as JSON. + + Args: + mode: Either ``text`` or ``binary``. + + Returns: + An arbitrary value + """ + data = await self.receive_data(mode=mode) + return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def receive_msgpack(self) -> Any: + """Receive data and decode it as MessagePack. + + Note that since MessagePack is a binary format, this method will always receive + data in ``binary`` mode. + + Returns: + An arbitrary value + """ + data = await self.receive_data(mode="binary") + return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]: + """Continuously receive data and yield it, decoding it as JSON in the process. + + Args: + mode: Socket mode to use. Either ``text`` or ``binary`` + """ + async for data in self.iter_data(mode): + yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def iter_msgpack(self) -> AsyncGenerator[Any, None]: + """Continuously receive data and yield it, decoding it as MessagePack in the + process. + + Note that since MessagePack is a binary format, this method will always receive + data in ``binary`` mode. + + """ + async for data in self.iter_data(mode="binary"): + yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None: + """Send a 'websocket.send' event. + + Args: + data: Data to send. + mode: The respective event key to use. + encoding: Encoding to use when converting bytes / str. + + Returns: + None + """ + if self.connection_state == "init": # pragma: no cover + await self.accept() + event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None} + if mode == "binary": + event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding) + else: + event["text"] = data if isinstance(data, str) else data.decode(encoding) + await self.send(event) + + @overload + async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ... + + @overload + async def send_text(self, data: str) -> None: ... + + async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None: + """Send data using the ``text`` key of the send event. + + Args: + data: Data to send + encoding: Encoding to use for binary data. + + Returns: + None + """ + await self.send_data(data=data, encoding=encoding) + + @overload + async def send_bytes(self, data: bytes) -> None: ... + + @overload + async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ... + + async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None: + """Send data using the ``bytes`` key of the send event. + + Args: + data: Data to send + encoding: Encoding to use for binary data. + + Returns: + None + """ + await self.send_data(data=data, mode="binary", encoding=encoding) + + async def send_json( + self, + data: Any, + mode: WebSocketMode = "text", + encoding: str = "utf-8", + serializer: Serializer = default_serializer, + ) -> None: + """Send data as JSON. + + Args: + data: A value to serialize. + mode: Either ``text`` or ``binary``. + encoding: Encoding to use for binary data. + serializer: A serializer function. + + Returns: + None + """ + await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding) + + async def send_msgpack( + self, + data: Any, + encoding: str = "utf-8", + serializer: Serializer = default_serializer, + ) -> None: + """Send data as MessagePack. + + Note that since MessagePack is a binary format, this method will always send + data in ``binary`` mode. + + Args: + data: A value to serialize. + encoding: Encoding to use for binary data. + serializer: A serializer function. + + Returns: + None + """ + await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding) |