summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/connection
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:17:55 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:17:55 -0400
commit12cf076118570eebbff08c6b3090e0d4798447a1 (patch)
tree3ba25e17e3c3a5e82316558ba3864b955919ff72 /venv/lib/python3.11/site-packages/litestar/connection
parentc45662ff3923b34614ddcc8feb9195541166dcc5 (diff)
no venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/connection')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__init__.py37
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pycbin2206 -> 0 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pycbin15545 -> 0 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pycbin12711 -> 0 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pycbin17142 -> 0 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/base.py345
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/request.py263
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/websocket.py343
8 files changed, 0 insertions, 988 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__init__.py b/venv/lib/python3.11/site-packages/litestar/connection/__init__.py
deleted file mode 100644
index 6922e79..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/__init__.py
+++ /dev/null
@@ -1,37 +0,0 @@
-"""Some code in this module was adapted from https://github.com/encode/starlette/blob/master/starlette/requests.py and
-https://github.com/encode/starlette/blob/master/starlette/websockets.py.
-
-Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/).
-All rights reserved.
-
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions are met:
-
-* Redistributions of source code must retain the above copyright notice, this
- list of conditions and the following disclaimer.
-
-* Redistributions in binary form must reproduce the above copyright notice,
- this list of conditions and the following disclaimer in the documentation
- and/or other materials provided with the distribution.
-
-* Neither the name of the copyright holder nor the names of its
- contributors may be used to endorse or promote products derived from
- this software without specific prior written permission.
-
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""
-
-from litestar.connection.base import ASGIConnection
-from litestar.connection.request import Request
-from litestar.connection.websocket import WebSocket
-
-__all__ = ("ASGIConnection", "Request", "WebSocket")
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc
deleted file mode 100644
index 89cb4db..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc
+++ /dev/null
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc
deleted file mode 100644
index f1eff74..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc
+++ /dev/null
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc
deleted file mode 100644
index f047a18..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc
+++ /dev/null
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc
deleted file mode 100644
index 49e294c..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc
+++ /dev/null
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/base.py b/venv/lib/python3.11/site-packages/litestar/connection/base.py
deleted file mode 100644
index d14c662..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/base.py
+++ /dev/null
@@ -1,345 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
-
-from litestar._parsers import parse_cookie_string, parse_query_string
-from litestar.datastructures.headers import Headers
-from litestar.datastructures.multi_dicts import MultiDict
-from litestar.datastructures.state import State
-from litestar.datastructures.url import URL, Address, make_absolute_url
-from litestar.exceptions import ImproperlyConfiguredException
-from litestar.types.empty import Empty
-from litestar.utils.empty import value_or_default
-from litestar.utils.scope.state import ScopeState
-
-if TYPE_CHECKING:
- from typing import NoReturn
-
- from litestar.app import Litestar
- from litestar.types import DataContainerType, EmptyType
- from litestar.types.asgi_types import Message, Receive, Scope, Send
- from litestar.types.protocols import Logger
-
-__all__ = ("ASGIConnection", "empty_receive", "empty_send")
-
-UserT = TypeVar("UserT")
-AuthT = TypeVar("AuthT")
-HandlerT = TypeVar("HandlerT")
-StateT = TypeVar("StateT", bound=State)
-
-
-async def empty_receive() -> NoReturn: # pragma: no cover
- """Raise a ``RuntimeError``.
-
- Serves as a placeholder ``send`` function.
-
- Raises:
- RuntimeError
- """
- raise RuntimeError()
-
-
-async def empty_send(_: Message) -> NoReturn: # pragma: no cover
- """Raise a ``RuntimeError``.
-
- Serves as a placeholder ``send`` function.
-
- Args:
- _: An ASGI message
-
- Raises:
- RuntimeError
- """
- raise RuntimeError()
-
-
-class ASGIConnection(Generic[HandlerT, UserT, AuthT, StateT]):
- """The base ASGI connection container."""
-
- __slots__ = (
- "scope",
- "receive",
- "send",
- "_base_url",
- "_url",
- "_parsed_query",
- "_cookies",
- "_server_extensions",
- "_connection_state",
- )
-
- scope: Scope
- """The ASGI scope attached to the connection."""
- receive: Receive
- """The ASGI receive function."""
- send: Send
- """The ASGI send function."""
-
- def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
- """Initialize ``ASGIConnection``.
-
- Args:
- scope: The ASGI connection scope.
- receive: The ASGI receive function.
- send: The ASGI send function.
- """
- self.scope = scope
- self.receive = receive
- self.send = send
- self._connection_state = ScopeState.from_scope(scope)
- self._base_url: URL | EmptyType = Empty
- self._url: URL | EmptyType = Empty
- self._parsed_query: tuple[tuple[str, str], ...] | EmptyType = Empty
- self._cookies: dict[str, str] | EmptyType = Empty
- self._server_extensions = scope.get("extensions") or {} # extensions may be None
-
- @property
- def app(self) -> Litestar:
- """Return the ``app`` for this connection.
-
- Returns:
- The :class:`Litestar <litestar.app.Litestar>` application instance
- """
- return self.scope["app"]
-
- @property
- def route_handler(self) -> HandlerT:
- """Return the ``route_handler`` for this connection.
-
- Returns:
- The target route handler instance.
- """
- return cast("HandlerT", self.scope["route_handler"])
-
- @property
- def state(self) -> StateT:
- """Return the ``State`` of this connection.
-
- Returns:
- A State instance constructed from the scope["state"] value.
- """
- return cast("StateT", State(self.scope.get("state")))
-
- @property
- def url(self) -> URL:
- """Return the URL of this connection's ``Scope``.
-
- Returns:
- A URL instance constructed from the request's scope.
- """
- if self._url is Empty:
- if (url := self._connection_state.url) is not Empty:
- self._url = url
- else:
- self._connection_state.url = self._url = URL.from_scope(self.scope)
-
- return self._url
-
- @property
- def base_url(self) -> URL:
- """Return the base URL of this connection's ``Scope``.
-
- Returns:
- A URL instance constructed from the request's scope, representing only the base part
- (host + domain + prefix) of the request.
- """
- if self._base_url is Empty:
- if (base_url := self._connection_state.base_url) is not Empty:
- self._base_url = base_url
- else:
- scope = cast(
- "Scope",
- {
- **self.scope,
- "path": "/",
- "query_string": b"",
- "root_path": self.scope.get("app_root_path") or self.scope.get("root_path", ""),
- },
- )
- self._connection_state.base_url = self._base_url = URL.from_scope(scope)
- return self._base_url
-
- @property
- def headers(self) -> Headers:
- """Return the headers of this connection's ``Scope``.
-
- Returns:
- A Headers instance with the request's scope["headers"] value.
- """
- return Headers.from_scope(self.scope)
-
- @property
- def query_params(self) -> MultiDict[Any]:
- """Return the query parameters of this connection's ``Scope``.
-
- Returns:
- A normalized dict of query parameters. Multiple values for the same key are returned as a list.
- """
- if self._parsed_query is Empty:
- if (parsed_query := self._connection_state.parsed_query) is not Empty:
- self._parsed_query = parsed_query
- else:
- self._connection_state.parsed_query = self._parsed_query = parse_query_string(
- self.scope.get("query_string", b"")
- )
- return MultiDict(self._parsed_query)
-
- @property
- def path_params(self) -> dict[str, Any]:
- """Return the ``path_params`` of this connection's ``Scope``.
-
- Returns:
- A string keyed dictionary of path parameter values.
- """
- return self.scope["path_params"]
-
- @property
- def cookies(self) -> dict[str, str]:
- """Return the ``cookies`` of this connection's ``Scope``.
-
- Returns:
- Returns any cookies stored in the header as a parsed dictionary.
- """
- if self._cookies is Empty:
- if (cookies := self._connection_state.cookies) is not Empty:
- self._cookies = cookies
- else:
- self._connection_state.cookies = self._cookies = (
- parse_cookie_string(cookie_header) if (cookie_header := self.headers.get("cookie")) else {}
- )
- return self._cookies
-
- @property
- def client(self) -> Address | None:
- """Return the ``client`` data of this connection's ``Scope``.
-
- Returns:
- A two tuple of the host name and port number.
- """
- client = self.scope.get("client")
- return Address(*client) if client else None
-
- @property
- def auth(self) -> AuthT:
- """Return the ``auth`` data of this connection's ``Scope``.
-
- Raises:
- ImproperlyConfiguredException: If ``auth`` is not set in scope via an ``AuthMiddleware``, raises an exception
-
- Returns:
- A type correlating to the generic variable Auth.
- """
- if "auth" not in self.scope:
- raise ImproperlyConfiguredException("'auth' is not defined in scope, install an AuthMiddleware to set it")
-
- return cast("AuthT", self.scope["auth"])
-
- @property
- def user(self) -> UserT:
- """Return the ``user`` data of this connection's ``Scope``.
-
- Raises:
- ImproperlyConfiguredException: If ``user`` is not set in scope via an ``AuthMiddleware``, raises an exception
-
- Returns:
- A type correlating to the generic variable User.
- """
- if "user" not in self.scope:
- raise ImproperlyConfiguredException("'user' is not defined in scope, install an AuthMiddleware to set it")
-
- return cast("UserT", self.scope["user"])
-
- @property
- def session(self) -> dict[str, Any]:
- """Return the session for this connection if a session was previously set in the ``Scope``
-
- Returns:
- A dictionary representing the session value - if existing.
-
- Raises:
- ImproperlyConfiguredException: if session is not set in scope.
- """
- if "session" not in self.scope:
- raise ImproperlyConfiguredException(
- "'session' is not defined in scope, install a SessionMiddleware to set it"
- )
-
- return cast("dict[str, Any]", self.scope["session"])
-
- @property
- def logger(self) -> Logger:
- """Return the ``Logger`` instance for this connection.
-
- Returns:
- A ``Logger`` instance.
-
- Raises:
- ImproperlyConfiguredException: if ``log_config`` has not been passed to the Litestar constructor.
- """
- return self.app.get_logger()
-
- def set_session(self, value: dict[str, Any] | DataContainerType | EmptyType) -> None:
- """Set the session in the connection's ``Scope``.
-
- If the :class:`SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, the session will be added
- to the response as a cookie header.
-
- Args:
- value: Dictionary or pydantic model instance for the session data.
-
- Returns:
- None
- """
- self.scope["session"] = value
-
- def clear_session(self) -> None:
- """Remove the session from the connection's ``Scope``.
-
- If the :class:`Litestar SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, this will cause
- the session data to be cleared.
-
- Returns:
- None.
- """
- self.scope["session"] = Empty
- self._connection_state.session_id = Empty
-
- def get_session_id(self) -> str | None:
- return value_or_default(value=self._connection_state.session_id, default=None)
-
- def url_for(self, name: str, **path_parameters: Any) -> str:
- """Return the url for a given route handler name.
-
- Args:
- name: The ``name`` of the request route handler.
- **path_parameters: Values for path parameters in the route
-
- Raises:
- NoRouteMatchFoundException: If route with ``name`` does not exist, path parameters are missing or have a
- wrong type.
-
- Returns:
- A string representing the absolute url of the route handler.
- """
- litestar_instance = self.scope["app"]
- url_path = litestar_instance.route_reverse(name, **path_parameters)
-
- return make_absolute_url(url_path, self.base_url)
-
- def url_for_static_asset(self, name: str, file_path: str) -> str:
- """Receives a static files handler name, an asset file path and returns resolved absolute url to the asset.
-
- Args:
- name: A static handler unique name.
- file_path: a string containing path to an asset.
-
- Raises:
- NoRouteMatchFoundException: If static files handler with ``name`` does not exist.
-
- Returns:
- A string representing absolute url to the asset.
- """
- litestar_instance = self.scope["app"]
- url_path = litestar_instance.url_for_static_asset(name, file_path)
-
- return make_absolute_url(url_path, self.base_url)
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/request.py b/venv/lib/python3.11/site-packages/litestar/connection/request.py
deleted file mode 100644
index 254c315..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/request.py
+++ /dev/null
@@ -1,263 +0,0 @@
-from __future__ import annotations
-
-import warnings
-from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic
-
-from litestar._multipart import parse_content_header, parse_multipart_form
-from litestar._parsers import parse_url_encoded_form_data
-from litestar.connection.base import (
- ASGIConnection,
- AuthT,
- StateT,
- UserT,
- empty_receive,
- empty_send,
-)
-from litestar.datastructures.headers import Accept
-from litestar.datastructures.multi_dicts import FormMultiDict
-from litestar.enums import ASGIExtension, RequestEncodingType
-from litestar.exceptions import (
- InternalServerException,
- LitestarException,
- LitestarWarning,
-)
-from litestar.serialization import decode_json, decode_msgpack
-from litestar.types import Empty
-
-__all__ = ("Request",)
-
-
-if TYPE_CHECKING:
- from litestar.handlers.http_handlers import HTTPRouteHandler # noqa: F401
- from litestar.types.asgi_types import HTTPScope, Method, Receive, Scope, Send
- from litestar.types.empty import EmptyType
-
-
-SERVER_PUSH_HEADERS = {
- "accept",
- "accept-encoding",
- "accept-language",
- "cache-control",
- "user-agent",
-}
-
-
-class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]):
- """The Litestar Request class."""
-
- __slots__ = (
- "_json",
- "_form",
- "_body",
- "_msgpack",
- "_content_type",
- "_accept",
- "is_connected",
- "supports_push_promise",
- )
-
- scope: HTTPScope # pyright: ignore
- """The ASGI scope attached to the connection."""
- receive: Receive
- """The ASGI receive function."""
- send: Send
- """The ASGI send function."""
-
- def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
- """Initialize ``Request``.
-
- Args:
- scope: The ASGI connection scope.
- receive: The ASGI receive function.
- send: The ASGI send function.
- """
- super().__init__(scope, receive, send)
- self.is_connected: bool = True
- self._body: bytes | EmptyType = Empty
- self._form: dict[str, str | list[str]] | EmptyType = Empty
- self._json: Any = Empty
- self._msgpack: Any = Empty
- self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty
- self._accept: Accept | EmptyType = Empty
- self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions
-
- @property
- def method(self) -> Method:
- """Return the request method.
-
- Returns:
- The request :class:`Method <litestar.types.Method>`
- """
- return self.scope["method"]
-
- @property
- def content_type(self) -> tuple[str, dict[str, str]]:
- """Parse the request's 'Content-Type' header, returning the header value and any options as a dictionary.
-
- Returns:
- A tuple with the parsed value and a dictionary containing any options send in it.
- """
- if self._content_type is Empty:
- if (content_type := self._connection_state.content_type) is not Empty:
- self._content_type = content_type
- else:
- self._content_type = self._connection_state.content_type = parse_content_header(
- self.headers.get("Content-Type", "")
- )
- return self._content_type
-
- @property
- def accept(self) -> Accept:
- """Parse the request's 'Accept' header, returning an :class:`Accept <litestar.datastructures.headers.Accept>` instance.
-
- Returns:
- An :class:`Accept <litestar.datastructures.headers.Accept>` instance, representing the list of acceptable media types.
- """
- if self._accept is Empty:
- if (accept := self._connection_state.accept) is not Empty:
- self._accept = accept
- else:
- self._accept = self._connection_state.accept = Accept(self.headers.get("Accept", "*/*"))
- return self._accept
-
- async def json(self) -> Any:
- """Retrieve the json request body from the request.
-
- Returns:
- An arbitrary value
- """
- if self._json is Empty:
- if (json_ := self._connection_state.json) is not Empty:
- self._json = json_
- else:
- body = await self.body()
- self._json = self._connection_state.json = decode_json(
- body or b"null", type_decoders=self.route_handler.resolve_type_decoders()
- )
- return self._json
-
- async def msgpack(self) -> Any:
- """Retrieve the MessagePack request body from the request.
-
- Returns:
- An arbitrary value
- """
- if self._msgpack is Empty:
- if (msgpack := self._connection_state.msgpack) is not Empty:
- self._msgpack = msgpack
- else:
- body = await self.body()
- self._msgpack = self._connection_state.msgpack = decode_msgpack(
- body or b"\xc0", type_decoders=self.route_handler.resolve_type_decoders()
- )
- return self._msgpack
-
- async def stream(self) -> AsyncGenerator[bytes, None]:
- """Return an async generator that streams chunks of bytes.
-
- Returns:
- An async generator.
-
- Raises:
- RuntimeError: if the stream is already consumed
- """
- if self._body is Empty:
- if not self.is_connected:
- raise InternalServerException("stream consumed")
- while event := await self.receive():
- if event["type"] == "http.request":
- if event["body"]:
- yield event["body"]
-
- if not event.get("more_body", False):
- break
-
- if event["type"] == "http.disconnect":
- raise InternalServerException("client disconnected prematurely")
-
- self.is_connected = False
- yield b""
-
- else:
- yield self._body
- yield b""
- return
-
- async def body(self) -> bytes:
- """Return the body of the request.
-
- Returns:
- A byte-string representing the body of the request.
- """
- if self._body is Empty:
- if (body := self._connection_state.body) is not Empty:
- self._body = body
- else:
- self._body = self._connection_state.body = b"".join([c async for c in self.stream()])
- return self._body
-
- async def form(self) -> FormMultiDict:
- """Retrieve form data from the request. If the request is either a 'multipart/form-data' or an
- 'application/x-www-form- urlencoded', return a FormMultiDict instance populated with the values sent in the
- request, otherwise, an empty instance.
-
- Returns:
- A FormMultiDict instance
- """
- if self._form is Empty:
- if (form := self._connection_state.form) is not Empty:
- self._form = form
- else:
- content_type, options = self.content_type
- if content_type == RequestEncodingType.MULTI_PART:
- self._form = parse_multipart_form(
- body=await self.body(),
- boundary=options.get("boundary", "").encode(),
- multipart_form_part_limit=self.app.multipart_form_part_limit,
- )
- elif content_type == RequestEncodingType.URL_ENCODED:
- self._form = parse_url_encoded_form_data(
- await self.body(),
- )
- else:
- self._form = {}
-
- self._connection_state.form = self._form
-
- return FormMultiDict(self._form)
-
- async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None:
- """Send a push promise.
-
- This method requires the `http.response.push` extension to be sent from the ASGI server.
-
- Args:
- path: Path to send the promise to.
- raise_if_unavailable: Raise an exception if server push is not supported by
- the server
-
- Returns:
- None
- """
- if not self.supports_push_promise:
- if raise_if_unavailable:
- raise LitestarException("Attempted to send a push promise but the server does not support it")
-
- warnings.warn(
- "Attempted to send a push promise but the server does not support it. In a future version, this will "
- "raise an exception. To enable this behaviour in the current version, set raise_if_unavailable=True. "
- "To prevent this behaviour, make sure that the server you are using supports the 'http.response.push' "
- "ASGI extension, or check this dynamically via "
- ":attr:`~litestar.connection.Request.supports_push_promise`",
- stacklevel=2,
- category=LitestarWarning,
- )
-
- return
-
- raw_headers = [
- (header_name.encode("latin-1"), value.encode("latin-1"))
- for header_name in (self.headers.keys() & SERVER_PUSH_HEADERS)
- for value in self.headers.getall(header_name, [])
- ]
- await self.send({"type": "http.response.push", "path": path, "headers": raw_headers})
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py
deleted file mode 100644
index 0c7bc04..0000000
--- a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py
+++ /dev/null
@@ -1,343 +0,0 @@
-from __future__ import annotations
-
-from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, Literal, cast, overload
-
-from litestar.connection.base import (
- ASGIConnection,
- AuthT,
- StateT,
- UserT,
- empty_receive,
- empty_send,
-)
-from litestar.datastructures.headers import Headers
-from litestar.exceptions import WebSocketDisconnect
-from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack
-from litestar.status_codes import WS_1000_NORMAL_CLOSURE
-
-__all__ = ("WebSocket",)
-
-
-if TYPE_CHECKING:
- from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401
- from litestar.types import Message, Serializer, WebSocketScope
- from litestar.types.asgi_types import (
- Receive,
- ReceiveMessage,
- Scope,
- Send,
- WebSocketAcceptEvent,
- WebSocketCloseEvent,
- WebSocketDisconnectEvent,
- WebSocketMode,
- WebSocketReceiveEvent,
- WebSocketSendEvent,
- )
-
-DISCONNECT_MESSAGE = "connection is disconnected"
-
-
-class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]):
- """The Litestar WebSocket class."""
-
- __slots__ = ("connection_state",)
-
- scope: WebSocketScope # pyright: ignore
- """The ASGI scope attached to the connection."""
- receive: Receive
- """The ASGI receive function."""
- send: Send
- """The ASGI send function."""
-
- def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
- """Initialize ``WebSocket``.
-
- Args:
- scope: The ASGI connection scope.
- receive: The ASGI receive function.
- send: The ASGI send function.
- """
- super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send))
- self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init"
-
- def receive_wrapper(self, receive: Receive) -> Receive:
- """Wrap ``receive`` to set 'self.connection_state' and validate events.
-
- Args:
- receive: The ASGI receive function.
-
- Returns:
- An ASGI receive function.
- """
-
- async def wrapped_receive() -> ReceiveMessage:
- if self.connection_state == "disconnect":
- raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE)
- message = await receive()
- if message["type"] == "websocket.connect":
- self.connection_state = "connect"
- elif message["type"] == "websocket.receive":
- self.connection_state = "receive"
- else:
- self.connection_state = "disconnect"
- return message
-
- return wrapped_receive
-
- def send_wrapper(self, send: Send) -> Send:
- """Wrap ``send`` to ensure that state is not disconnected.
-
- Args:
- send: The ASGI send function.
-
- Returns:
- An ASGI send function.
- """
-
- async def wrapped_send(message: Message) -> None:
- if self.connection_state == "disconnect":
- raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover
- await send(message)
-
- return wrapped_send
-
- async def accept(
- self,
- subprotocols: str | None = None,
- headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None,
- ) -> None:
- """Accept the incoming connection. This method should be called before receiving data.
-
- Args:
- subprotocols: Websocket sub-protocol to use.
- headers: Headers to set on the data sent.
-
- Returns:
- None
- """
- if self.connection_state == "init":
- await self.receive()
- _headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else []
-
- if isinstance(headers, dict):
- _headers = Headers(headers=headers).to_header_list()
-
- if isinstance(headers, Headers):
- _headers = headers.to_header_list()
-
- event: WebSocketAcceptEvent = {
- "type": "websocket.accept",
- "subprotocol": subprotocols,
- "headers": _headers,
- }
- await self.send(event)
-
- async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None:
- """Send an 'websocket.close' event.
-
- Args:
- code: Status code.
- reason: Reason for closing the connection
-
- Returns:
- None
- """
- event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""}
- await self.send(event)
-
- @overload
- async def receive_data(self, mode: Literal["text"]) -> str: ...
-
- @overload
- async def receive_data(self, mode: Literal["binary"]) -> bytes: ...
-
- async def receive_data(self, mode: WebSocketMode) -> str | bytes:
- """Receive an 'websocket.receive' event and returns the data stored on it.
-
- Args:
- mode: The respective event key to use.
-
- Returns:
- The event's data.
- """
- if self.connection_state == "init":
- await self.accept()
- event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive())
- if event["type"] == "websocket.disconnect":
- raise WebSocketDisconnect(detail="disconnect event", code=event["code"])
- return event.get("text") or "" if mode == "text" else event.get("bytes") or b""
-
- @overload
- def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ...
-
- @overload
- def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ...
-
- async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]:
- """Continuously receive data and yield it
-
- Args:
- mode: Socket mode to use. Either ``text`` or ``binary``
- """
- try:
- while True:
- yield await self.receive_data(mode)
- except WebSocketDisconnect:
- pass
-
- async def receive_text(self) -> str:
- """Receive data as text.
-
- Returns:
- A string.
- """
- return await self.receive_data(mode="text")
-
- async def receive_bytes(self) -> bytes:
- """Receive data as bytes.
-
- Returns:
- A byte-string.
- """
- return await self.receive_data(mode="binary")
-
- async def receive_json(self, mode: WebSocketMode = "text") -> Any:
- """Receive data and decode it as JSON.
-
- Args:
- mode: Either ``text`` or ``binary``.
-
- Returns:
- An arbitrary value
- """
- data = await self.receive_data(mode=mode)
- return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders())
-
- async def receive_msgpack(self) -> Any:
- """Receive data and decode it as MessagePack.
-
- Note that since MessagePack is a binary format, this method will always receive
- data in ``binary`` mode.
-
- Returns:
- An arbitrary value
- """
- data = await self.receive_data(mode="binary")
- return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders())
-
- async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]:
- """Continuously receive data and yield it, decoding it as JSON in the process.
-
- Args:
- mode: Socket mode to use. Either ``text`` or ``binary``
- """
- async for data in self.iter_data(mode):
- yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders())
-
- async def iter_msgpack(self) -> AsyncGenerator[Any, None]:
- """Continuously receive data and yield it, decoding it as MessagePack in the
- process.
-
- Note that since MessagePack is a binary format, this method will always receive
- data in ``binary`` mode.
-
- """
- async for data in self.iter_data(mode="binary"):
- yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders())
-
- async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None:
- """Send a 'websocket.send' event.
-
- Args:
- data: Data to send.
- mode: The respective event key to use.
- encoding: Encoding to use when converting bytes / str.
-
- Returns:
- None
- """
- if self.connection_state == "init": # pragma: no cover
- await self.accept()
- event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None}
- if mode == "binary":
- event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding)
- else:
- event["text"] = data if isinstance(data, str) else data.decode(encoding)
- await self.send(event)
-
- @overload
- async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ...
-
- @overload
- async def send_text(self, data: str) -> None: ...
-
- async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None:
- """Send data using the ``text`` key of the send event.
-
- Args:
- data: Data to send
- encoding: Encoding to use for binary data.
-
- Returns:
- None
- """
- await self.send_data(data=data, encoding=encoding)
-
- @overload
- async def send_bytes(self, data: bytes) -> None: ...
-
- @overload
- async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ...
-
- async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None:
- """Send data using the ``bytes`` key of the send event.
-
- Args:
- data: Data to send
- encoding: Encoding to use for binary data.
-
- Returns:
- None
- """
- await self.send_data(data=data, mode="binary", encoding=encoding)
-
- async def send_json(
- self,
- data: Any,
- mode: WebSocketMode = "text",
- encoding: str = "utf-8",
- serializer: Serializer = default_serializer,
- ) -> None:
- """Send data as JSON.
-
- Args:
- data: A value to serialize.
- mode: Either ``text`` or ``binary``.
- encoding: Encoding to use for binary data.
- serializer: A serializer function.
-
- Returns:
- None
- """
- await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding)
-
- async def send_msgpack(
- self,
- data: Any,
- encoding: str = "utf-8",
- serializer: Serializer = default_serializer,
- ) -> None:
- """Send data as MessagePack.
-
- Note that since MessagePack is a binary format, this method will always send
- data in ``binary`` mode.
-
- Args:
- data: A value to serialize.
- encoding: Encoding to use for binary data.
- serializer: A serializer function.
-
- Returns:
- None
- """
- await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding)