summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py184
1 files changed, 184 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py b/venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py
new file mode 100644
index 0000000..ebafaf0
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_asgi/asgi_router.py
@@ -0,0 +1,184 @@
+from __future__ import annotations
+
+import re
+from collections import defaultdict
+from functools import lru_cache
+from traceback import format_exc
+from typing import TYPE_CHECKING, Any, Pattern
+
+from litestar._asgi.routing_trie import validate_node
+from litestar._asgi.routing_trie.mapping import add_route_to_trie
+from litestar._asgi.routing_trie.traversal import parse_path_to_route
+from litestar._asgi.routing_trie.types import create_node
+from litestar._asgi.utils import get_route_handlers
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.utils import normalize_path
+
+__all__ = ("ASGIRouter",)
+
+
+if TYPE_CHECKING:
+ from litestar._asgi.routing_trie.types import RouteTrieNode
+ from litestar.app import Litestar
+ from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute
+ from litestar.routes.base import BaseRoute
+ from litestar.types import (
+ ASGIApp,
+ LifeSpanReceive,
+ LifeSpanSend,
+ LifeSpanShutdownCompleteEvent,
+ LifeSpanShutdownFailedEvent,
+ LifeSpanStartupCompleteEvent,
+ LifeSpanStartupFailedEvent,
+ Method,
+ Receive,
+ RouteHandlerType,
+ Scope,
+ Send,
+ )
+
+
+class ASGIRouter:
+ """Litestar ASGI router.
+
+ Handling both the ASGI lifespan events and routing of connection requests.
+ """
+
+ __slots__ = (
+ "_mount_paths_regex",
+ "_mount_routes",
+ "_plain_routes",
+ "_registered_routes",
+ "_static_routes",
+ "app",
+ "root_route_map_node",
+ "route_handler_index",
+ "route_mapping",
+ )
+
+ def __init__(self, app: Litestar) -> None:
+ """Initialize ``ASGIRouter``.
+
+ Args:
+ app: The Litestar app instance
+ """
+ self._mount_paths_regex: Pattern | None = None
+ self._mount_routes: dict[str, RouteTrieNode] = {}
+ self._plain_routes: set[str] = set()
+ self._registered_routes: set[HTTPRoute | WebSocketRoute | ASGIRoute] = set()
+ self.app = app
+ self.root_route_map_node: RouteTrieNode = create_node()
+ self.route_handler_index: dict[str, RouteHandlerType] = {}
+ self.route_mapping: dict[str, list[BaseRoute]] = defaultdict(list)
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ """ASGI callable.
+
+ The main entry point to the Router class.
+ """
+ scope.setdefault("path_params", {})
+
+ path = scope["path"]
+ if root_path := scope.get("root_path", ""):
+ path = path.split(root_path, maxsplit=1)[-1]
+ normalized_path = normalize_path(path)
+
+ asgi_app, scope["route_handler"], scope["path"], scope["path_params"] = self.handle_routing(
+ path=normalized_path, method=scope.get("method")
+ )
+ await asgi_app(scope, receive, send)
+
+ @lru_cache(1024) # noqa: B019
+ def handle_routing(self, path: str, method: Method | None) -> tuple[ASGIApp, RouteHandlerType, str, dict[str, Any]]:
+ """Handle routing for a given path / method combo. This method is meant to allow easy caching.
+
+ Args:
+ path: The path of the request.
+ method: The scope's method, if any.
+
+ Returns:
+ A tuple composed of the ASGIApp of the route, the route handler instance, the resolved and normalized path and any parsed path params.
+ """
+ return parse_path_to_route(
+ mount_paths_regex=self._mount_paths_regex,
+ mount_routes=self._mount_routes,
+ path=path,
+ plain_routes=self._plain_routes,
+ root_node=self.root_route_map_node,
+ method=method,
+ )
+
+ def _store_handler_to_route_mapping(self, route: BaseRoute) -> None:
+ """Store the mapping of route handlers to routes and to route handler names.
+
+ Args:
+ route: A Route instance.
+
+ Returns:
+ None
+ """
+
+ for handler in get_route_handlers(route):
+ if handler.name in self.route_handler_index and str(self.route_handler_index[handler.name]) != str(handler):
+ raise ImproperlyConfiguredException(
+ f"route handler names must be unique - {handler.name} is not unique."
+ )
+ identifier = handler.name or str(handler)
+ self.route_mapping[identifier].append(route)
+ self.route_handler_index[identifier] = handler
+
+ def construct_routing_trie(self) -> None:
+ """Create a map of the app's routes.
+
+ This map is used in the asgi router to route requests.
+ """
+ new_routes = [route for route in self.app.routes if route not in self._registered_routes]
+ for route in new_routes:
+ add_route_to_trie(
+ app=self.app,
+ mount_routes=self._mount_routes,
+ plain_routes=self._plain_routes,
+ root_node=self.root_route_map_node,
+ route=route,
+ )
+ self._store_handler_to_route_mapping(route)
+ self._registered_routes.add(route)
+
+ validate_node(node=self.root_route_map_node)
+ if self._mount_routes:
+ self._mount_paths_regex = re.compile("|".join(sorted(set(self._mount_routes)))) # pyright: ignore
+
+ async def lifespan(self, receive: LifeSpanReceive, send: LifeSpanSend) -> None:
+ """Handle the ASGI "lifespan" event on application startup and shutdown.
+
+ Args:
+ receive: The ASGI receive function.
+ send: The ASGI send function.
+
+ Returns:
+ None.
+ """
+
+ message = await receive()
+ shutdown_event: LifeSpanShutdownCompleteEvent = {"type": "lifespan.shutdown.complete"}
+ startup_event: LifeSpanStartupCompleteEvent = {"type": "lifespan.startup.complete"}
+
+ try:
+ async with self.app.lifespan():
+ await send(startup_event)
+ message = await receive()
+
+ except BaseException as e:
+ formatted_exception = format_exc()
+ failure_message: LifeSpanStartupFailedEvent | LifeSpanShutdownFailedEvent
+
+ if message["type"] == "lifespan.startup":
+ failure_message = {"type": "lifespan.startup.failed", "message": formatted_exception}
+ else:
+ failure_message = {"type": "lifespan.shutdown.failed", "message": formatted_exception}
+
+ await send(failure_message)
+
+ raise e
+
+ await send(shutdown_event)