summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/routes
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/routes')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/__init__.py6
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/__pycache__/__init__.cpython-311.pycbin0 -> 461 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/__pycache__/asgi.cpython-311.pycbin0 -> 2874 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/__pycache__/base.cpython-311.pycbin0 -> 10362 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/__pycache__/http.cpython-311.pycbin0 -> 17069 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/__pycache__/websocket.cpython-311.pycbin0 -> 4417 bytes
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/asgi.py54
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/base.py195
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/http.py327
-rw-r--r--venv/lib/python3.11/site-packages/litestar/routes/websocket.py86
10 files changed, 668 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__init__.py b/venv/lib/python3.11/site-packages/litestar/routes/__init__.py
new file mode 100644
index 0000000..c8b5d3d
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/__init__.py
@@ -0,0 +1,6 @@
+from .asgi import ASGIRoute
+from .base import BaseRoute
+from .http import HTTPRoute
+from .websocket import WebSocketRoute
+
+__all__ = ("BaseRoute", "ASGIRoute", "WebSocketRoute", "HTTPRoute")
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/__init__.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000..ae8dee1
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/__init__.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/asgi.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/asgi.cpython-311.pyc
new file mode 100644
index 0000000..b3330b1
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/asgi.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/base.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/base.cpython-311.pyc
new file mode 100644
index 0000000..98b66b4
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/base.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/http.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/http.cpython-311.pyc
new file mode 100644
index 0000000..a430b70
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/http.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/websocket.cpython-311.pyc b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/websocket.cpython-311.pyc
new file mode 100644
index 0000000..79da4a4
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/__pycache__/websocket.cpython-311.pyc
Binary files differ
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/asgi.py b/venv/lib/python3.11/site-packages/litestar/routes/asgi.py
new file mode 100644
index 0000000..a8564d0
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/asgi.py
@@ -0,0 +1,54 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from litestar.connection import ASGIConnection
+from litestar.enums import ScopeType
+from litestar.routes.base import BaseRoute
+
+if TYPE_CHECKING:
+ from litestar.handlers.asgi_handlers import ASGIRouteHandler
+ from litestar.types import Receive, Scope, Send
+
+
+class ASGIRoute(BaseRoute):
+ """An ASGI route, handling a single ``ASGIRouteHandler``"""
+
+ __slots__ = ("route_handler",)
+
+ def __init__(
+ self,
+ *,
+ path: str,
+ route_handler: ASGIRouteHandler,
+ ) -> None:
+ """Initialize the route.
+
+ Args:
+ path: The path for the route.
+ route_handler: An instance of :class:`~.handlers.ASGIRouteHandler`.
+ """
+ self.route_handler = route_handler
+ super().__init__(
+ path=path,
+ scope_type=ScopeType.ASGI,
+ handler_names=[route_handler.handler_name],
+ )
+
+ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
+ """ASGI app that authorizes the connection and then awaits the handler function.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+
+ Returns:
+ None
+ """
+
+ if self.route_handler.resolve_guards():
+ connection = ASGIConnection["ASGIRouteHandler", Any, Any, Any](scope=scope, receive=receive)
+ await self.route_handler.authorize_connection(connection=connection)
+
+ await self.route_handler.fn(scope=scope, receive=receive, send=send)
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/base.py b/venv/lib/python3.11/site-packages/litestar/routes/base.py
new file mode 100644
index 0000000..b9baab4
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/base.py
@@ -0,0 +1,195 @@
+from __future__ import annotations
+
+import re
+from abc import ABC, abstractmethod
+from datetime import date, datetime, time, timedelta
+from decimal import Decimal
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable
+from uuid import UUID
+
+import msgspec
+
+from litestar._kwargs import KwargsModel
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.types.internal_types import PathParameterDefinition
+from litestar.utils import join_paths, normalize_path
+
+if TYPE_CHECKING:
+ from litestar.enums import ScopeType
+ from litestar.handlers.base import BaseRouteHandler
+ from litestar.types import Method, Receive, Scope, Send
+
+
+def _parse_datetime(value: str) -> datetime:
+ return msgspec.convert(value, datetime)
+
+
+def _parse_date(value: str) -> date:
+ return msgspec.convert(value, date)
+
+
+def _parse_time(value: str) -> time:
+ return msgspec.convert(value, time)
+
+
+def _parse_timedelta(value: str) -> timedelta:
+ try:
+ return msgspec.convert(value, timedelta)
+ except msgspec.ValidationError:
+ return timedelta(seconds=int(float(value)))
+
+
+param_match_regex = re.compile(r"{(.*?)}")
+
+param_type_map = {
+ "str": str,
+ "int": int,
+ "float": float,
+ "uuid": UUID,
+ "decimal": Decimal,
+ "date": date,
+ "datetime": datetime,
+ "time": time,
+ "timedelta": timedelta,
+ "path": Path,
+}
+
+
+parsers_map: dict[Any, Callable[[Any], Any]] = {
+ float: float,
+ int: int,
+ Decimal: Decimal,
+ UUID: UUID,
+ date: _parse_date,
+ datetime: _parse_datetime,
+ time: _parse_time,
+ timedelta: _parse_timedelta,
+}
+
+
+class BaseRoute(ABC):
+ """Base Route class used by Litestar.
+
+ It's an abstract class meant to be extended.
+ """
+
+ __slots__ = (
+ "app",
+ "handler_names",
+ "methods",
+ "path",
+ "path_format",
+ "path_parameters",
+ "path_components",
+ "scope_type",
+ )
+
+ def __init__(
+ self,
+ *,
+ handler_names: list[str],
+ path: str,
+ scope_type: ScopeType,
+ methods: list[Method] | None = None,
+ ) -> None:
+ """Initialize the route.
+
+ Args:
+ handler_names: Names of the associated handler functions
+ path: Base path of the route
+ scope_type: Type of the ASGI scope
+ methods: Supported methods
+ """
+ self.path, self.path_format, self.path_components = self._parse_path(path)
+ self.path_parameters: tuple[PathParameterDefinition, ...] = tuple(
+ component for component in self.path_components if isinstance(component, PathParameterDefinition)
+ )
+ self.handler_names = handler_names
+ self.scope_type = scope_type
+ self.methods = set(methods or [])
+
+ @abstractmethod
+ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
+ """ASGI App of the route.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+
+ Returns:
+ None
+ """
+ raise NotImplementedError("Route subclasses must implement handle which serves as the ASGI app entry point")
+
+ def create_handler_kwargs_model(self, route_handler: BaseRouteHandler) -> KwargsModel:
+ """Create a `KwargsModel` for a given route handler."""
+
+ path_parameters = set()
+ for param in self.path_parameters:
+ if param.name in path_parameters:
+ raise ImproperlyConfiguredException(f"Duplicate parameter '{param.name}' detected in '{self.path}'.")
+ path_parameters.add(param.name)
+
+ return KwargsModel.create_for_signature_model(
+ signature_model=route_handler.signature_model,
+ parsed_signature=route_handler.parsed_fn_signature,
+ dependencies=route_handler.resolve_dependencies(),
+ path_parameters=path_parameters,
+ layered_parameters=route_handler.resolve_layered_parameters(),
+ )
+
+ @staticmethod
+ def _validate_path_parameter(param: str, path: str) -> None:
+ """Validate that a path parameter adheres to the required format and datatypes.
+
+ Raises:
+ ImproperlyConfiguredException: If the parameter has an invalid format.
+ """
+ if len(param.split(":")) != 2:
+ raise ImproperlyConfiguredException(
+ f"Path parameters should be declared with a type using the following pattern: '{{parameter_name:type}}', e.g. '/my-path/{{my_param:int}}' in path: '{path}'"
+ )
+ param_name, param_type = (p.strip() for p in param.split(":"))
+ if not param_name:
+ raise ImproperlyConfiguredException("Path parameter names should be of length greater than zero")
+ if param_type not in param_type_map:
+ raise ImproperlyConfiguredException(
+ f"Path parameters should be declared with an allowed type, i.e. one of {', '.join(param_type_map.keys())} in path: '{path}'"
+ )
+
+ @classmethod
+ def _parse_path(cls, path: str) -> tuple[str, str, list[str | PathParameterDefinition]]:
+ """Normalize and parse a path.
+
+ Splits the path into a list of components, parsing any that are path parameters. Also builds the OpenAPI
+ compatible path, which does not include the type of the path parameters.
+
+ Returns:
+ A 3-tuple of the normalized path, the OpenAPI formatted path, and the list of parsed components.
+ """
+ path = normalize_path(path)
+
+ parsed_components: list[str | PathParameterDefinition] = []
+ path_format_components = []
+
+ components = [component for component in path.split("/") if component]
+ for component in components:
+ if param_match := param_match_regex.fullmatch(component):
+ param = param_match.group(1)
+ cls._validate_path_parameter(param, path)
+ param_name, param_type = (p.strip() for p in param.split(":"))
+ type_class = param_type_map[param_type]
+ parser = parsers_map[type_class] if type_class not in {str, Path} else None
+ parsed_components.append(
+ PathParameterDefinition(name=param_name, type=type_class, full=param, parser=parser)
+ )
+ path_format_components.append("{" + param_name + "}")
+ else:
+ parsed_components.append(component)
+ path_format_components.append(component)
+
+ path_format = join_paths(path_format_components)
+
+ return path, path_format, parsed_components
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/http.py b/venv/lib/python3.11/site-packages/litestar/routes/http.py
new file mode 100644
index 0000000..b1f70cb
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/http.py
@@ -0,0 +1,327 @@
+from __future__ import annotations
+
+from itertools import chain
+from typing import TYPE_CHECKING, Any, cast
+
+from msgspec.msgpack import decode as _decode_msgpack_plain
+
+from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
+from litestar.datastructures.headers import Headers
+from litestar.datastructures.upload_file import UploadFile
+from litestar.enums import HttpMethod, MediaType, ScopeType
+from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
+from litestar.handlers.http_handlers import HTTPRouteHandler
+from litestar.response import Response
+from litestar.routes.base import BaseRoute
+from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
+from litestar.types.empty import Empty
+from litestar.utils.scope.state import ScopeState
+
+if TYPE_CHECKING:
+ from litestar._kwargs import KwargsModel
+ from litestar._kwargs.cleanup import DependencyCleanupGroup
+ from litestar.connection import Request
+ from litestar.types import ASGIApp, HTTPScope, Method, Receive, Scope, Send
+
+
+class HTTPRoute(BaseRoute):
+ """An HTTP route, capable of handling multiple ``HTTPRouteHandler``\\ s.""" # noqa: D301
+
+ __slots__ = (
+ "route_handler_map",
+ "route_handlers",
+ )
+
+ def __init__(
+ self,
+ *,
+ path: str,
+ route_handlers: list[HTTPRouteHandler],
+ ) -> None:
+ """Initialize ``HTTPRoute``.
+
+ Args:
+ path: The path for the route.
+ route_handlers: A list of :class:`~.handlers.HTTPRouteHandler`.
+ """
+ methods = list(chain.from_iterable([route_handler.http_methods for route_handler in route_handlers]))
+ if "OPTIONS" not in methods:
+ methods.append("OPTIONS")
+ options_handler = self.create_options_handler(path)
+ options_handler.owner = route_handlers[0].owner
+ route_handlers.append(options_handler)
+
+ self.route_handlers = route_handlers
+ self.route_handler_map: dict[Method, tuple[HTTPRouteHandler, KwargsModel]] = {}
+
+ super().__init__(
+ methods=methods,
+ path=path,
+ scope_type=ScopeType.HTTP,
+ handler_names=[route_handler.handler_name for route_handler in self.route_handlers],
+ )
+
+ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: # type: ignore[override]
+ """ASGI app that creates a Request from the passed in args, determines which handler function to call and then
+ handles the call.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+
+ Returns:
+ None
+ """
+ route_handler, parameter_model = self.route_handler_map[scope["method"]]
+ request: Request[Any, Any, Any] = route_handler.resolve_request_class()(scope=scope, receive=receive, send=send)
+
+ if route_handler.resolve_guards():
+ await route_handler.authorize_connection(connection=request)
+
+ response = await self._get_response_for_request(
+ scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
+ )
+
+ await response(scope, receive, send)
+
+ if after_response_handler := route_handler.resolve_after_response():
+ await after_response_handler(request)
+
+ if form_data := scope.get("_form", {}):
+ await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data))
+
+ def create_handler_map(self) -> None:
+ """Parse the ``router_handlers`` of this route and return a mapping of
+ http- methods and route handlers.
+ """
+ for route_handler in self.route_handlers:
+ kwargs_model = self.create_handler_kwargs_model(route_handler=route_handler)
+ for http_method in route_handler.http_methods:
+ if self.route_handler_map.get(http_method):
+ raise ImproperlyConfiguredException(
+ f"Handler already registered for path {self.path!r} and http method {http_method}"
+ )
+ self.route_handler_map[http_method] = (route_handler, kwargs_model)
+
+ async def _get_response_for_request(
+ self,
+ scope: Scope,
+ request: Request[Any, Any, Any],
+ route_handler: HTTPRouteHandler,
+ parameter_model: KwargsModel,
+ ) -> ASGIApp:
+ """Return a response for the request.
+
+ If caching is enabled and a response exist in the cache, the cached response will be returned.
+ If caching is enabled and a response does not exist in the cache, the newly created
+ response will be cached.
+
+ Args:
+ scope: The Request's scope
+ request: The Request instance
+ route_handler: The HTTPRouteHandler instance
+ parameter_model: The Handler's KwargsModel
+
+ Returns:
+ An instance of Response or a compatible ASGIApp or a subclass of it
+ """
+ if route_handler.cache and (
+ response := await self._get_cached_response(request=request, route_handler=route_handler)
+ ):
+ return response
+
+ return await self._call_handler_function(
+ scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler
+ )
+
+ async def _call_handler_function(
+ self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler
+ ) -> ASGIApp:
+ """Call the before request handlers, retrieve any data required for the route handler, and call the route
+ handler's ``to_response`` method.
+
+ This is wrapped in a try except block - and if an exception is raised,
+ it tries to pass it to an appropriate exception handler - if defined.
+ """
+ response_data: Any = None
+ cleanup_group: DependencyCleanupGroup | None = None
+
+ if before_request_handler := route_handler.resolve_before_request():
+ response_data = await before_request_handler(request)
+
+ if not response_data:
+ response_data, cleanup_group = await self._get_response_data(
+ route_handler=route_handler, parameter_model=parameter_model, request=request
+ )
+
+ response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request)
+
+ if cleanup_group:
+ await cleanup_group.cleanup()
+
+ return response
+
+ @staticmethod
+ async def _get_response_data(
+ route_handler: HTTPRouteHandler, parameter_model: KwargsModel, request: Request
+ ) -> tuple[Any, DependencyCleanupGroup | None]:
+ """Determine what kwargs are required for the given route handler's ``fn`` and calls it."""
+ parsed_kwargs: dict[str, Any] = {}
+ cleanup_group: DependencyCleanupGroup | None = None
+
+ if parameter_model.has_kwargs and route_handler.signature_model:
+ kwargs = parameter_model.to_kwargs(connection=request)
+
+ if "data" in kwargs:
+ try:
+ data = await kwargs["data"]
+ except SerializationException as e:
+ raise ClientException(str(e)) from e
+
+ if data is Empty:
+ del kwargs["data"]
+ else:
+ kwargs["data"] = data
+
+ if "body" in kwargs:
+ kwargs["body"] = await kwargs["body"]
+
+ if parameter_model.dependency_batches:
+ cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)
+
+ parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs(
+ connection=request, **kwargs
+ )
+
+ if cleanup_group:
+ async with cleanup_group:
+ data = (
+ route_handler.fn(**parsed_kwargs)
+ if route_handler.has_sync_callable
+ else await route_handler.fn(**parsed_kwargs)
+ )
+ elif route_handler.has_sync_callable:
+ data = route_handler.fn(**parsed_kwargs)
+ else:
+ data = await route_handler.fn(**parsed_kwargs)
+
+ return data, cleanup_group
+
+ @staticmethod
+ async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler) -> ASGIApp | None:
+ """Retrieve and un-pickle the cached response, if existing.
+
+ Args:
+ request: The :class:`Request <litestar.connection.Request>` instance
+ route_handler: The :class:`~.handlers.HTTPRouteHandler` instance
+
+ Returns:
+ A cached response instance, if existing.
+ """
+
+ cache_config = request.app.response_cache_config
+ cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request)
+ store = cache_config.get_store_from_app(request.app)
+
+ if not (cached_response_data := await store.get(key=cache_key)):
+ return None
+
+ # we use the regular msgspec.msgpack.decode here since we don't need any of
+ # the added decoders
+ messages = _decode_msgpack_plain(cached_response_data)
+
+ async def cached_response(scope: Scope, receive: Receive, send: Send) -> None:
+ ScopeState.from_scope(scope).is_cached = True
+ for message in messages:
+ await send(message)
+
+ return cached_response
+
+ def create_options_handler(self, path: str) -> HTTPRouteHandler:
+ """Args:
+ path: The route path
+
+ Returns:
+ An HTTP route handler for OPTIONS requests.
+ """
+
+ def options_handler(scope: Scope) -> Response:
+ """Handler function for OPTIONS requests.
+
+ Args:
+ scope: The ASGI Scope.
+
+ Returns:
+ Response
+ """
+ cors_config = scope["app"].cors_config
+ request_headers = Headers.from_scope(scope=scope)
+ origin = request_headers.get("origin")
+
+ if cors_config and origin:
+ pre_flight_method = request_headers.get("Access-Control-Request-Method")
+ failures = []
+
+ if not cors_config.is_allow_all_methods and (
+ pre_flight_method and pre_flight_method not in cors_config.allow_methods
+ ):
+ failures.append("method")
+
+ response_headers = cors_config.preflight_headers.copy()
+
+ if not cors_config.is_origin_allowed(origin):
+ failures.append("Origin")
+ elif response_headers.get("Access-Control-Allow-Origin") != "*":
+ response_headers["Access-Control-Allow-Origin"] = origin
+
+ pre_flight_requested_headers = [
+ header.strip()
+ for header in request_headers.get("Access-Control-Request-Headers", "").split(",")
+ if header.strip()
+ ]
+
+ if pre_flight_requested_headers:
+ if cors_config.is_allow_all_headers:
+ response_headers["Access-Control-Allow-Headers"] = ", ".join(
+ sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore
+ )
+ elif any(
+ header.lower() not in cors_config.allow_headers for header in pre_flight_requested_headers
+ ):
+ failures.append("headers")
+
+ return (
+ Response(
+ content=f"Disallowed CORS {', '.join(failures)}",
+ status_code=HTTP_400_BAD_REQUEST,
+ media_type=MediaType.TEXT,
+ )
+ if failures
+ else Response(
+ content=None,
+ status_code=HTTP_204_NO_CONTENT,
+ media_type=MediaType.TEXT,
+ headers=response_headers,
+ )
+ )
+
+ return Response(
+ content=None,
+ status_code=HTTP_204_NO_CONTENT,
+ headers={"Allow": ", ".join(sorted(self.methods))}, # pyright: ignore
+ media_type=MediaType.TEXT,
+ )
+
+ return HTTPRouteHandler(
+ path=path,
+ http_method=[HttpMethod.OPTIONS],
+ include_in_schema=False,
+ sync_to_thread=False,
+ )(options_handler)
+
+ @staticmethod
+ async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None:
+ for v in form_data.values():
+ if isinstance(v, UploadFile) and not v.file.closed:
+ await v.close()
diff --git a/venv/lib/python3.11/site-packages/litestar/routes/websocket.py b/venv/lib/python3.11/site-packages/litestar/routes/websocket.py
new file mode 100644
index 0000000..ebf4959
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/routes/websocket.py
@@ -0,0 +1,86 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from litestar.enums import ScopeType
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.routes.base import BaseRoute
+
+if TYPE_CHECKING:
+ from litestar._kwargs import KwargsModel
+ from litestar._kwargs.cleanup import DependencyCleanupGroup
+ from litestar.connection import WebSocket
+ from litestar.handlers.websocket_handlers import WebsocketRouteHandler
+ from litestar.types import Receive, Send, WebSocketScope
+
+
+class WebSocketRoute(BaseRoute):
+ """A websocket route, handling a single ``WebsocketRouteHandler``"""
+
+ __slots__ = (
+ "route_handler",
+ "handler_parameter_model",
+ )
+
+ def __init__(
+ self,
+ *,
+ path: str,
+ route_handler: WebsocketRouteHandler,
+ ) -> None:
+ """Initialize the route.
+
+ Args:
+ path: The path for the route.
+ route_handler: An instance of :class:`~.handlers.WebsocketRouteHandler`.
+ """
+ self.route_handler = route_handler
+ self.handler_parameter_model: KwargsModel | None = None
+
+ super().__init__(
+ path=path,
+ scope_type=ScopeType.WEBSOCKET,
+ handler_names=[route_handler.handler_name],
+ )
+
+ async def handle(self, scope: WebSocketScope, receive: Receive, send: Send) -> None: # type: ignore[override]
+ """ASGI app that creates a WebSocket from the passed in args, and then awaits the handler function.
+
+ Args:
+ scope: The ASGI connection scope.
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+
+ Returns:
+ None
+ """
+
+ if not self.handler_parameter_model: # pragma: no cover
+ raise ImproperlyConfiguredException("handler parameter model not defined")
+
+ websocket: WebSocket[Any, Any, Any] = self.route_handler.resolve_websocket_class()(
+ scope=scope, receive=receive, send=send
+ )
+
+ if self.route_handler.resolve_guards():
+ await self.route_handler.authorize_connection(connection=websocket)
+
+ parsed_kwargs: dict[str, Any] = {}
+ cleanup_group: DependencyCleanupGroup | None = None
+
+ if self.handler_parameter_model.has_kwargs and self.route_handler.signature_model:
+ parsed_kwargs = self.handler_parameter_model.to_kwargs(connection=websocket)
+
+ if self.handler_parameter_model.dependency_batches:
+ cleanup_group = await self.handler_parameter_model.resolve_dependencies(websocket, parsed_kwargs)
+
+ parsed_kwargs = self.route_handler.signature_model.parse_values_from_connection_kwargs(
+ connection=websocket, **parsed_kwargs
+ )
+
+ if cleanup_group:
+ async with cleanup_group:
+ await self.route_handler.fn(**parsed_kwargs)
+ await cleanup_group.cleanup()
+ else:
+ await self.route_handler.fn(**parsed_kwargs)