diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/connection/websocket.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/websocket.py | 343 |
1 files changed, 0 insertions, 343 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py deleted file mode 100644 index 0c7bc04..0000000 --- a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py +++ /dev/null @@ -1,343 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, Literal, cast, overload - -from litestar.connection.base import ( - ASGIConnection, - AuthT, - StateT, - UserT, - empty_receive, - empty_send, -) -from litestar.datastructures.headers import Headers -from litestar.exceptions import WebSocketDisconnect -from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack -from litestar.status_codes import WS_1000_NORMAL_CLOSURE - -__all__ = ("WebSocket",) - - -if TYPE_CHECKING: - from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401 - from litestar.types import Message, Serializer, WebSocketScope - from litestar.types.asgi_types import ( - Receive, - ReceiveMessage, - Scope, - Send, - WebSocketAcceptEvent, - WebSocketCloseEvent, - WebSocketDisconnectEvent, - WebSocketMode, - WebSocketReceiveEvent, - WebSocketSendEvent, - ) - -DISCONNECT_MESSAGE = "connection is disconnected" - - -class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]): - """The Litestar WebSocket class.""" - - __slots__ = ("connection_state",) - - scope: WebSocketScope # pyright: ignore - """The ASGI scope attached to the connection.""" - receive: Receive - """The ASGI receive function.""" - send: Send - """The ASGI send function.""" - - def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: - """Initialize ``WebSocket``. - - Args: - scope: The ASGI connection scope. - receive: The ASGI receive function. - send: The ASGI send function. - """ - super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send)) - self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init" - - def receive_wrapper(self, receive: Receive) -> Receive: - """Wrap ``receive`` to set 'self.connection_state' and validate events. - - Args: - receive: The ASGI receive function. - - Returns: - An ASGI receive function. - """ - - async def wrapped_receive() -> ReceiveMessage: - if self.connection_state == "disconnect": - raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) - message = await receive() - if message["type"] == "websocket.connect": - self.connection_state = "connect" - elif message["type"] == "websocket.receive": - self.connection_state = "receive" - else: - self.connection_state = "disconnect" - return message - - return wrapped_receive - - def send_wrapper(self, send: Send) -> Send: - """Wrap ``send`` to ensure that state is not disconnected. - - Args: - send: The ASGI send function. - - Returns: - An ASGI send function. - """ - - async def wrapped_send(message: Message) -> None: - if self.connection_state == "disconnect": - raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover - await send(message) - - return wrapped_send - - async def accept( - self, - subprotocols: str | None = None, - headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None, - ) -> None: - """Accept the incoming connection. This method should be called before receiving data. - - Args: - subprotocols: Websocket sub-protocol to use. - headers: Headers to set on the data sent. - - Returns: - None - """ - if self.connection_state == "init": - await self.receive() - _headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else [] - - if isinstance(headers, dict): - _headers = Headers(headers=headers).to_header_list() - - if isinstance(headers, Headers): - _headers = headers.to_header_list() - - event: WebSocketAcceptEvent = { - "type": "websocket.accept", - "subprotocol": subprotocols, - "headers": _headers, - } - await self.send(event) - - async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None: - """Send an 'websocket.close' event. - - Args: - code: Status code. - reason: Reason for closing the connection - - Returns: - None - """ - event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""} - await self.send(event) - - @overload - async def receive_data(self, mode: Literal["text"]) -> str: ... - - @overload - async def receive_data(self, mode: Literal["binary"]) -> bytes: ... - - async def receive_data(self, mode: WebSocketMode) -> str | bytes: - """Receive an 'websocket.receive' event and returns the data stored on it. - - Args: - mode: The respective event key to use. - - Returns: - The event's data. - """ - if self.connection_state == "init": - await self.accept() - event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive()) - if event["type"] == "websocket.disconnect": - raise WebSocketDisconnect(detail="disconnect event", code=event["code"]) - return event.get("text") or "" if mode == "text" else event.get("bytes") or b"" - - @overload - def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ... - - @overload - def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ... - - async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]: - """Continuously receive data and yield it - - Args: - mode: Socket mode to use. Either ``text`` or ``binary`` - """ - try: - while True: - yield await self.receive_data(mode) - except WebSocketDisconnect: - pass - - async def receive_text(self) -> str: - """Receive data as text. - - Returns: - A string. - """ - return await self.receive_data(mode="text") - - async def receive_bytes(self) -> bytes: - """Receive data as bytes. - - Returns: - A byte-string. - """ - return await self.receive_data(mode="binary") - - async def receive_json(self, mode: WebSocketMode = "text") -> Any: - """Receive data and decode it as JSON. - - Args: - mode: Either ``text`` or ``binary``. - - Returns: - An arbitrary value - """ - data = await self.receive_data(mode=mode) - return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) - - async def receive_msgpack(self) -> Any: - """Receive data and decode it as MessagePack. - - Note that since MessagePack is a binary format, this method will always receive - data in ``binary`` mode. - - Returns: - An arbitrary value - """ - data = await self.receive_data(mode="binary") - return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) - - async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]: - """Continuously receive data and yield it, decoding it as JSON in the process. - - Args: - mode: Socket mode to use. Either ``text`` or ``binary`` - """ - async for data in self.iter_data(mode): - yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) - - async def iter_msgpack(self) -> AsyncGenerator[Any, None]: - """Continuously receive data and yield it, decoding it as MessagePack in the - process. - - Note that since MessagePack is a binary format, this method will always receive - data in ``binary`` mode. - - """ - async for data in self.iter_data(mode="binary"): - yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) - - async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None: - """Send a 'websocket.send' event. - - Args: - data: Data to send. - mode: The respective event key to use. - encoding: Encoding to use when converting bytes / str. - - Returns: - None - """ - if self.connection_state == "init": # pragma: no cover - await self.accept() - event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None} - if mode == "binary": - event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding) - else: - event["text"] = data if isinstance(data, str) else data.decode(encoding) - await self.send(event) - - @overload - async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ... - - @overload - async def send_text(self, data: str) -> None: ... - - async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None: - """Send data using the ``text`` key of the send event. - - Args: - data: Data to send - encoding: Encoding to use for binary data. - - Returns: - None - """ - await self.send_data(data=data, encoding=encoding) - - @overload - async def send_bytes(self, data: bytes) -> None: ... - - @overload - async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ... - - async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None: - """Send data using the ``bytes`` key of the send event. - - Args: - data: Data to send - encoding: Encoding to use for binary data. - - Returns: - None - """ - await self.send_data(data=data, mode="binary", encoding=encoding) - - async def send_json( - self, - data: Any, - mode: WebSocketMode = "text", - encoding: str = "utf-8", - serializer: Serializer = default_serializer, - ) -> None: - """Send data as JSON. - - Args: - data: A value to serialize. - mode: Either ``text`` or ``binary``. - encoding: Encoding to use for binary data. - serializer: A serializer function. - - Returns: - None - """ - await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding) - - async def send_msgpack( - self, - data: Any, - encoding: str = "utf-8", - serializer: Serializer = default_serializer, - ) -> None: - """Send data as MessagePack. - - Note that since MessagePack is a binary format, this method will always send - data in ``binary`` mode. - - Args: - data: A value to serialize. - encoding: Encoding to use for binary data. - serializer: A serializer function. - - Returns: - None - """ - await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding) |