summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/connection
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/connection')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__init__.py37
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pycbin0 -> 2206 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pycbin0 -> 15545 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pycbin0 -> 12711 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pycbin0 -> 17142 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/base.py345
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/request.py263
-rw-r--r--venv/lib/python3.11/site-packages/litestar/connection/websocket.py343
8 files changed, 988 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__init__.py b/venv/lib/python3.11/site-packages/litestar/connection/__init__.py
new file mode 100644
index 0000000..6922e79
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/__init__.py
@@ -0,0 +1,37 @@
+"""Some code in this module was adapted from https://github.com/encode/starlette/blob/master/starlette/requests.py and
+https://github.com/encode/starlette/blob/master/starlette/websockets.py.
+
+Copyright © 2018, [Encode OSS Ltd](https://www.encode.io/).
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+"""
+
+from litestar.connection.base import ASGIConnection
+from litestar.connection.request import Request
+from litestar.connection.websocket import WebSocket
+
+__all__ = ("ASGIConnection", "Request", "WebSocket")
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..89cb4db
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/__init__.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000..f1eff74
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/base.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc
new file mode 100644
index 0000000..f047a18
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/request.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc
new file mode 100644
index 0000000..49e294c
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/__pycache__/websocket.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/base.py b/venv/lib/python3.11/site-packages/litestar/connection/base.py
new file mode 100644
index 0000000..d14c662
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/base.py
@@ -0,0 +1,345 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
+
+from litestar._parsers import parse_cookie_string, parse_query_string
+from litestar.datastructures.headers import Headers
+from litestar.datastructures.multi_dicts import MultiDict
+from litestar.datastructures.state import State
+from litestar.datastructures.url import URL, Address, make_absolute_url
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.types.empty import Empty
+from litestar.utils.empty import value_or_default
+from litestar.utils.scope.state import ScopeState
+
+if TYPE_CHECKING:
+ from typing import NoReturn
+
+ from litestar.app import Litestar
+ from litestar.types import DataContainerType, EmptyType
+ from litestar.types.asgi_types import Message, Receive, Scope, Send
+ from litestar.types.protocols import Logger
+
+__all__ = ("ASGIConnection", "empty_receive", "empty_send")
+
+UserT = TypeVar("UserT")
+AuthT = TypeVar("AuthT")
+HandlerT = TypeVar("HandlerT")
+StateT = TypeVar("StateT", bound=State)
+
+
+async def empty_receive() -> NoReturn: # pragma: no cover
+ """Raise a ``RuntimeError``.
+
+ Serves as a placeholder ``send`` function.
+
+ Raises:
+ RuntimeError
+ """
+ raise RuntimeError()
+
+
+async def empty_send(_: Message) -> NoReturn: # pragma: no cover
+ """Raise a ``RuntimeError``.
+
+ Serves as a placeholder ``send`` function.
+
+ Args:
+ _: An ASGI message
+
+ Raises:
+ RuntimeError
+ """
+ raise RuntimeError()
+
+
+class ASGIConnection(Generic[HandlerT, UserT, AuthT, StateT]):
+ """The base ASGI connection container."""
+
+ __slots__ = (
+ "scope",
+ "receive",
+ "send",
+ "_base_url",
+ "_url",
+ "_parsed_query",
+ "_cookies",
+ "_server_extensions",
+ "_connection_state",
+ )
+
+ scope: Scope
+ """The ASGI scope attached to the connection."""
+ receive: Receive
+ """The ASGI receive function."""
+ send: Send
+ """The ASGI send function."""
+
+ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
+ """Initialize ``ASGIConnection``.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+ """
+ self.scope = scope
+ self.receive = receive
+ self.send = send
+ self._connection_state = ScopeState.from_scope(scope)
+ self._base_url: URL | EmptyType = Empty
+ self._url: URL | EmptyType = Empty
+ self._parsed_query: tuple[tuple[str, str], ...] | EmptyType = Empty
+ self._cookies: dict[str, str] | EmptyType = Empty
+ self._server_extensions = scope.get("extensions") or {} # extensions may be None
+
+ @property
+ def app(self) -> Litestar:
+ """Return the ``app`` for this connection.
+
+ Returns:
+ The :class:`Litestar <litestar.app.Litestar>` application instance
+ """
+ return self.scope["app"]
+
+ @property
+ def route_handler(self) -> HandlerT:
+ """Return the ``route_handler`` for this connection.
+
+ Returns:
+ The target route handler instance.
+ """
+ return cast("HandlerT", self.scope["route_handler"])
+
+ @property
+ def state(self) -> StateT:
+ """Return the ``State`` of this connection.
+
+ Returns:
+ A State instance constructed from the scope["state"] value.
+ """
+ return cast("StateT", State(self.scope.get("state")))
+
+ @property
+ def url(self) -> URL:
+ """Return the URL of this connection's ``Scope``.
+
+ Returns:
+ A URL instance constructed from the request's scope.
+ """
+ if self._url is Empty:
+ if (url := self._connection_state.url) is not Empty:
+ self._url = url
+ else:
+ self._connection_state.url = self._url = URL.from_scope(self.scope)
+
+ return self._url
+
+ @property
+ def base_url(self) -> URL:
+ """Return the base URL of this connection's ``Scope``.
+
+ Returns:
+ A URL instance constructed from the request's scope, representing only the base part
+ (host + domain + prefix) of the request.
+ """
+ if self._base_url is Empty:
+ if (base_url := self._connection_state.base_url) is not Empty:
+ self._base_url = base_url
+ else:
+ scope = cast(
+ "Scope",
+ {
+ **self.scope,
+ "path": "/",
+ "query_string": b"",
+ "root_path": self.scope.get("app_root_path") or self.scope.get("root_path", ""),
+ },
+ )
+ self._connection_state.base_url = self._base_url = URL.from_scope(scope)
+ return self._base_url
+
+ @property
+ def headers(self) -> Headers:
+ """Return the headers of this connection's ``Scope``.
+
+ Returns:
+ A Headers instance with the request's scope["headers"] value.
+ """
+ return Headers.from_scope(self.scope)
+
+ @property
+ def query_params(self) -> MultiDict[Any]:
+ """Return the query parameters of this connection's ``Scope``.
+
+ Returns:
+ A normalized dict of query parameters. Multiple values for the same key are returned as a list.
+ """
+ if self._parsed_query is Empty:
+ if (parsed_query := self._connection_state.parsed_query) is not Empty:
+ self._parsed_query = parsed_query
+ else:
+ self._connection_state.parsed_query = self._parsed_query = parse_query_string(
+ self.scope.get("query_string", b"")
+ )
+ return MultiDict(self._parsed_query)
+
+ @property
+ def path_params(self) -> dict[str, Any]:
+ """Return the ``path_params`` of this connection's ``Scope``.
+
+ Returns:
+ A string keyed dictionary of path parameter values.
+ """
+ return self.scope["path_params"]
+
+ @property
+ def cookies(self) -> dict[str, str]:
+ """Return the ``cookies`` of this connection's ``Scope``.
+
+ Returns:
+ Returns any cookies stored in the header as a parsed dictionary.
+ """
+ if self._cookies is Empty:
+ if (cookies := self._connection_state.cookies) is not Empty:
+ self._cookies = cookies
+ else:
+ self._connection_state.cookies = self._cookies = (
+ parse_cookie_string(cookie_header) if (cookie_header := self.headers.get("cookie")) else {}
+ )
+ return self._cookies
+
+ @property
+ def client(self) -> Address | None:
+ """Return the ``client`` data of this connection's ``Scope``.
+
+ Returns:
+ A two tuple of the host name and port number.
+ """
+ client = self.scope.get("client")
+ return Address(*client) if client else None
+
+ @property
+ def auth(self) -> AuthT:
+ """Return the ``auth`` data of this connection's ``Scope``.
+
+ Raises:
+ ImproperlyConfiguredException: If ``auth`` is not set in scope via an ``AuthMiddleware``, raises an exception
+
+ Returns:
+ A type correlating to the generic variable Auth.
+ """
+ if "auth" not in self.scope:
+ raise ImproperlyConfiguredException("'auth' is not defined in scope, install an AuthMiddleware to set it")
+
+ return cast("AuthT", self.scope["auth"])
+
+ @property
+ def user(self) -> UserT:
+ """Return the ``user`` data of this connection's ``Scope``.
+
+ Raises:
+ ImproperlyConfiguredException: If ``user`` is not set in scope via an ``AuthMiddleware``, raises an exception
+
+ Returns:
+ A type correlating to the generic variable User.
+ """
+ if "user" not in self.scope:
+ raise ImproperlyConfiguredException("'user' is not defined in scope, install an AuthMiddleware to set it")
+
+ return cast("UserT", self.scope["user"])
+
+ @property
+ def session(self) -> dict[str, Any]:
+ """Return the session for this connection if a session was previously set in the ``Scope``
+
+ Returns:
+ A dictionary representing the session value - if existing.
+
+ Raises:
+ ImproperlyConfiguredException: if session is not set in scope.
+ """
+ if "session" not in self.scope:
+ raise ImproperlyConfiguredException(
+ "'session' is not defined in scope, install a SessionMiddleware to set it"
+ )
+
+ return cast("dict[str, Any]", self.scope["session"])
+
+ @property
+ def logger(self) -> Logger:
+ """Return the ``Logger`` instance for this connection.
+
+ Returns:
+ A ``Logger`` instance.
+
+ Raises:
+ ImproperlyConfiguredException: if ``log_config`` has not been passed to the Litestar constructor.
+ """
+ return self.app.get_logger()
+
+ def set_session(self, value: dict[str, Any] | DataContainerType | EmptyType) -> None:
+ """Set the session in the connection's ``Scope``.
+
+ If the :class:`SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, the session will be added
+ to the response as a cookie header.
+
+ Args:
+ value: Dictionary or pydantic model instance for the session data.
+
+ Returns:
+ None
+ """
+ self.scope["session"] = value
+
+ def clear_session(self) -> None:
+ """Remove the session from the connection's ``Scope``.
+
+ If the :class:`Litestar SessionMiddleware <.middleware.session.base.SessionMiddleware>` is enabled, this will cause
+ the session data to be cleared.
+
+ Returns:
+ None.
+ """
+ self.scope["session"] = Empty
+ self._connection_state.session_id = Empty
+
+ def get_session_id(self) -> str | None:
+ return value_or_default(value=self._connection_state.session_id, default=None)
+
+ def url_for(self, name: str, **path_parameters: Any) -> str:
+ """Return the url for a given route handler name.
+
+ Args:
+ name: The ``name`` of the request route handler.
+ **path_parameters: Values for path parameters in the route
+
+ Raises:
+ NoRouteMatchFoundException: If route with ``name`` does not exist, path parameters are missing or have a
+ wrong type.
+
+ Returns:
+ A string representing the absolute url of the route handler.
+ """
+ litestar_instance = self.scope["app"]
+ url_path = litestar_instance.route_reverse(name, **path_parameters)
+
+ return make_absolute_url(url_path, self.base_url)
+
+ def url_for_static_asset(self, name: str, file_path: str) -> str:
+ """Receives a static files handler name, an asset file path and returns resolved absolute url to the asset.
+
+ Args:
+ name: A static handler unique name.
+ file_path: a string containing path to an asset.
+
+ Raises:
+ NoRouteMatchFoundException: If static files handler with ``name`` does not exist.
+
+ Returns:
+ A string representing absolute url to the asset.
+ """
+ litestar_instance = self.scope["app"]
+ url_path = litestar_instance.url_for_static_asset(name, file_path)
+
+ return make_absolute_url(url_path, self.base_url)
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/request.py b/venv/lib/python3.11/site-packages/litestar/connection/request.py
new file mode 100644
index 0000000..254c315
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/request.py
@@ -0,0 +1,263 @@
+from __future__ import annotations
+
+import warnings
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic
+
+from litestar._multipart import parse_content_header, parse_multipart_form
+from litestar._parsers import parse_url_encoded_form_data
+from litestar.connection.base import (
+ ASGIConnection,
+ AuthT,
+ StateT,
+ UserT,
+ empty_receive,
+ empty_send,
+)
+from litestar.datastructures.headers import Accept
+from litestar.datastructures.multi_dicts import FormMultiDict
+from litestar.enums import ASGIExtension, RequestEncodingType
+from litestar.exceptions import (
+ InternalServerException,
+ LitestarException,
+ LitestarWarning,
+)
+from litestar.serialization import decode_json, decode_msgpack
+from litestar.types import Empty
+
+__all__ = ("Request",)
+
+
+if TYPE_CHECKING:
+ from litestar.handlers.http_handlers import HTTPRouteHandler # noqa: F401
+ from litestar.types.asgi_types import HTTPScope, Method, Receive, Scope, Send
+ from litestar.types.empty import EmptyType
+
+
+SERVER_PUSH_HEADERS = {
+ "accept",
+ "accept-encoding",
+ "accept-language",
+ "cache-control",
+ "user-agent",
+}
+
+
+class Request(Generic[UserT, AuthT, StateT], ASGIConnection["HTTPRouteHandler", UserT, AuthT, StateT]):
+ """The Litestar Request class."""
+
+ __slots__ = (
+ "_json",
+ "_form",
+ "_body",
+ "_msgpack",
+ "_content_type",
+ "_accept",
+ "is_connected",
+ "supports_push_promise",
+ )
+
+ scope: HTTPScope # pyright: ignore
+ """The ASGI scope attached to the connection."""
+ receive: Receive
+ """The ASGI receive function."""
+ send: Send
+ """The ASGI send function."""
+
+ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
+ """Initialize ``Request``.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+ """
+ super().__init__(scope, receive, send)
+ self.is_connected: bool = True
+ self._body: bytes | EmptyType = Empty
+ self._form: dict[str, str | list[str]] | EmptyType = Empty
+ self._json: Any = Empty
+ self._msgpack: Any = Empty
+ self._content_type: tuple[str, dict[str, str]] | EmptyType = Empty
+ self._accept: Accept | EmptyType = Empty
+ self.supports_push_promise = ASGIExtension.SERVER_PUSH in self._server_extensions
+
+ @property
+ def method(self) -> Method:
+ """Return the request method.
+
+ Returns:
+ The request :class:`Method <litestar.types.Method>`
+ """
+ return self.scope["method"]
+
+ @property
+ def content_type(self) -> tuple[str, dict[str, str]]:
+ """Parse the request's 'Content-Type' header, returning the header value and any options as a dictionary.
+
+ Returns:
+ A tuple with the parsed value and a dictionary containing any options send in it.
+ """
+ if self._content_type is Empty:
+ if (content_type := self._connection_state.content_type) is not Empty:
+ self._content_type = content_type
+ else:
+ self._content_type = self._connection_state.content_type = parse_content_header(
+ self.headers.get("Content-Type", "")
+ )
+ return self._content_type
+
+ @property
+ def accept(self) -> Accept:
+ """Parse the request's 'Accept' header, returning an :class:`Accept <litestar.datastructures.headers.Accept>` instance.
+
+ Returns:
+ An :class:`Accept <litestar.datastructures.headers.Accept>` instance, representing the list of acceptable media types.
+ """
+ if self._accept is Empty:
+ if (accept := self._connection_state.accept) is not Empty:
+ self._accept = accept
+ else:
+ self._accept = self._connection_state.accept = Accept(self.headers.get("Accept", "*/*"))
+ return self._accept
+
+ async def json(self) -> Any:
+ """Retrieve the json request body from the request.
+
+ Returns:
+ An arbitrary value
+ """
+ if self._json is Empty:
+ if (json_ := self._connection_state.json) is not Empty:
+ self._json = json_
+ else:
+ body = await self.body()
+ self._json = self._connection_state.json = decode_json(
+ body or b"null", type_decoders=self.route_handler.resolve_type_decoders()
+ )
+ return self._json
+
+ async def msgpack(self) -> Any:
+ """Retrieve the MessagePack request body from the request.
+
+ Returns:
+ An arbitrary value
+ """
+ if self._msgpack is Empty:
+ if (msgpack := self._connection_state.msgpack) is not Empty:
+ self._msgpack = msgpack
+ else:
+ body = await self.body()
+ self._msgpack = self._connection_state.msgpack = decode_msgpack(
+ body or b"\xc0", type_decoders=self.route_handler.resolve_type_decoders()
+ )
+ return self._msgpack
+
+ async def stream(self) -> AsyncGenerator[bytes, None]:
+ """Return an async generator that streams chunks of bytes.
+
+ Returns:
+ An async generator.
+
+ Raises:
+ RuntimeError: if the stream is already consumed
+ """
+ if self._body is Empty:
+ if not self.is_connected:
+ raise InternalServerException("stream consumed")
+ while event := await self.receive():
+ if event["type"] == "http.request":
+ if event["body"]:
+ yield event["body"]
+
+ if not event.get("more_body", False):
+ break
+
+ if event["type"] == "http.disconnect":
+ raise InternalServerException("client disconnected prematurely")
+
+ self.is_connected = False
+ yield b""
+
+ else:
+ yield self._body
+ yield b""
+ return
+
+ async def body(self) -> bytes:
+ """Return the body of the request.
+
+ Returns:
+ A byte-string representing the body of the request.
+ """
+ if self._body is Empty:
+ if (body := self._connection_state.body) is not Empty:
+ self._body = body
+ else:
+ self._body = self._connection_state.body = b"".join([c async for c in self.stream()])
+ return self._body
+
+ async def form(self) -> FormMultiDict:
+ """Retrieve form data from the request. If the request is either a 'multipart/form-data' or an
+ 'application/x-www-form- urlencoded', return a FormMultiDict instance populated with the values sent in the
+ request, otherwise, an empty instance.
+
+ Returns:
+ A FormMultiDict instance
+ """
+ if self._form is Empty:
+ if (form := self._connection_state.form) is not Empty:
+ self._form = form
+ else:
+ content_type, options = self.content_type
+ if content_type == RequestEncodingType.MULTI_PART:
+ self._form = parse_multipart_form(
+ body=await self.body(),
+ boundary=options.get("boundary", "").encode(),
+ multipart_form_part_limit=self.app.multipart_form_part_limit,
+ )
+ elif content_type == RequestEncodingType.URL_ENCODED:
+ self._form = parse_url_encoded_form_data(
+ await self.body(),
+ )
+ else:
+ self._form = {}
+
+ self._connection_state.form = self._form
+
+ return FormMultiDict(self._form)
+
+ async def send_push_promise(self, path: str, raise_if_unavailable: bool = False) -> None:
+ """Send a push promise.
+
+ This method requires the `http.response.push` extension to be sent from the ASGI server.
+
+ Args:
+ path: Path to send the promise to.
+ raise_if_unavailable: Raise an exception if server push is not supported by
+ the server
+
+ Returns:
+ None
+ """
+ if not self.supports_push_promise:
+ if raise_if_unavailable:
+ raise LitestarException("Attempted to send a push promise but the server does not support it")
+
+ warnings.warn(
+ "Attempted to send a push promise but the server does not support it. In a future version, this will "
+ "raise an exception. To enable this behaviour in the current version, set raise_if_unavailable=True. "
+ "To prevent this behaviour, make sure that the server you are using supports the 'http.response.push' "
+ "ASGI extension, or check this dynamically via "
+ ":attr:`~litestar.connection.Request.supports_push_promise`",
+ stacklevel=2,
+ category=LitestarWarning,
+ )
+
+ return
+
+ raw_headers = [
+ (header_name.encode("latin-1"), value.encode("latin-1"))
+ for header_name in (self.headers.keys() & SERVER_PUSH_HEADERS)
+ for value in self.headers.getall(header_name, [])
+ ]
+ await self.send({"type": "http.response.push", "path": path, "headers": raw_headers})
diff --git a/venv/lib/python3.11/site-packages/litestar/connection/websocket.py b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py
new file mode 100644
index 0000000..0c7bc04
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/connection/websocket.py
@@ -0,0 +1,343 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Generic, Literal, cast, overload
+
+from litestar.connection.base import (
+ ASGIConnection,
+ AuthT,
+ StateT,
+ UserT,
+ empty_receive,
+ empty_send,
+)
+from litestar.datastructures.headers import Headers
+from litestar.exceptions import WebSocketDisconnect
+from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack
+from litestar.status_codes import WS_1000_NORMAL_CLOSURE
+
+__all__ = ("WebSocket",)
+
+
+if TYPE_CHECKING:
+ from litestar.handlers.websocket_handlers import WebsocketRouteHandler # noqa: F401
+ from litestar.types import Message, Serializer, WebSocketScope
+ from litestar.types.asgi_types import (
+ Receive,
+ ReceiveMessage,
+ Scope,
+ Send,
+ WebSocketAcceptEvent,
+ WebSocketCloseEvent,
+ WebSocketDisconnectEvent,
+ WebSocketMode,
+ WebSocketReceiveEvent,
+ WebSocketSendEvent,
+ )
+
+DISCONNECT_MESSAGE = "connection is disconnected"
+
+
+class WebSocket(Generic[UserT, AuthT, StateT], ASGIConnection["WebsocketRouteHandler", UserT, AuthT, StateT]):
+ """The Litestar WebSocket class."""
+
+ __slots__ = ("connection_state",)
+
+ scope: WebSocketScope # pyright: ignore
+ """The ASGI scope attached to the connection."""
+ receive: Receive
+ """The ASGI receive function."""
+ send: Send
+ """The ASGI send function."""
+
+ def __init__(self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send) -> None:
+ """Initialize ``WebSocket``.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+ """
+ super().__init__(scope, self.receive_wrapper(receive), self.send_wrapper(send))
+ self.connection_state: Literal["init", "connect", "receive", "disconnect"] = "init"
+
+ def receive_wrapper(self, receive: Receive) -> Receive:
+ """Wrap ``receive`` to set 'self.connection_state' and validate events.
+
+ Args:
+ receive: The ASGI receive function.
+
+ Returns:
+ An ASGI receive function.
+ """
+
+ async def wrapped_receive() -> ReceiveMessage:
+ if self.connection_state == "disconnect":
+ raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE)
+ message = await receive()
+ if message["type"] == "websocket.connect":
+ self.connection_state = "connect"
+ elif message["type"] == "websocket.receive":
+ self.connection_state = "receive"
+ else:
+ self.connection_state = "disconnect"
+ return message
+
+ return wrapped_receive
+
+ def send_wrapper(self, send: Send) -> Send:
+ """Wrap ``send`` to ensure that state is not disconnected.
+
+ Args:
+ send: The ASGI send function.
+
+ Returns:
+ An ASGI send function.
+ """
+
+ async def wrapped_send(message: Message) -> None:
+ if self.connection_state == "disconnect":
+ raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover
+ await send(message)
+
+ return wrapped_send
+
+ async def accept(
+ self,
+ subprotocols: str | None = None,
+ headers: Headers | dict[str, Any] | list[tuple[bytes, bytes]] | None = None,
+ ) -> None:
+ """Accept the incoming connection. This method should be called before receiving data.
+
+ Args:
+ subprotocols: Websocket sub-protocol to use.
+ headers: Headers to set on the data sent.
+
+ Returns:
+ None
+ """
+ if self.connection_state == "init":
+ await self.receive()
+ _headers: list[tuple[bytes, bytes]] = headers if isinstance(headers, list) else []
+
+ if isinstance(headers, dict):
+ _headers = Headers(headers=headers).to_header_list()
+
+ if isinstance(headers, Headers):
+ _headers = headers.to_header_list()
+
+ event: WebSocketAcceptEvent = {
+ "type": "websocket.accept",
+ "subprotocol": subprotocols,
+ "headers": _headers,
+ }
+ await self.send(event)
+
+ async def close(self, code: int = WS_1000_NORMAL_CLOSURE, reason: str | None = None) -> None:
+ """Send an 'websocket.close' event.
+
+ Args:
+ code: Status code.
+ reason: Reason for closing the connection
+
+ Returns:
+ None
+ """
+ event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason or ""}
+ await self.send(event)
+
+ @overload
+ async def receive_data(self, mode: Literal["text"]) -> str: ...
+
+ @overload
+ async def receive_data(self, mode: Literal["binary"]) -> bytes: ...
+
+ async def receive_data(self, mode: WebSocketMode) -> str | bytes:
+ """Receive an 'websocket.receive' event and returns the data stored on it.
+
+ Args:
+ mode: The respective event key to use.
+
+ Returns:
+ The event's data.
+ """
+ if self.connection_state == "init":
+ await self.accept()
+ event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive())
+ if event["type"] == "websocket.disconnect":
+ raise WebSocketDisconnect(detail="disconnect event", code=event["code"])
+ return event.get("text") or "" if mode == "text" else event.get("bytes") or b""
+
+ @overload
+ def iter_data(self, mode: Literal["text"]) -> AsyncGenerator[str, None]: ...
+
+ @overload
+ def iter_data(self, mode: Literal["binary"]) -> AsyncGenerator[bytes, None]: ...
+
+ async def iter_data(self, mode: WebSocketMode = "text") -> AsyncGenerator[str | bytes, None]:
+ """Continuously receive data and yield it
+
+ Args:
+ mode: Socket mode to use. Either ``text`` or ``binary``
+ """
+ try:
+ while True:
+ yield await self.receive_data(mode)
+ except WebSocketDisconnect:
+ pass
+
+ async def receive_text(self) -> str:
+ """Receive data as text.
+
+ Returns:
+ A string.
+ """
+ return await self.receive_data(mode="text")
+
+ async def receive_bytes(self) -> bytes:
+ """Receive data as bytes.
+
+ Returns:
+ A byte-string.
+ """
+ return await self.receive_data(mode="binary")
+
+ async def receive_json(self, mode: WebSocketMode = "text") -> Any:
+ """Receive data and decode it as JSON.
+
+ Args:
+ mode: Either ``text`` or ``binary``.
+
+ Returns:
+ An arbitrary value
+ """
+ data = await self.receive_data(mode=mode)
+ return decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def receive_msgpack(self) -> Any:
+ """Receive data and decode it as MessagePack.
+
+ Note that since MessagePack is a binary format, this method will always receive
+ data in ``binary`` mode.
+
+ Returns:
+ An arbitrary value
+ """
+ data = await self.receive_data(mode="binary")
+ return decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def iter_json(self, mode: WebSocketMode = "text") -> AsyncGenerator[Any, None]:
+ """Continuously receive data and yield it, decoding it as JSON in the process.
+
+ Args:
+ mode: Socket mode to use. Either ``text`` or ``binary``
+ """
+ async for data in self.iter_data(mode):
+ yield decode_json(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def iter_msgpack(self) -> AsyncGenerator[Any, None]:
+ """Continuously receive data and yield it, decoding it as MessagePack in the
+ process.
+
+ Note that since MessagePack is a binary format, this method will always receive
+ data in ``binary`` mode.
+
+ """
+ async for data in self.iter_data(mode="binary"):
+ yield decode_msgpack(value=data, type_decoders=self.route_handler.resolve_type_decoders())
+
+ async def send_data(self, data: str | bytes, mode: WebSocketMode = "text", encoding: str = "utf-8") -> None:
+ """Send a 'websocket.send' event.
+
+ Args:
+ data: Data to send.
+ mode: The respective event key to use.
+ encoding: Encoding to use when converting bytes / str.
+
+ Returns:
+ None
+ """
+ if self.connection_state == "init": # pragma: no cover
+ await self.accept()
+ event: WebSocketSendEvent = {"type": "websocket.send", "bytes": None, "text": None}
+ if mode == "binary":
+ event["bytes"] = data if isinstance(data, bytes) else data.encode(encoding)
+ else:
+ event["text"] = data if isinstance(data, str) else data.decode(encoding)
+ await self.send(event)
+
+ @overload
+ async def send_text(self, data: bytes, encoding: str = "utf-8") -> None: ...
+
+ @overload
+ async def send_text(self, data: str) -> None: ...
+
+ async def send_text(self, data: str | bytes, encoding: str = "utf-8") -> None:
+ """Send data using the ``text`` key of the send event.
+
+ Args:
+ data: Data to send
+ encoding: Encoding to use for binary data.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=data, encoding=encoding)
+
+ @overload
+ async def send_bytes(self, data: bytes) -> None: ...
+
+ @overload
+ async def send_bytes(self, data: str, encoding: str = "utf-8") -> None: ...
+
+ async def send_bytes(self, data: str | bytes, encoding: str = "utf-8") -> None:
+ """Send data using the ``bytes`` key of the send event.
+
+ Args:
+ data: Data to send
+ encoding: Encoding to use for binary data.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=data, mode="binary", encoding=encoding)
+
+ async def send_json(
+ self,
+ data: Any,
+ mode: WebSocketMode = "text",
+ encoding: str = "utf-8",
+ serializer: Serializer = default_serializer,
+ ) -> None:
+ """Send data as JSON.
+
+ Args:
+ data: A value to serialize.
+ mode: Either ``text`` or ``binary``.
+ encoding: Encoding to use for binary data.
+ serializer: A serializer function.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=encode_json(data, serializer), mode=mode, encoding=encoding)
+
+ async def send_msgpack(
+ self,
+ data: Any,
+ encoding: str = "utf-8",
+ serializer: Serializer = default_serializer,
+ ) -> None:
+ """Send data as MessagePack.
+
+ Note that since MessagePack is a binary format, this method will always send
+ data in ``binary`` mode.
+
+ Args:
+ data: A value to serialize.
+ encoding: Encoding to use for binary data.
+ serializer: A serializer function.
+
+ Returns:
+ None
+ """
+ await self.send_data(data=encode_msgpack(data, serializer), mode="binary", encoding=encoding)