1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
|
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from litestar.enums import ScopeType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.routes.base import BaseRoute
if TYPE_CHECKING:
from litestar._kwargs import KwargsModel
from litestar._kwargs.cleanup import DependencyCleanupGroup
from litestar.connection import WebSocket
from litestar.handlers.websocket_handlers import WebsocketRouteHandler
from litestar.types import Receive, Send, WebSocketScope
class WebSocketRoute(BaseRoute):
"""A websocket route, handling a single ``WebsocketRouteHandler``"""
__slots__ = (
"route_handler",
"handler_parameter_model",
)
def __init__(
self,
*,
path: str,
route_handler: WebsocketRouteHandler,
) -> None:
"""Initialize the route.
Args:
path: The path for the route.
route_handler: An instance of :class:`~.handlers.WebsocketRouteHandler`.
"""
self.route_handler = route_handler
self.handler_parameter_model: KwargsModel | None = None
super().__init__(
path=path,
scope_type=ScopeType.WEBSOCKET,
handler_names=[route_handler.handler_name],
)
async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> None: # type: ignore[override]
"""ASGI app that creates a WebSocket from the passed in args, and then awaits the handler function.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None
"""
if not self.handler_parameter_model: # pragma: no cover
raise ImproperlyConfiguredException("handler parameter model not defined")
websocket: WebSocket[Any, Any, Any] = self.route_handler.resolve_websocket_class()(
scope=scope, receive=receive, send=send
)
if self.route_handler.resolve_guards():
await self.route_handler.authorize_connection(connection=websocket)
parsed_kwargs: dict[str, Any] = {}
cleanup_group: DependencyCleanupGroup | None = None
if self.handler_parameter_model.has_kwargs and self.route_handler.signature_model:
parsed_kwargs = self.handler_parameter_model.to_kwargs(connection=websocket)
if self.handler_parameter_model.dependency_batches:
cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs)
parsed_kwargs = self.route_handler.signature_model.parse_values_from_connection_kwargs(
connection=websocket, **parsed_kwargs
)
if cleanup_group:
async with cleanup_group:
await self.route_handler.fn(**parsed_kwargs)
await cleanup_group.cleanup()
else:
await self.route_handler.fn(**parsed_kwargs)
|