diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers')
8 files changed, 686 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__init__.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__init__.py new file mode 100644 index 0000000..5b24734 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__init__.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from litestar.handlers.websocket_handlers.listener import ( + WebsocketListener, + WebsocketListenerRouteHandler, + websocket_listener, +) +from litestar.handlers.websocket_handlers.route_handler import WebsocketRouteHandler, websocket + +__all__ = ( + "WebsocketListener", + "WebsocketListenerRouteHandler", + "WebsocketRouteHandler", + "websocket", + "websocket_listener", +) diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/__init__.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..f6d1115 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/__init__.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/_utils.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/_utils.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..c5ae4c8 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/_utils.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/listener.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/listener.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..38b8219 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/listener.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/route_handler.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/route_handler.cpython-311.pyc Binary files differnew file mode 100644 index 0000000..0e92ccd --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/__pycache__/route_handler.cpython-311.pyc diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/_utils.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/_utils.py new file mode 100644 index 0000000..bcd90ac --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/_utils.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from functools import wraps +from inspect import Parameter, Signature +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict + +from msgspec.json import Encoder as JsonEncoder + +from litestar.di import Provide +from litestar.serialization import decode_json +from litestar.types.builtin_types import NoneType +from litestar.utils import ensure_async_callable +from litestar.utils.helpers import unwrap_partial + +if TYPE_CHECKING: + from litestar import WebSocket + from litestar.handlers.websocket_handlers.listener import WebsocketListenerRouteHandler + from litestar.types import AnyCallable + from litestar.utils.signature import ParsedSignature + + +def create_handle_receive(listener: WebsocketListenerRouteHandler) -> Callable[[WebSocket], Coroutine[Any, None, None]]: + if data_dto := listener.resolve_data_dto(): + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return data_dto(socket).decode_bytes( + received_data.encode("utf-8") if isinstance(received_data, str) else received_data + ) + + elif listener.parsed_data_field and listener.parsed_data_field.annotation is str: + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return received_data.decode("utf-8") if isinstance(received_data, bytes) else received_data + + elif listener.parsed_data_field and listener.parsed_data_field.annotation is bytes: + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return received_data.encode("utf-8") if isinstance(received_data, str) else received_data + + else: + + async def handle_receive(socket: WebSocket) -> Any: + received_data = await socket.receive_data(mode=listener._receive_mode) + return decode_json(value=received_data, type_decoders=socket.route_handler.resolve_type_decoders()) + + return handle_receive + + +def create_handle_send( + listener: WebsocketListenerRouteHandler, +) -> Callable[[WebSocket, Any], Coroutine[None, None, None]]: + json_encoder = JsonEncoder(enc_hook=listener.default_serializer) + + if return_dto := listener.resolve_return_dto(): + + async def handle_send(socket: WebSocket, data: Any) -> None: + encoded_data = return_dto(socket).data_to_encodable_type(data) + data = json_encoder.encode(encoded_data) + await socket.send_data(data=data, mode=listener._send_mode) + + elif listener.parsed_return_field.is_subclass_of((str, bytes)) or ( + listener.parsed_return_field.is_optional and listener.parsed_return_field.has_inner_subclass_of((str, bytes)) + ): + + async def handle_send(socket: WebSocket, data: Any) -> None: + await socket.send_data(data=data, mode=listener._send_mode) + + else: + + async def handle_send(socket: WebSocket, data: Any) -> None: + data = json_encoder.encode(data) + await socket.send_data(data=data, mode=listener._send_mode) + + return handle_send + + +class ListenerHandler: + __slots__ = ("_can_send_data", "_fn", "_listener", "_pass_socket") + + def __init__( + self, + listener: WebsocketListenerRouteHandler, + fn: AnyCallable, + parsed_signature: ParsedSignature, + namespace: dict[str, Any], + ) -> None: + self._can_send_data = not parsed_signature.return_type.is_subclass_of(NoneType) + self._fn = ensure_async_callable(fn) + self._listener = listener + self._pass_socket = "socket" in parsed_signature.parameters + + async def __call__( + self, + *args: Any, + socket: WebSocket, + connection_lifespan_dependencies: Dict[str, Any], # noqa: UP006 + **kwargs: Any, + ) -> None: + lifespan_mananger = self._listener._connection_lifespan or self._listener.default_connection_lifespan + handle_send = self._listener.resolve_send_handler() if self._can_send_data else None + handle_receive = self._listener.resolve_receive_handler() + + if self._pass_socket: + kwargs["socket"] = socket + + async with lifespan_mananger(**connection_lifespan_dependencies): + while True: + received_data = await handle_receive(socket) + data = await self._fn(*args, data=received_data, **kwargs) + if handle_send: + await handle_send(socket, data) + + +def create_handler_signature(callback_signature: Signature) -> Signature: + """Creates a :class:`Signature` for the handler function for signature modelling. + + This is required for two reasons: + + 1. the :class:`.handlers.WebsocketHandler` signature model cannot contain the ``data`` parameter, which is + required for :class:`.handlers.websocket_listener` handlers. + 2. the :class;`.handlers.WebsocketHandler` signature model must include the ``socket`` parameter, which is + optional for :class:`.handlers.websocket_listener` handlers. + + Args: + callback_signature: The :class:`Signature` of the listener callback. + + Returns: + The :class:`Signature` for the listener callback as required for signature modelling. + """ + new_params = [p for p in callback_signature.parameters.values() if p.name != "data"] + if "socket" not in callback_signature.parameters: + new_params.append(Parameter(name="socket", kind=Parameter.KEYWORD_ONLY, annotation="WebSocket")) + + new_params.append( + Parameter(name="connection_lifespan_dependencies", kind=Parameter.KEYWORD_ONLY, annotation="Dict[str, Any]") + ) + + return callback_signature.replace(parameters=new_params) + + +def create_stub_dependency(src: AnyCallable) -> Provide: + """Create a stub dependency, accepting any kwargs defined in ``src``, and + wrap it in ``Provide`` + """ + src = unwrap_partial(src) + + @wraps(src) + async def stub(**kwargs: Any) -> Dict[str, Any]: # noqa: UP006 + return kwargs + + return Provide(stub) diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/listener.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/listener.py new file mode 100644 index 0000000..86fefc9 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/listener.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Callable, + Dict, + Mapping, + Optional, + cast, + overload, +) + +from litestar._signature import SignatureModel +from litestar.connection import WebSocket +from litestar.exceptions import ImproperlyConfiguredException, WebSocketDisconnect +from litestar.types import ( + AnyCallable, + Dependencies, + Empty, + EmptyType, + ExceptionHandler, + Guard, + Middleware, + TypeEncodersMap, +) +from litestar.utils import ensure_async_callable +from litestar.utils.signature import ParsedSignature, get_fn_type_hints + +from ._utils import ( + ListenerHandler, + create_handle_receive, + create_handle_send, + create_handler_signature, + create_stub_dependency, +) +from .route_handler import WebsocketRouteHandler + +if TYPE_CHECKING: + from typing import Coroutine + + from typing_extensions import Self + + from litestar import Router + from litestar.dto import AbstractDTO + from litestar.types.asgi_types import WebSocketMode + from litestar.types.composite_types import TypeDecodersSequence + +__all__ = ("WebsocketListener", "WebsocketListenerRouteHandler", "websocket_listener") + + +class WebsocketListenerRouteHandler(WebsocketRouteHandler): + """A websocket listener that automatically accepts a connection, handles disconnects, + invokes a callback function every time new data is received and sends any data + returned + """ + + __slots__ = { + "connection_accept_handler": "Callback to accept a WebSocket connection. By default, calls WebSocket.accept", + "on_accept": "Callback invoked after a WebSocket connection has been accepted", + "on_disconnect": "Callback invoked after a WebSocket connection has been closed", + "weboscket_class": "WebSocket class", + "_connection_lifespan": None, + "_handle_receive": None, + "_handle_send": None, + "_receive_mode": None, + "_send_mode": None, + } + + @overload + def __init__( + self, + path: str | list[str] | None = None, + *, + connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + receive_mode: WebSocketMode = "text", + send_mode: WebSocketMode = "text", + name: str | None = None, + opt: dict[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: ... + + @overload + def __init__( + self, + path: str | list[str] | None = None, + *, + connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + receive_mode: WebSocketMode = "text", + send_mode: WebSocketMode = "text", + name: str | None = None, + on_accept: AnyCallable | None = None, + on_disconnect: AnyCallable | None = None, + opt: dict[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: ... + + def __init__( + self, + path: str | list[str] | None = None, + *, + connection_accept_handler: Callable[[WebSocket], Coroutine[Any, Any, None]] = WebSocket.accept, + connection_lifespan: Callable[..., AbstractAsyncContextManager[Any]] | None = None, + dependencies: Dependencies | None = None, + dto: type[AbstractDTO] | None | EmptyType = Empty, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + receive_mode: WebSocketMode = "text", + send_mode: WebSocketMode = "text", + name: str | None = None, + on_accept: AnyCallable | None = None, + on_disconnect: AnyCallable | None = None, + opt: dict[str, Any] | None = None, + return_dto: type[AbstractDTO] | None | EmptyType = Empty, + signature_namespace: Mapping[str, Any] | None = None, + type_decoders: TypeDecodersSequence | None = None, + type_encoders: TypeEncodersMap | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``WebsocketRouteHandler`` + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults + to ``/`` + connection_accept_handler: A callable that accepts a :class:`WebSocket <.connection.WebSocket>` instance + and returns a coroutine that when awaited, will accept the connection. Defaults to ``WebSocket.accept``. + connection_lifespan: An asynchronous context manager, handling the lifespan of the connection. By default, + it calls the ``connection_accept_handler``, ``on_connect`` and ``on_disconnect``. Can request any + dependencies, for example the :class:`WebSocket <.connection.WebSocket>` connection + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and + validation of request data. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + receive_mode: Websocket mode to receive data in, either `text` or `binary`. + send_mode: Websocket mode to receive data in, either `text` or `binary`. + name: A string identifying the route handler. + on_accept: Callback invoked after a connection has been accepted. Can request any dependencies, for example + the :class:`WebSocket <.connection.WebSocket>` connection + on_disconnect: Callback invoked after a connection has been closed. Can request any dependencies, for + example the :class:`WebSocket <.connection.WebSocket>` connection + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing + outbound response data. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature + modelling. + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ + if connection_lifespan and any([on_accept, on_disconnect, connection_accept_handler is not WebSocket.accept]): + raise ImproperlyConfiguredException( + "connection_lifespan can not be used with connection hooks " + "(on_accept, on_disconnect, connection_accept_handler)", + ) + + self._receive_mode: WebSocketMode = receive_mode + self._send_mode: WebSocketMode = send_mode + self._connection_lifespan = connection_lifespan + self._send_handler: Callable[[WebSocket, Any], Coroutine[None, None, None]] | EmptyType = Empty + self._receive_handler: Callable[[WebSocket], Any] | EmptyType = Empty + + self.connection_accept_handler = connection_accept_handler + self.on_accept = ensure_async_callable(on_accept) if on_accept else None + self.on_disconnect = ensure_async_callable(on_disconnect) if on_disconnect else None + self.type_decoders = type_decoders + self.type_encoders = type_encoders + self.websocket_class = websocket_class + + listener_dependencies = dict(dependencies or {}) + + listener_dependencies["connection_lifespan_dependencies"] = create_stub_dependency( + connection_lifespan or self.default_connection_lifespan + ) + + if self.on_accept: + listener_dependencies["on_accept_dependencies"] = create_stub_dependency(self.on_accept) + + if self.on_disconnect: + listener_dependencies["on_disconnect_dependencies"] = create_stub_dependency(self.on_disconnect) + + super().__init__( + path=path, + dependencies=listener_dependencies, + exception_handlers=exception_handlers, + guards=guards, + middleware=middleware, + name=name, + opt=opt, + signature_namespace=signature_namespace, + dto=dto, + return_dto=return_dto, + type_decoders=type_decoders, + type_encoders=type_encoders, + websocket_class=websocket_class, + **kwargs, + ) + + def __call__(self, fn: AnyCallable) -> Self: + parsed_signature = ParsedSignature.from_fn(fn, self.resolve_signature_namespace()) + + if "data" not in parsed_signature.parameters: + raise ImproperlyConfiguredException("Websocket listeners must accept a 'data' parameter") + + for param in ("request", "body"): + if param in parsed_signature.parameters: + raise ImproperlyConfiguredException(f"The {param} kwarg is not supported with websocket listeners") + + # we are manipulating the signature of the decorated function below, so we must store the original values for + # use elsewhere. + self._parsed_return_field = parsed_signature.return_type + self._parsed_data_field = parsed_signature.parameters.get("data") + self._parsed_fn_signature = ParsedSignature.from_signature( + create_handler_signature(parsed_signature.original_signature), + fn_type_hints={ + **get_fn_type_hints(fn, namespace=self.resolve_signature_namespace()), + **get_fn_type_hints(ListenerHandler.__call__, namespace=self.resolve_signature_namespace()), + }, + ) + + return super().__call__( + ListenerHandler( + listener=self, fn=fn, parsed_signature=parsed_signature, namespace=self.resolve_signature_namespace() + ) + ) + + def _validate_handler_function(self) -> None: + """Validate the route handler function once it's set by inspecting its return annotations.""" + # validation occurs in the call method + + @property + def signature_model(self) -> type[SignatureModel]: + """Get the signature model for the route handler. + + Returns: + A signature model for the route handler. + + """ + if self._signature_model is Empty: + self._signature_model = SignatureModel.create( + dependency_name_set=self.dependency_name_set, + fn=cast("AnyCallable", self.fn), + parsed_signature=self.parsed_fn_signature, + type_decoders=self.resolve_type_decoders(), + ) + return self._signature_model + + @asynccontextmanager + async def default_connection_lifespan( + self, + socket: WebSocket, + on_accept_dependencies: Optional[Dict[str, Any]] = None, # noqa: UP006, UP007 + on_disconnect_dependencies: Optional[Dict[str, Any]] = None, # noqa: UP006, UP007 + ) -> AsyncGenerator[None, None]: + """Handle the connection lifespan of a :class:`WebSocket <.connection.WebSocket>`. + + Args: + socket: The :class:`WebSocket <.connection.WebSocket>` connection + on_accept_dependencies: Dependencies requested by the :attr:`on_accept` hook + on_disconnect_dependencies: Dependencies requested by the :attr:`on_disconnect` hook + + By, default this will + + - Call :attr:`connection_accept_handler` to accept a connection + - Call :attr:`on_accept` if defined after a connection has been accepted + - Call :attr:`on_disconnect` upon leaving the context + """ + await self.connection_accept_handler(socket) + + if self.on_accept: + await self.on_accept(**(on_accept_dependencies or {})) + + try: + yield + except WebSocketDisconnect: + pass + finally: + if self.on_disconnect: + await self.on_disconnect(**(on_disconnect_dependencies or {})) + + def resolve_receive_handler(self) -> Callable[[WebSocket], Any]: + if self._receive_handler is Empty: + self._receive_handler = create_handle_receive(self) + return self._receive_handler + + def resolve_send_handler(self) -> Callable[[WebSocket, Any], Coroutine[None, None, None]]: + if self._send_handler is Empty: + self._send_handler = create_handle_send(self) + return self._send_handler + + +websocket_listener = WebsocketListenerRouteHandler + + +class WebsocketListener(ABC): + path: str | list[str] | None = None + """A path fragment for the route handler function or a sequence of path fragments. If not given defaults to ``/``""" + dependencies: Dependencies | None = None + """A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances.""" + dto: type[AbstractDTO] | None | EmptyType = Empty + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and validation of request data""" + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None + """A mapping of status codes and/or exception types to handler functions.""" + guards: list[Guard] | None = None + """A sequence of :class:`Guard <.types.Guard>` callables.""" + middleware: list[Middleware] | None = None + """A sequence of :class:`Middleware <.types.Middleware>`.""" + on_accept: AnyCallable | None = None + """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been accepted. Can receive any dependencies""" + on_disconnect: AnyCallable | None = None + """Called after a :class:`WebSocket <.connection.WebSocket>` connection has been disconnected. Can receive any dependencies""" + receive_mode: WebSocketMode = "text" + """:class:`WebSocket <.connection.WebSocket>` mode to receive data in, either ``text`` or ``binary``.""" + send_mode: WebSocketMode = "text" + """Websocket mode to send data in, either `text` or `binary`.""" + name: str | None = None + """A string identifying the route handler.""" + opt: dict[str, Any] | None = None + """ + A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or wherever you + have access to :class:`Request <.connection.Request>` or :class:`ASGI Scope <.types.Scope>`. + """ + return_dto: type[AbstractDTO] | None | EmptyType = Empty + """:class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing outbound response data.""" + signature_namespace: Mapping[str, Any] | None = None + """ + A mapping of names to types for use in forward reference resolution during signature modelling. + """ + type_decoders: TypeDecodersSequence | None = None + """ + type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec + hook for deserialization. + """ + type_encoders: TypeEncodersMap | None = None + """ + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + """ + websocket_class: type[WebSocket] | None = None + """ + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ + + def __init__(self, owner: Router) -> None: + """Initialize a WebsocketListener instance. + + Args: + owner: The :class:`Router <.router.Router>` instance that owns this listener. + """ + self._owner = owner + + def to_handler(self) -> WebsocketListenerRouteHandler: + handler = WebsocketListenerRouteHandler( + dependencies=self.dependencies, + dto=self.dto, + exception_handlers=self.exception_handlers, + guards=self.guards, + middleware=self.middleware, + send_mode=self.send_mode, + receive_mode=self.receive_mode, + name=self.name, + on_accept=self.on_accept, + on_disconnect=self.on_disconnect, + opt=self.opt, + path=self.path, + return_dto=self.return_dto, + signature_namespace=self.signature_namespace, + type_decoders=self.type_decoders, + type_encoders=self.type_encoders, + websocket_class=self.websocket_class, + )(self.on_receive) + handler.owner = self._owner + return handler + + @abstractmethod + def on_receive(self, *args: Any, **kwargs: Any) -> Any: + """Called after data has been received from the WebSocket. + + This should take a ``data`` argument, receiving the processed WebSocket data, + and can additionally include handler dependencies such as ``state``, or other + regular dependencies. + + Data returned from this function will be serialized and sent via the socket + according to handler configuration. + """ + raise NotImplementedError diff --git a/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/route_handler.py b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/route_handler.py new file mode 100644 index 0000000..edb49c3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/handlers/websocket_handlers/route_handler.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping + +from litestar.connection import WebSocket +from litestar.exceptions import ImproperlyConfiguredException +from litestar.handlers import BaseRouteHandler +from litestar.types.builtin_types import NoneType +from litestar.utils.predicates import is_async_callable + +if TYPE_CHECKING: + from litestar.types import Dependencies, ExceptionHandler, Guard, Middleware + + +class WebsocketRouteHandler(BaseRouteHandler): + """Websocket route handler decorator. + + Use this decorator to decorate websocket handler functions. + """ + + def __init__( + self, + path: str | list[str] | None = None, + *, + dependencies: Dependencies | None = None, + exception_handlers: dict[int | type[Exception], ExceptionHandler] | None = None, + guards: list[Guard] | None = None, + middleware: list[Middleware] | None = None, + name: str | None = None, + opt: dict[str, Any] | None = None, + signature_namespace: Mapping[str, Any] | None = None, + websocket_class: type[WebSocket] | None = None, + **kwargs: Any, + ) -> None: + """Initialize ``WebsocketRouteHandler`` + + Args: + path: A path fragment for the route handler function or a sequence of path fragments. If not given defaults + to ``/`` + dependencies: A string keyed mapping of dependency :class:`Provider <.di.Provide>` instances. + exception_handlers: A mapping of status codes and/or exception types to handler functions. + guards: A sequence of :class:`Guard <.types.Guard>` callables. + middleware: A sequence of :class:`Middleware <.types.Middleware>`. + name: A string identifying the route handler. + opt: A string keyed mapping of arbitrary values that can be accessed in :class:`Guards <.types.Guard>` or + wherever you have access to :class:`Request <.connection.Request>` or + :class:`ASGI Scope <.types.Scope>`. + signature_namespace: A mapping of names to types for use in forward reference resolution during signature modelling. + type_encoders: A mapping of types to callables that transform them into types supported for serialization. + **kwargs: Any additional kwarg - will be set in the opt dictionary. + websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as route handler's + default websocket class. + """ + self.websocket_class = websocket_class + + super().__init__( + path=path, + dependencies=dependencies, + exception_handlers=exception_handlers, + guards=guards, + middleware=middleware, + name=name, + opt=opt, + signature_namespace=signature_namespace, + **kwargs, + ) + + def resolve_websocket_class(self) -> type[WebSocket]: + """Return the closest custom WebSocket class in the owner graph or the default Websocket class. + + This method is memoized so the computation occurs only once. + + Returns: + The default :class:`WebSocket <.connection.WebSocket>` class for the route handler. + """ + return next( + (layer.websocket_class for layer in reversed(self.ownership_layers) if layer.websocket_class is not None), + WebSocket, + ) + + def _validate_handler_function(self) -> None: + """Validate the route handler function once it's set by inspecting its return annotations.""" + super()._validate_handler_function() + + if not self.parsed_fn_signature.return_type.is_subclass_of(NoneType): + raise ImproperlyConfiguredException("Websocket handler functions should return 'None'") + + if "socket" not in self.parsed_fn_signature.parameters: + raise ImproperlyConfiguredException("Websocket handlers must set a 'socket' kwarg") + + for param in ("request", "body", "data"): + if param in self.parsed_fn_signature.parameters: + raise ImproperlyConfiguredException(f"The {param} kwarg is not supported with websocket handlers") + + if not is_async_callable(self.fn): + raise ImproperlyConfiguredException("Functions decorated with 'websocket' must be async functions") + + +websocket = WebsocketRouteHandler |