summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/connection/websocket.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/connection/websocket.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/websocket.py343
1 files changed, 343 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py
new file mode 100644
index 0000000..0c7bc04
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py
@@ -0,0 +1,343 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, Literal, cast, overload
+
+from litestar.connection.base import (
+ ASGIConnection,
+ AuthT,
+ StateT,
+ UserT,
+ empty_receive,
+ empty_send,
+)
+from litestar.datastructures.headers import Headers
+from litestar.exceptions import WebSocketDisconnect
+from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack
+from litestar.status_codes import WS_1000_NORMAL_CLOSURE
+
+__all__ = ("WebSocket",)
+
+
+if TYPE_CHECKING:
+ from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401
+ from litestar.types import Message, Serializer, WebSocketScope
+ from litestar.types.asgi_types import (
+ Receive,
+ ReceiveMessage,
+ Scope,
+ Send,
+ WebSocketAcceptEvent,
+ WebSocketCloseEvent,
+ WebSocketDisconnectEvent,
+ WebSocketMode,
+ WebSocketReceiveEvent,
+ WebSocketSendEvent,
+ )
+
+DISCONNECT_MESSAGE = "connection is disconnected"
+
+
+class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]):
+ """The Litestar WebSocket class."""
+
+ __slots__ = ("connection_state",)
+
+ scope: WebSocketScope # pyright: ignore
+ """The ASGI scope attached to the connection."""
+ receive: Receive
+ """The ASGI receive function."""
+ send: Send
+ """The ASGI send function."""
+
+ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
+ """Initialize ``WebSocket``.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+ """
+ super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send))
+ self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init"
+
+ def receive_wrapper(self, receive: Receive) -> Receive:
+ """Wrap ``receive`` to set 'self.connection_state' and validate events.
+
+ Args:
+ receive: The ASGI receive function.
+
+ Returns:
+ An ASGI receive function.
+ """
+
+ async def wrapped_receive() -> ReceiveMessage:
+ if self.connection_state == "disconnect":
+ raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE)
+ message = await receive()
+ if message["type"] == "websocket.connect":
+ self.connection_state = "connect"
+ elif message["type"] == "websocket.receive":
+ self.connection_state = "receive"
+ else:
+ self.connection_state = "disconnect"
+ return message
+
+ return wrapped_receive
+
+ def send_wrapper(self, send: Send) -> Send:
+ """Wrap ``send`` to ensure that state is not disconnected.
+
+ Args:
+ send: The ASGI send function.
+
+ Returns:
+ An ASGI send function.
+ """
+
+ async def wrapped_send(message: Message) -> None:
+ if self.connection_state == "disconnect":
+ raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover
+ await send(message)
+
+ return wrapped_send
+
+ async def accept(
+ self,
+ subprotocols: str | None = None,
+ headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None,
+ ) -> None:
+ """Accept the incoming connection. This method should be called before receiving data.
+
+ Args:
+ subprotocols: Websocket sub-protocol to use.
+ headers: Headers to set on the data sent.
+
+ Returns:
+ None
+ """
+ if self.connection_state == "init":
+ await self.receive()
+ _headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else []
+
+ if isinstance(headers, dict):
+ _headers = Headers(headers=headers).to_header_list()
+
+ if isinstance(headers, Headers):
+ _headers = headers.to_header_list()
+
+ event: WebSocketAcceptEvent = {
+ "type": "websocket.accept",
+ "subprotocol": subprotocols,
+ "headers": _headers,
+ }
+ await self.send(event)
+
+ async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None:
+ """Send an 'websocket.close' event.
+
+ Args:
+ code: Status code.
+ reason: Reason for closing the connection
+
+ Returns:
+ None
+ """
+ event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""}
+ await self.send(event)
+
+ @overload
+ async def receive_data(self, mode: Literal["text"]) -> str: ...
+
+ @overload
+ async def receive_data(self, mode: Literal["binary"]) -> bytes: ...
+
+ async def receive_data(self, mode: WebSocketMode) -> str | bytes:
+ """Receive an 'websocket.receive' event and returns the data stored on it.
+
+ Args:
+ mode: The respective event key to use.
+
+ Returns:
+ The event's data.
+ """
+ if self.connection_state == "init":
+ await self.accept()
+ event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive())
+ if event["type"] == "websocket.disconnect":
+ raise WebSocketDisconnect(detail="disconnect event", code=event["code"])
+ return event.get("text") or "" if mode == "text" else event.get("bytes") or b""
+
+ @overload
+ def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ...
+
+ @overload
+ def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ...
+
+ async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]:
+ """Continuously receive data and yield it
+
+ Args:
+ mode: Socket mode to use. Either ``text`` or ``binary``
+ """
+ try:
+ while True:
+ yield await self.receive_data(mode)
+ except WebSocketDisconnect:
+ pass
+
+ async def receive_text(self) -> str:
+ """Receive data as text.
+
+ Returns:
+ A string.
+ """
+ return await self.receive_data(mode="text")
+
+ async def receive_bytes(self) -> bytes:
+ """Receive data as bytes.
+
+ Returns:
+ A byte-string.
+ """
+ return await self.receive_data(mode="binary")
+
+ async def receive_json(self, mode: WebSocketMode = "text") -> Any:
+ """Receive data and decode it as JSON.
+
+ Args:
+ mode: Either ``text`` or ``binary``.
+
+ Returns:
+ An arbitrary value
+ """
+ data = await self.receive_data(mode=mode)
+ return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def receive_msgpack(self) -> Any:
+ """Receive data and decode it as MessagePack.
+
+ Note that since MessagePack is a binary format, this method will always receive
+ data in ``binary`` mode.
+
+ Returns:
+ An arbitrary value
+ """
+ data = await self.receive_data(mode="binary")
+ return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]:
+ """Continuously receive data and yield it, decoding it as JSON in the process.
+
+ Args:
+ mode: Socket mode to use. Either ``text`` or ``binary``
+ """
+ async for data in self.iter_data(mode):
+ yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def iter_msgpack(self) -> AsyncGenerator[Any, None]:
+ """Continuously receive data and yield it, decoding it as MessagePack in the
+ process.
+
+ Note that since MessagePack is a binary format, this method will always receive
+ data in ``binary`` mode.
+
+ """
+ async for data in self.iter_data(mode="binary"):
+ yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None:
+ """Send a 'websocket.send' event.
+
+ Args:
+ data: Data to send.
+ mode: The respective event key to use.
+ encoding: Encoding to use when converting bytes / str.
+
+ Returns:
+ None
+ """
+ if self.connection_state == "init": # pragma: no cover
+ await self.accept()
+ event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None}
+ if mode == "binary":
+ event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding)
+ else:
+ event["text"] = data if isinstance(data, str) else data.decode(encoding)
+ await self.send(event)
+
+ @overload
+ async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ...
+
+ @overload
+ async def send_text(self, data: str) -> None: ...
+
+ async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None:
+ """Send data using the ``text`` key of the send event.
+
+ Args:
+ data: Data to send
+ encoding: Encoding to use for binary data.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=data, encoding=encoding)
+
+ @overload
+ async def send_bytes(self, data: bytes) -> None: ...
+
+ @overload
+ async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ...
+
+ async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None:
+ """Send data using the ``bytes`` key of the send event.
+
+ Args:
+ data: Data to send
+ encoding: Encoding to use for binary data.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=data, mode="binary", encoding=encoding)
+
+ async def send_json(
+ self,
+ data: Any,
+ mode: WebSocketMode = "text",
+ encoding: str = "utf-8",
+ serializer: Serializer = default_serializer,
+ ) -> None:
+ """Send data as JSON.
+
+ Args:
+ data: A value to serialize.
+ mode: Either ``text`` or ``binary``.
+ encoding: Encoding to use for binary data.
+ serializer: A serializer function.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding)
+
+ async def send_msgpack(
+ self,
+ data: Any,
+ encoding: str = "utf-8",
+ serializer: Serializer = default_serializer,
+ ) -> None:
+ """Send data as MessagePack.
+
+ Note that since MessagePack is a binary format, this method will always send
+ data in ``binary`` mode.
+
+ Args:
+ data: A value to serialize.
+ encoding: Encoding to use for binary data.
+ serializer: A serializer function.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding)