diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/connection')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/__init__.py | 37 | ||||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc | bin | 0 -> 2206 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc | bin | 0 -> 15545 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc | bin | 0 -> 12711 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc | bin | 0 -> 17142 bytes | |||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/base.py | 345 | ||||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/request.py | 263 | ||||
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/connection/websocket.py | 343 |
8 files changed, 988 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__init__.py b/venv/lib/python3.11/site-packages/litestar/connection/__init__.py new file mode 100644 index 0000000..6922e79 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__init__.py @@ -0,0 +1,37 @@ +"""Some code in this module was adapted from https://github.com/encode/starlette/blob/master/starlette/requests.py and +https://github.com/encode/starlette/blob/master/starlette/websockets.py. + +Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/). +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from litestar.connection.base import ASGIConnection +from litestar.connection.request import Request +from litestar.connection.websocket import WebSocket + +__all__ = ("ASGIConnection", "Request", "WebSocket") diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..89cb4db --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f1eff74 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f047a18 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..49e294c --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/connection/base.py b/venv/lib/python3.11/site-packages/litestar/connection/base.py new file mode 100644 index 0000000..d14c662 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/base.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast + +from litestar._parsers import parse_cookie_string, parse_query_string +from litestar.datastructures.headers import Headers +from litestar.datastructures.multi_dicts import MultiDict +from litestar.datastructures.state import State +from litestar.datastructures.url import URL, Address, make_absolute_url +from litestar.exceptions import ImproperlyConfiguredException +from litestar.types.empty import Empty +from litestar.utils.empty import value_or_default +from litestar.utils.scope.state import ScopeState + +if TYPE_CHECKING: + from typing import NoReturn + + from litestar.app import Litestar + from litestar.types import DataContainerType, EmptyType + from litestar.types.asgi_types import Message, Receive, Scope, Send + from litestar.types.protocols import Logger + +__all__ = ("ASGIConnection", "empty_receive", "empty_send") + +UserT = TypeVar("UserT") +AuthT = TypeVar("AuthT") +HandlerT = TypeVar("HandlerT") +StateT = TypeVar("StateT", bound=State) + + +async def empty_receive() -> NoReturn: # pragma: no cover + """Raise a ``RuntimeError``. + + Serves as a placeholder ``send`` function. + + Raises: + RuntimeError + """ + raise RuntimeError() + + +async def empty_send(_: Message) -> NoReturn: # pragma: no cover + """Raise a ``RuntimeError``. + + Serves as a placeholder ``send`` function. + + Args: + _: An ASGI message + + Raises: + RuntimeError + """ + raise RuntimeError() + + +class ASGIConnection(Generic[HandlerT, UserT, AuthT, StateT]): + """The base ASGI connection container.""" + + __slots__ = ( + "scope", + "receive", + "send", + "_base_url", + "_url", + "_parsed_query", + "_cookies", + "_server_extensions", + "_connection_state", + ) + + scope: Scope + """The ASGI scope attached to the connection.""" + receive: Receive + """The ASGI receive function.""" + send: Send + """The ASGI send function.""" + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize ``ASGIConnection``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + """ + self.scope = scope + self.receive = receive + self.send = send + self._connection_state = ScopeState.from_scope(scope) + self._base_url: URL | EmptyType = Empty + self._url: URL | EmptyType = Empty + self._parsed_query: tuple[tuple[str, str], ...] | EmptyType = Empty + self._cookies: dict[str, str] | EmptyType = Empty + self._server_extensions = scope.get("extensions") or {} # extensions may be None + + @property + def app(self) -> Litestar: + """Return the ``app`` for this connection. + + Returns: + The :class:`Litestar <litestar.app.Litestar>` application instance + """ + return self.scope["app"] + + @property + def route_handler(self) -> HandlerT: + """Return the ``route_handler`` for this connection. + + Returns: + The target route handler instance. + """ + return cast("HandlerT", self.scope["route_handler"]) + + @property + def state(self) -> StateT: + """Return the ``State`` of this connection. + + Returns: + A State instance constructed from the scope["state"] value. + """ + return cast("StateT", State(self.scope.get("state"))) + + @property + def url(self) -> URL: + """Return the URL of this connection's ``Scope``. + + Returns: + A URL instance constructed from the request's scope. + """ + if self._url is Empty: + if (url := self._connection_state.url) is not Empty: + self._url = url + else: + self._connection_state.url = self._url = URL.from_scope(self.scope) + + return self._url + + @property + def base_url(self) -> URL: + """Return the base URL of this connection's ``Scope``. + + Returns: + A URL instance constructed from the request's scope, representing only the base part + (host + domain + prefix) of the request. + """ + if self._base_url is Empty: + if (base_url := self._connection_state.base_url) is not Empty: + self._base_url = base_url + else: + scope = cast( + "Scope", + { + **self.scope, + "path": "/", + "query_string": b"", + "root_path": self.scope.get("app_root_path") or self.scope.get("root_path", ""), + }, + ) + self._connection_state.base_url = self._base_url = URL.from_scope(scope) + return self._base_url + + @property + def headers(self) -> Headers: + """Return the headers of this connection's ``Scope``. + + Returns: + A Headers instance with the request's scope["headers"] value. + """ + return Headers.from_scope(self.scope) + + @property + def query_params(self) -> MultiDict[Any]: + """Return the query parameters of this connection's ``Scope``. + + Returns: + A normalized dict of query parameters. Multiple values for the same key are returned as a list. + """ + if self._parsed_query is Empty: + if (parsed_query := self._connection_state.parsed_query) is not Empty: + self._parsed_query = parsed_query + else: + self._connection_state.parsed_query = self._parsed_query = parse_query_string( + self.scope.get("query_string", b"") + ) + return MultiDict(self._parsed_query) + + @property + def path_params(self) -> dict[str, Any]: + """Return the ``path_params`` of this connection's ``Scope``. + + Returns: + A string keyed dictionary of path parameter values. + """ + return self.scope["path_params"] + + @property + def cookies(self) -> dict[str, str]: + """Return the ``cookies`` of this connection's ``Scope``. + + Returns: + Returns any cookies stored in the header as a parsed dictionary. + """ + if self._cookies is Empty: + if (cookies := self._connection_state.cookies) is not Empty: + self._cookies = cookies + else: + self._connection_state.cookies = self._cookies = ( + parse_cookie_string(cookie_header) if (cookie_header := self.headers.get("cookie")) else {} + ) + return self._cookies + + @property + def client(self) -> Address | None: + """Return the ``client`` data of this connection's ``Scope``. + + Returns: + A two tuple of the host name and port number. + """ + client = self.scope.get("client") + return Address(*client) if client else None + + @property + def auth(self) -> AuthT: + """Return the ``auth`` data of this connection's ``Scope``. + + Raises: + ImproperlyConfiguredException: If ``auth`` is not set in scope via an ``AuthMiddleware``, raises an exception + + Returns: + A type correlating to the generic variable Auth. + """ + if "auth" not in self.scope: + raise ImproperlyConfiguredException("'auth' is not defined in scope, install an AuthMiddleware to set it") + + return cast("AuthT", self.scope["auth"]) + + @property + def user(self) -> UserT: + """Return the ``user`` data of this connection's ``Scope``. + + Raises: + ImproperlyConfiguredException: If ``user`` is not set in scope via an ``AuthMiddleware``, raises an exception + + Returns: + A type correlating to the generic variable User. + """ + if "user" not in self.scope: + raise ImproperlyConfiguredException("'user' is not defined in scope, install an AuthMiddleware to set it") + + return cast("UserT", self.scope["user"]) + + @property + def session(self) -> dict[str, Any]: + """Return the session for this connection if a session was previously set in the ``Scope`` + + Returns: + A dictionary representing the session value - if existing. + + Raises: + ImproperlyConfiguredException: if session is not set in scope. + """ + if "session" not in self.scope: + raise ImproperlyConfiguredException( + "'session' is not defined in scope, install a SessionMiddleware to set it" + ) + + return cast("dict[str, Any]", self.scope["session"]) + + @property + def logger(self) -> Logger: + """Return the ``Logger`` instance for this connection. + + Returns: + A ``Logger`` instance. + + Raises: + ImproperlyConfiguredException: if ``log_config`` has not been passed to the Litestar constructor. + """ + return self.app.get_logger() + + def set_session(self, value: dict[str, Any] | DataContainerType | EmptyType) -> None: + """Set the session in the connection's ``Scope``. + + If the :class:`SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, the session will be added + to the response as a cookie header. + + Args: + value: Dictionary or pydantic model instance for the session data. + + Returns: + None + """ + self.scope["session"] = value + + def clear_session(self) -> None: + """Remove the session from the connection's ``Scope``. + + If the :class:`Litestar SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, this will cause + the session data to be cleared. + + Returns: + None. + """ + self.scope["session"] = Empty + self._connection_state.session_id = Empty + + def get_session_id(self) -> str | None: + return value_or_default(value=self._connection_state.session_id, default=None) + + def url_for(self, name: str, **path_parameters: Any) -> str: + """Return the url for a given route handler name. + + Args: + name: The ``name`` of the request route handler. + **path_parameters: Values for path parameters in the route + + Raises: + NoRouteMatchFoundException: If route with ``name`` does not exist, path parameters are missing or have a + wrong type. + + Returns: + A string representing the absolute url of the route handler. + """ + litestar_instance = self.scope["app"] + url_path = litestar_instance.route_reverse(name, **path_parameters) + + return make_absolute_url(url_path, self.base_url) + + def url_for_static_asset(self, name: str, file_path: str) -> str: + """Receives a static files handler name, an asset file path and returns resolved absolute url to the asset. + + Args: + name: A static handler unique name. + file_path: a string containing path to an asset. + + Raises: + NoRouteMatchFoundException: If static files handler with ``name`` does not exist. + + Returns: + A string representing absolute url to the asset. + """ + litestar_instance = self.scope["app"] + url_path = litestar_instance.url_for_static_asset(name, file_path) + + return make_absolute_url(url_path, self.base_url) diff --git a/venv/lib/python3.11/site-packages/litestar/connection/request.py b/venv/lib/python3.11/site-packages/litestar/connection/request.py new file mode 100644 index 0000000..254c315 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/request.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic + +from litestar._multipart import parse_content_header, parse_multipart_form +from litestar._parsers import parse_url_encoded_form_data +from litestar.connection.base import ( + ASGIConnection, + AuthT, + StateT, + UserT, + empty_receive, + empty_send, +) +from litestar.datastructures.headers import Accept +from litestar.datastructures.multi_dicts import FormMultiDict +from litestar.enums import ASGIExtension, RequestEncodingType +from litestar.exceptions import ( + InternalServerException, + LitestarException, + LitestarWarning, +) +from litestar.serialization import decode_json, decode_msgpack +from litestar.types import Empty + +__all__ = ("Request",) + + +if TYPE_CHECKING: + from litestar.handlers.http_handlers import HTTPRouteHandler # noqa: F401 + from litestar.types.asgi_types import HTTPScope, Method, Receive, Scope, Send + from litestar.types.empty import EmptyType + + +SERVER_PUSH_HEADERS = { + "accept", + "accept-encoding", + "accept-language", + "cache-control", + "user-agent", +} + + +class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]): + """The Litestar Request class.""" + + __slots__ = ( + "_json", + "_form", + "_body", + "_msgpack", + "_content_type", + "_accept", + "is_connected", + "supports_push_promise", + ) + + scope: HTTPScope # pyright: ignore + """The ASGI scope attached to the connection.""" + receive: Receive + """The ASGI receive function.""" + send: Send + """The ASGI send function.""" + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize ``Request``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + """ + super().__init__(scope, receive, send) + self.is_connected: bool = True + self._body: bytes | EmptyType = Empty + self._form: dict[str, str | list[str]] | EmptyType = Empty + self._json: Any = Empty + self._msgpack: Any = Empty + self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty + self._accept: Accept | EmptyType = Empty + self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions + + @property + def method(self) -> Method: + """Return the request method. + + Returns: + The request :class:`Method <litestar.types.Method>` + """ + return self.scope["method"] + + @property + def content_type(self) -> tuple[str, dict[str, str]]: + """Parse the request's 'Content-Type' header, returning the header value and any options as a dictionary. + + Returns: + A tuple with the parsed value and a dictionary containing any options send in it. + """ + if self._content_type is Empty: + if (content_type := self._connection_state.content_type) is not Empty: + self._content_type = content_type + else: + self._content_type = self._connection_state.content_type = parse_content_header( + self.headers.get("Content-Type", "") + ) + return self._content_type + + @property + def accept(self) -> Accept: + """Parse the request's 'Accept' header, returning an :class:`Accept <litestar.datastructures.headers.Accept>` instance. + + Returns: + An :class:`Accept <litestar.datastructures.headers.Accept>` instance, representing the list of acceptable media types. + """ + if self._accept is Empty: + if (accept := self._connection_state.accept) is not Empty: + self._accept = accept + else: + self._accept = self._connection_state.accept = Accept(self.headers.get("Accept", "*/*")) + return self._accept + + async def json(self) -> Any: + """Retrieve the json request body from the request. + + Returns: + An arbitrary value + """ + if self._json is Empty: + if (json_ := self._connection_state.json) is not Empty: + self._json = json_ + else: + body = await self.body() + self._json = self._connection_state.json = decode_json( + body or b"null", type_decoders=self.route_handler.resolve_type_decoders() + ) + return self._json + + async def msgpack(self) -> Any: + """Retrieve the MessagePack request body from the request. + + Returns: + An arbitrary value + """ + if self._msgpack is Empty: + if (msgpack := self._connection_state.msgpack) is not Empty: + self._msgpack = msgpack + else: + body = await self.body() + self._msgpack = self._connection_state.msgpack = decode_msgpack( + body or b"\xc0", type_decoders=self.route_handler.resolve_type_decoders() + ) + return self._msgpack + + async def stream(self) -> AsyncGenerator[bytes, None]: + """Return an async generator that streams chunks of bytes. + + Returns: + An async generator. + + Raises: + RuntimeError: if the stream is already consumed + """ + if self._body is Empty: + if not self.is_connected: + raise InternalServerException("stream consumed") + while event := await self.receive(): + if event["type"] == "http.request": + if event["body"]: + yield event["body"] + + if not event.get("more_body", False): + break + + if event["type"] == "http.disconnect": + raise InternalServerException("client disconnected prematurely") + + self.is_connected = False + yield b"" + + else: + yield self._body + yield b"" + return + + async def body(self) -> bytes: + """Return the body of the request. + + Returns: + A byte-string representing the body of the request. + """ + if self._body is Empty: + if (body := self._connection_state.body) is not Empty: + self._body = body + else: + self._body = self._connection_state.body = b"".join([c async for c in self.stream()]) + return self._body + + async def form(self) -> FormMultiDict: + """Retrieve form data from the request. If the request is either a 'multipart/form-data' or an + 'application/x-www-form- urlencoded', return a FormMultiDict instance populated with the values sent in the + request, otherwise, an empty instance. + + Returns: + A FormMultiDict instance + """ + if self._form is Empty: + if (form := self._connection_state.form) is not Empty: + self._form = form + else: + content_type, options = self.content_type + if content_type == RequestEncodingType.MULTI_PART: + self._form = parse_multipart_form( + body=await self.body(), + boundary=options.get("boundary", "").encode(), + multipart_form_part_limit=self.app.multipart_form_part_limit, + ) + elif content_type == RequestEncodingType.URL_ENCODED: + self._form = parse_url_encoded_form_data( + await self.body(), + ) + else: + self._form = {} + + self._connection_state.form = self._form + + return FormMultiDict(self._form) + + async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None: + """Send a push promise. + + This method requires the `http.response.push` extension to be sent from the ASGI server. + + Args: + path: Path to send the promise to. + raise_if_unavailable: Raise an exception if server push is not supported by + the server + + Returns: + None + """ + if not self.supports_push_promise: + if raise_if_unavailable: + raise LitestarException("Attempted to send a push promise but the server does not support it") + + warnings.warn( + "Attempted to send a push promise but the server does not support it. In a future version, this will " + "raise an exception. To enable this behaviour in the current version, set raise_if_unavailable=True. " + "To prevent this behaviour, make sure that the server you are using supports the 'http.response.push' " + "ASGI extension, or check this dynamically via " + ":attr:`~litestar.connection.Request.supports_push_promise`", + stacklevel=2, + category=LitestarWarning, + ) + + return + + raw_headers = [ + (header_name.encode("latin-1"), value.encode("latin-1")) + for header_name in (self.headers.keys() & SERVER_PUSH_HEADERS) + for value in self.headers.getall(header_name, []) + ] + await self.send({"type": "http.response.push", "path": path, "headers": raw_headers}) diff --git a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py new file mode 100644 index 0000000..0c7bc04 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, Literal, cast, overload + +from litestar.connection.base import ( + ASGIConnection, + AuthT, + StateT, + UserT, + empty_receive, + empty_send, +) +from litestar.datastructures.headers import Headers +from litestar.exceptions import WebSocketDisconnect +from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack +from litestar.status_codes import WS_1000_NORMAL_CLOSURE + +__all__ = ("WebSocket",) + + +if TYPE_CHECKING: + from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401 + from litestar.types import Message, Serializer, WebSocketScope + from litestar.types.asgi_types import ( + Receive, + ReceiveMessage, + Scope, + Send, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketMode, + WebSocketReceiveEvent, + WebSocketSendEvent, + ) + +DISCONNECT_MESSAGE = "connection is disconnected" + + +class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]): + """The Litestar WebSocket class.""" + + __slots__ = ("connection_state",) + + scope: WebSocketScope # pyright: ignore + """The ASGI scope attached to the connection.""" + receive: Receive + """The ASGI receive function.""" + send: Send + """The ASGI send function.""" + + def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None: + """Initialize ``WebSocket``. + + Args: + scope: The ASGI connection scope. + receive: The ASGI receive function. + send: The ASGI send function. + """ + super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send)) + self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init" + + def receive_wrapper(self, receive: Receive) -> Receive: + """Wrap ``receive`` to set 'self.connection_state' and validate events. + + Args: + receive: The ASGI receive function. + + Returns: + An ASGI receive function. + """ + + async def wrapped_receive() -> ReceiveMessage: + if self.connection_state == "disconnect": + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) + message = await receive() + if message["type"] == "websocket.connect": + self.connection_state = "connect" + elif message["type"] == "websocket.receive": + self.connection_state = "receive" + else: + self.connection_state = "disconnect" + return message + + return wrapped_receive + + def send_wrapper(self, send: Send) -> Send: + """Wrap ``send`` to ensure that state is not disconnected. + + Args: + send: The ASGI send function. + + Returns: + An ASGI send function. + """ + + async def wrapped_send(message: Message) -> None: + if self.connection_state == "disconnect": + raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover + await send(message) + + return wrapped_send + + async def accept( + self, + subprotocols: str | None = None, + headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None, + ) -> None: + """Accept the incoming connection. This method should be called before receiving data. + + Args: + subprotocols: Websocket sub-protocol to use. + headers: Headers to set on the data sent. + + Returns: + None + """ + if self.connection_state == "init": + await self.receive() + _headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else [] + + if isinstance(headers, dict): + _headers = Headers(headers=headers).to_header_list() + + if isinstance(headers, Headers): + _headers = headers.to_header_list() + + event: WebSocketAcceptEvent = { + "type": "websocket.accept", + "subprotocol": subprotocols, + "headers": _headers, + } + await self.send(event) + + async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None: + """Send an 'websocket.close' event. + + Args: + code: Status code. + reason: Reason for closing the connection + + Returns: + None + """ + event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""} + await self.send(event) + + @overload + async def receive_data(self, mode: Literal["text"]) -> str: ... + + @overload + async def receive_data(self, mode: Literal["binary"]) -> bytes: ... + + async def receive_data(self, mode: WebSocketMode) -> str | bytes: + """Receive an 'websocket.receive' event and returns the data stored on it. + + Args: + mode: The respective event key to use. + + Returns: + The event's data. + """ + if self.connection_state == "init": + await self.accept() + event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive()) + if event["type"] == "websocket.disconnect": + raise WebSocketDisconnect(detail="disconnect event", code=event["code"]) + return event.get("text") or "" if mode == "text" else event.get("bytes") or b"" + + @overload + def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ... + + @overload + def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ... + + async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]: + """Continuously receive data and yield it + + Args: + mode: Socket mode to use. Either ``text`` or ``binary`` + """ + try: + while True: + yield await self.receive_data(mode) + except WebSocketDisconnect: + pass + + async def receive_text(self) -> str: + """Receive data as text. + + Returns: + A string. + """ + return await self.receive_data(mode="text") + + async def receive_bytes(self) -> bytes: + """Receive data as bytes. + + Returns: + A byte-string. + """ + return await self.receive_data(mode="binary") + + async def receive_json(self, mode: WebSocketMode = "text") -> Any: + """Receive data and decode it as JSON. + + Args: + mode: Either ``text`` or ``binary``. + + Returns: + An arbitrary value + """ + data = await self.receive_data(mode=mode) + return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def receive_msgpack(self) -> Any: + """Receive data and decode it as MessagePack. + + Note that since MessagePack is a binary format, this method will always receive + data in ``binary`` mode. + + Returns: + An arbitrary value + """ + data = await self.receive_data(mode="binary") + return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]: + """Continuously receive data and yield it, decoding it as JSON in the process. + + Args: + mode: Socket mode to use. Either ``text`` or ``binary`` + """ + async for data in self.iter_data(mode): + yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def iter_msgpack(self) -> AsyncGenerator[Any, None]: + """Continuously receive data and yield it, decoding it as MessagePack in the + process. + + Note that since MessagePack is a binary format, this method will always receive + data in ``binary`` mode. + + """ + async for data in self.iter_data(mode="binary"): + yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders()) + + async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None: + """Send a 'websocket.send' event. + + Args: + data: Data to send. + mode: The respective event key to use. + encoding: Encoding to use when converting bytes / str. + + Returns: + None + """ + if self.connection_state == "init": # pragma: no cover + await self.accept() + event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None} + if mode == "binary": + event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding) + else: + event["text"] = data if isinstance(data, str) else data.decode(encoding) + await self.send(event) + + @overload + async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ... + + @overload + async def send_text(self, data: str) -> None: ... + + async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None: + """Send data using the ``text`` key of the send event. + + Args: + data: Data to send + encoding: Encoding to use for binary data. + + Returns: + None + """ + await self.send_data(data=data, encoding=encoding) + + @overload + async def send_bytes(self, data: bytes) -> None: ... + + @overload + async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ... + + async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None: + """Send data using the ``bytes`` key of the send event. + + Args: + data: Data to send + encoding: Encoding to use for binary data. + + Returns: + None + """ + await self.send_data(data=data, mode="binary", encoding=encoding) + + async def send_json( + self, + data: Any, + mode: WebSocketMode = "text", + encoding: str = "utf-8", + serializer: Serializer = default_serializer, + ) -> None: + """Send data as JSON. + + Args: + data: A value to serialize. + mode: Either ``text`` or ``binary``. + encoding: Encoding to use for binary data. + serializer: A serializer function. + + Returns: + None + """ + await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding) + + async def send_msgpack( + self, + data: Any, + encoding: str = "utf-8", + serializer: Serializer = default_serializer, + ) -> None: + """Send data as MessagePack. + + Note that since MessagePack is a binary format, this method will always send + data in ``binary`` mode. + + Args: + data: A value to serialize. + encoding: Encoding to use for binary data. + serializer: A serializer function. + + Returns: + None + """ + await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding) |