summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/router.py
blob: 85346d8c1d71d585c89c00cb999ea49e69cbbcca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
from __future__ import annotations

from collections import defaultdict
from copy import copy, deepcopy
from typing import TYPE_CHECKING, Any, Mapping, Sequence, cast

from litestar._layers.utils import narrow_response_cookies, narrow_response_headers
from litestar.controller import Controller
from litestar.exceptions import ImproperlyConfiguredException
from litestar.handlers.asgi_handlers import ASGIRouteHandler
from litestar.handlers.http_handlers import HTTPRouteHandler
from litestar.handlers.websocket_handlers import WebsocketListener, WebsocketRouteHandler
from litestar.routes import ASGIRoute, HTTPRoute, WebSocketRoute
from litestar.types.empty import Empty
from litestar.utils import find_index, is_class_and_subclass, join_paths, normalize_path, unique
from litestar.utils.signature import add_types_to_signature_namespace
from litestar.utils.sync import ensure_async_callable

__all__ = ("Router",)


if TYPE_CHECKING:
    from litestar.connection import Request, WebSocket
    from litestar.datastructures import CacheControlHeader, ETag
    from litestar.dto import AbstractDTO
    from litestar.openapi.spec import SecurityRequirement
    from litestar.response import Response
    from litestar.routes import BaseRoute
    from litestar.types import (
        AfterRequestHookHandler,
        AfterResponseHookHandler,
        BeforeRequestHookHandler,
        ControllerRouterHandler,
        ExceptionHandlersMap,
        Guard,
        Middleware,
        ParametersMap,
        ResponseCookies,
        RouteHandlerMapItem,
        RouteHandlerType,
        TypeEncodersMap,
    )
    from litestar.types.composite_types import Dependencies, ResponseHeaders, TypeDecodersSequence
    from litestar.types.empty import EmptyType


class Router:
    """The Litestar Router class.

    A Router instance is used to group controller, routers and route handler functions under a shared path fragment
    """

    __slots__ = (
        "after_request",
        "after_response",
        "before_request",
        "cache_control",
        "dependencies",
        "dto",
        "etag",
        "exception_handlers",
        "guards",
        "include_in_schema",
        "middleware",
        "opt",
        "owner",
        "parameters",
        "path",
        "registered_route_handler_ids",
        "request_class",
        "response_class",
        "response_cookies",
        "response_headers",
        "return_dto",
        "routes",
        "security",
        "signature_namespace",
        "tags",
        "type_decoders",
        "type_encoders",
        "websocket_class",
    )

    def __init__(
        self,
        path: str,
        *,
        after_request: AfterRequestHookHandler | None = None,
        after_response: AfterResponseHookHandler | None = None,
        before_request: BeforeRequestHookHandler | None = None,
        cache_control: CacheControlHeader | None = None,
        dependencies: Dependencies | None = None,
        dto: type[AbstractDTO] | None | EmptyType = Empty,
        etag: ETag | None = None,
        exception_handlers: ExceptionHandlersMap | None = None,
        guards: Sequence[Guard] | None = None,
        include_in_schema: bool | EmptyType = Empty,
        middleware: Sequence[Middleware] | None = None,
        opt: Mapping[str, Any] | None = None,
        parameters: ParametersMap | None = None,
        request_class: type[Request] | None = None,
        response_class: type[Response] | None = None,
        response_cookies: ResponseCookies | None = None,
        response_headers: ResponseHeaders | None = None,
        return_dto: type[AbstractDTO] | None | EmptyType = Empty,
        route_handlers: Sequence[ControllerRouterHandler],
        security: Sequence[SecurityRequirement] | None = None,
        signature_namespace: Mapping[str, Any] | None = None,
        signature_types: Sequence[Any] | None = None,
        tags: Sequence[str] | None = None,
        type_decoders: TypeDecodersSequence | None = None,
        type_encoders: TypeEncodersMap | None = None,
        websocket_class: type[WebSocket] | None = None,
    ) -> None:
        """Initialize a ``Router``.

        Args:
            after_request: A sync or async function executed before a :class:`Request <.connection.Request>` is passed
                to any route handler. If this function returns a value, the request will not reach the route handler,
                and instead this value will be used.
            after_response: A sync or async function called after the response has been awaited. It receives the
                :class:`Request <.connection.Request>` object and should not return any values.
            before_request: A sync or async function called immediately before calling the route handler. Receives
                the :class:`litestar.connection.Request` instance and any non-``None`` return value is used for the
                response, bypassing the route handler.
            cache_control: A ``cache-control`` header of type
                :class:`CacheControlHeader <.datastructures.CacheControlHeader>` to add to route handlers of
                this router. Can be overridden by route handlers.
            dependencies: A string keyed mapping of dependency :class:`Provide <.di.Provide>` instances.
            dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for (de)serializing and
                validation of request data.
            etag: An ``etag`` header of type :class:`ETag <.datastructures.ETag>` to add to route handlers of this app.
            exception_handlers: A mapping of status codes and/or exception types to handler functions.
            guards: A sequence of :data:`Guard <.types.Guard>` callables.
            include_in_schema: A boolean flag dictating whether  the route handler should be documented in the OpenAPI schema.
            middleware: A sequence of :data:`Middleware <.types.Middleware>`.
            opt: A string keyed mapping of arbitrary values that can be accessed in :data:`Guards <.types.Guard>` or
                wherever you have access to :class:`Request <.connection.Request>` or
                :data:`ASGI Scope <.types.Scope>`.
            parameters: A mapping of :func:`Parameter <.params.Parameter>` definitions available to all application
                paths.
            path: A path fragment that is prefixed to all route handlers, controllers and other routers associated
                with the router instance.
            request_class: A custom subclass of :class:`Request <.connection.Request>` to be used as the default for
                all route handlers, controllers and other routers associated with the router instance.
            response_class: A custom subclass of :class:`Response <.response.Response>` to be used as the default for
                all route handlers, controllers and other routers associated with the router instance.
            response_cookies: A sequence of :class:`Cookie <.datastructures.Cookie>` instances.
            response_headers: A string keyed mapping of :class:`ResponseHeader <.datastructures.ResponseHeader>`
                instances.
            return_dto: :class:`AbstractDTO <.dto.base_dto.AbstractDTO>` to use for serializing
                outbound response data.
            route_handlers: A required sequence of route handlers, which can include instances of
                :class:`Router <.router.Router>`, subclasses of :class:`Controller <.controller.Controller>` or any
                function decorated by the route handler decorators.
            security: A sequence of dicts that will be added to the schema of all route handlers in the application.
                See :data:`SecurityRequirement <.openapi.spec.SecurityRequirement>`
                for details.
            signature_namespace: A mapping of names to types for use in forward reference resolution during signature modeling.
            signature_types: A sequence of types for use in forward reference resolution during signature modeling.
                These types will be added to the signature namespace using their ``__name__`` attribute.
            tags: A sequence of string tags that will be appended to the schema of all route handlers under the
                application.
            type_decoders: A sequence of tuples, each composed of a predicate testing for type identity and a msgspec hook for deserialization.
            type_encoders: A mapping of types to callables that transform them into types supported for serialization.
            websocket_class: A custom subclass of :class:`WebSocket <.connection.WebSocket>` to be used as the default for
                all route handlers, controllers and other routers associated with the router instance.
        """

        self.after_request = ensure_async_callable(after_request) if after_request else None  # pyright: ignore
        self.after_response = ensure_async_callable(after_response) if after_response else None
        self.before_request = ensure_async_callable(before_request) if before_request else None
        self.cache_control = cache_control
        self.dto = dto
        self.etag = etag
        self.dependencies = dict(dependencies or {})
        self.exception_handlers = dict(exception_handlers or {})
        self.guards = list(guards or [])
        self.include_in_schema = include_in_schema
        self.middleware = list(middleware or [])
        self.opt = dict(opt or {})
        self.owner: Router | None = None
        self.parameters = dict(parameters or {})
        self.path = normalize_path(path)
        self.request_class = request_class
        self.response_class = response_class
        self.response_cookies = narrow_response_cookies(response_cookies)
        self.response_headers = narrow_response_headers(response_headers)
        self.return_dto = return_dto
        self.routes: list[HTTPRoute | ASGIRoute | WebSocketRoute] = []
        self.security = list(security or [])
        self.signature_namespace = add_types_to_signature_namespace(
            signature_types or [], dict(signature_namespace or {})
        )
        self.tags = list(tags or [])
        self.registered_route_handler_ids: set[int] = set()
        self.type_encoders = dict(type_encoders) if type_encoders is not None else None
        self.type_decoders = list(type_decoders) if type_decoders is not None else None
        self.websocket_class = websocket_class

        for route_handler in route_handlers or []:
            self.register(value=route_handler)

    def register(self, value: ControllerRouterHandler) -> list[BaseRoute]:
        """Register a Controller, Route instance or RouteHandler on the router.

        Args:
            value: a subclass or instance of Controller, an instance of :class:`Router` or a function/method that has
                been decorated by any of the routing decorators, e.g. :class:`get <.handlers.get>`,
                :class:`post <.handlers.post>`.

        Returns:
            Collection of handlers added to the router.
        """
        validated_value = self._validate_registration_value(value)

        routes: list[BaseRoute] = []

        for route_path, handlers_map in self.get_route_handler_map(value=validated_value).items():
            path = join_paths([self.path, route_path])
            if http_handlers := unique(
                [handler for handler in handlers_map.values() if isinstance(handler, HTTPRouteHandler)]
            ):
                if existing_handlers := unique(
                    [
                        handler
                        for handler in self.route_handler_method_map.get(path, {}).values()
                        if isinstance(handler, HTTPRouteHandler)
                    ]
                ):
                    http_handlers.extend(existing_handlers)
                    existing_route_index = find_index(self.routes, lambda x: x.path == path)  # noqa: B023

                    if existing_route_index == -1:  # pragma: no cover
                        raise ImproperlyConfiguredException("unable to find_index existing route index")

                    route: WebSocketRoute | ASGIRoute | HTTPRoute = HTTPRoute(
                        path=path,
                        route_handlers=http_handlers,
                    )
                    self.routes[existing_route_index] = route
                else:
                    route = HTTPRoute(path=path, route_handlers=http_handlers)
                    self.routes.append(route)

                routes.append(route)

            if websocket_handler := handlers_map.get("websocket"):
                route = WebSocketRoute(path=path, route_handler=cast("WebsocketRouteHandler", websocket_handler))
                self.routes.append(route)
                routes.append(route)

            if asgi_handler := handlers_map.get("asgi"):
                route = ASGIRoute(path=path, route_handler=cast("ASGIRouteHandler", asgi_handler))
                self.routes.append(route)
                routes.append(route)

        return routes

    @property
    def route_handler_method_map(self) -> dict[str, RouteHandlerMapItem]:
        """Map route paths to :class:`RouteHandlerMapItem <litestar.types.internal_typ es.RouteHandlerMapItem>`

        Returns:
             A dictionary mapping paths to route handlers
        """
        route_map: dict[str, RouteHandlerMapItem] = defaultdict(dict)
        for route in self.routes:
            if isinstance(route, HTTPRoute):
                for route_handler in route.route_handlers:
                    for method in route_handler.http_methods:
                        route_map[route.path][method] = route_handler
            else:
                route_map[route.path]["websocket" if isinstance(route, WebSocketRoute) else "asgi"] = (
                    route.route_handler
                )

        return route_map

    @classmethod
    def get_route_handler_map(
        cls,
        value: Controller | RouteHandlerType | Router,
    ) -> dict[str, RouteHandlerMapItem]:
        """Map route handlers to HTTP methods."""
        if isinstance(value, Router):
            return value.route_handler_method_map

        if isinstance(value, (HTTPRouteHandler, ASGIRouteHandler, WebsocketRouteHandler)):
            copied_value = copy(value)
            if isinstance(value, HTTPRouteHandler):
                return {path: {http_method: copied_value for http_method in value.http_methods} for path in value.paths}

            return {
                path: {"websocket" if isinstance(value, WebsocketRouteHandler) else "asgi": copied_value}
                for path in value.paths
            }

        handlers_map: defaultdict[str, RouteHandlerMapItem] = defaultdict(dict)
        for route_handler in value.get_route_handlers():
            for handler_path in route_handler.paths:
                path = join_paths([value.path, handler_path]) if handler_path else value.path
                if isinstance(route_handler, HTTPRouteHandler):
                    for http_method in route_handler.http_methods:
                        handlers_map[path][http_method] = route_handler
                else:
                    handlers_map[path]["websocket" if isinstance(route_handler, WebsocketRouteHandler) else "asgi"] = (
                        cast("WebsocketRouteHandler | ASGIRouteHandler", route_handler)
                    )

        return handlers_map

    def _validate_registration_value(self, value: ControllerRouterHandler) -> Controller | RouteHandlerType | Router:
        """Ensure values passed to the register method are supported."""
        if is_class_and_subclass(value, Controller):
            return value(owner=self)

        # this narrows down to an ABC, but we assume a non-abstract subclass of the ABC superclass
        if is_class_and_subclass(value, WebsocketListener):
            return value(owner=self).to_handler()  # pyright: ignore

        if isinstance(value, Router):
            if value is self:
                raise ImproperlyConfiguredException("Cannot register a router on itself")

            router_copy = deepcopy(value)
            router_copy.owner = self
            return router_copy

        if isinstance(value, (ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler)):
            value.owner = self
            return value

        raise ImproperlyConfiguredException(
            "Unsupported value passed to `Router.register`. "
            "If you passed in a function or method, "
            "make sure to decorate it first with one of the routing decorators"
        )