summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/routes/http.py
blob: b1f70cb36bed8c3c95d9629918cad6520b4a9846 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
from __future__ import annotations

from itertools import chain
from typing import TYPE_CHECKING, Any, cast

from msgspec.msgpack import decode as _decode_msgpack_plain

from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from litestar.datastructures.headers import Headers
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
from litestar.handlers.http_handlers import HTTPRouteHandler
from litestar.response import Response
from litestar.routes.base import BaseRoute
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.types.empty import Empty
from litestar.utils.scope.state import ScopeState

if TYPE_CHECKING:
    from litestar._kwargs import KwargsModel
    from litestar._kwargs.cleanup import DependencyCleanupGroup
    from litestar.connection import Request
    from litestar.types import ASGIApp, HTTPScope, Method, Receive, Scope, Send


class HTTPRoute(BaseRoute):
    """An HTTP route, capable of handling multiple ``HTTPRouteHandler``\\ s."""  # noqa: D301

    __slots__ = (
        "route_handler_map",
        "route_handlers",
    )

    def __init__(
        self,
        *,
        path: str,
        route_handlers: list[HTTPRouteHandler],
    ) -> None:
        """Initialize ``HTTPRoute``.

        Args:
            path: The path for the route.
            route_handlers: A list of :class:`~.handlers.HTTPRouteHandler`.
        """
        methods = list(chain.from_iterable([route_handler.http_methods for route_handler in route_handlers]))
        if "OPTIONS" not in methods:
            methods.append("OPTIONS")
            options_handler = self.create_options_handler(path)
            options_handler.owner = route_handlers[0].owner
            route_handlers.append(options_handler)

        self.route_handlers = route_handlers
        self.route_handler_map: dict[Method, tuple[HTTPRouteHandler, KwargsModel]] = {}

        super().__init__(
            methods=methods,
            path=path,
            scope_type=ScopeType.HTTP,
            handler_names=[route_handler.handler_name for route_handler in self.route_handlers],
        )

    async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None:  # type: ignore[override]
        """ASGI app that creates a Request from the passed in args, determines which handler function to call and then
        handles the call.

        Args:
            scope: The ASGI connection scope.
            receive: The ASGI receive function.
            send: The ASGI send function.

        Returns:
            None
        """
        route_handler, parameter_model = self.route_handler_map[scope["method"]]
        request: Request[Any, Any, Any] = route_handler.resolve_request_class()(scope=scope, receive=receive, send=send)

        if route_handler.resolve_guards():
            await route_handler.authorize_connection(connection=request)

        response = await self._get_response_for_request(
            scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
        )

        await response(scope, receive, send)

        if after_response_handler := route_handler.resolve_after_response():
            await after_response_handler(request)

        if form_data := scope.get("_form", {}):
            await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data))

    def create_handler_map(self) -> None:
        """Parse the ``router_handlers`` of this route and return a mapping of
        http- methods and route handlers.
        """
        for route_handler in self.route_handlers:
            kwargs_model = self.create_handler_kwargs_model(route_handler=route_handler)
            for http_method in route_handler.http_methods:
                if self.route_handler_map.get(http_method):
                    raise ImproperlyConfiguredException(
                        f"Handler already registered for path {self.path!r} and http method {http_method}"
                    )
                self.route_handler_map[http_method] = (route_handler, kwargs_model)

    async def _get_response_for_request(
        self,
        scope: Scope,
        request: Request[Any, Any, Any],
        route_handler: HTTPRouteHandler,
        parameter_model: KwargsModel,
    ) -> ASGIApp:
        """Return a response for the request.

        If caching is enabled and a response exist in the cache, the cached response will be returned.
        If caching is enabled and a response does not exist in the cache, the newly created
        response will be cached.

        Args:
            scope: The Request's scope
            request: The Request instance
            route_handler: The HTTPRouteHandler instance
            parameter_model: The Handler's KwargsModel

        Returns:
            An instance of Response or a compatible ASGIApp or a subclass of it
        """
        if route_handler.cache and (
            response := await self._get_cached_response(request=request, route_handler=route_handler)
        ):
            return response

        return await self._call_handler_function(
            scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler
        )

    async def _call_handler_function(
        self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler
    ) -> ASGIApp:
        """Call the before request handlers, retrieve any data required for the route handler, and call the route
        handler's ``to_response`` method.

        This is wrapped in a try except block - and if an exception is raised,
        it tries to pass it to an appropriate exception handler - if defined.
        """
        response_data: Any = None
        cleanup_group: DependencyCleanupGroup | None = None

        if before_request_handler := route_handler.resolve_before_request():
            response_data = await before_request_handler(request)

        if not response_data:
            response_data, cleanup_group = await self._get_response_data(
                route_handler=route_handler, parameter_model=parameter_model, request=request
            )

        response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request)

        if cleanup_group:
            await cleanup_group.cleanup()

        return response

    @staticmethod
    async def _get_response_data(
        route_handler: HTTPRouteHandler, parameter_model: KwargsModel, request: Request
    ) -> tuple[Any, DependencyCleanupGroup | None]:
        """Determine what kwargs are required for the given route handler's ``fn`` and calls it."""
        parsed_kwargs: dict[str, Any] = {}
        cleanup_group: DependencyCleanupGroup | None = None

        if parameter_model.has_kwargs and route_handler.signature_model:
            kwargs = parameter_model.to_kwargs(connection=request)

            if "data" in kwargs:
                try:
                    data = await kwargs["data"]
                except SerializationException as e:
                    raise ClientException(str(e)) from e

                if data is Empty:
                    del kwargs["data"]
                else:
                    kwargs["data"] = data

            if "body" in kwargs:
                kwargs["body"] = await kwargs["body"]

            if parameter_model.dependency_batches:
                cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)

            parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs(
                connection=request, **kwargs
            )

        if cleanup_group:
            async with cleanup_group:
                data = (
                    route_handler.fn(**parsed_kwargs)
                    if route_handler.has_sync_callable
                    else await route_handler.fn(**parsed_kwargs)
                )
        elif route_handler.has_sync_callable:
            data = route_handler.fn(**parsed_kwargs)
        else:
            data = await route_handler.fn(**parsed_kwargs)

        return data, cleanup_group

    @staticmethod
    async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler) -> ASGIApp | None:
        """Retrieve and un-pickle the cached response, if existing.

        Args:
            request: The :class:`Request <litestar.connection.Request>` instance
            route_handler: The :class:`~.handlers.HTTPRouteHandler` instance

        Returns:
            A cached response instance, if existing.
        """

        cache_config = request.app.response_cache_config
        cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request)
        store = cache_config.get_store_from_app(request.app)

        if not (cached_response_data := await store.get(key=cache_key)):
            return None

        # we use the regular msgspec.msgpack.decode here since we don't need any of
        # the added decoders
        messages = _decode_msgpack_plain(cached_response_data)

        async def cached_response(scope: Scope, receive: Receive, send: Send) -> None:
            ScopeState.from_scope(scope).is_cached = True
            for message in messages:
                await send(message)

        return cached_response

    def create_options_handler(self, path: str) -> HTTPRouteHandler:
        """Args:
            path: The route path

        Returns:
            An HTTP route handler for OPTIONS requests.
        """

        def options_handler(scope: Scope) -> Response:
            """Handler function for OPTIONS requests.

            Args:
                scope: The ASGI Scope.

            Returns:
                Response
            """
            cors_config = scope["app"].cors_config
            request_headers = Headers.from_scope(scope=scope)
            origin = request_headers.get("origin")

            if cors_config and origin:
                pre_flight_method = request_headers.get("Access-Control-Request-Method")
                failures = []

                if not cors_config.is_allow_all_methods and (
                    pre_flight_method and pre_flight_method not in cors_config.allow_methods
                ):
                    failures.append("method")

                response_headers = cors_config.preflight_headers.copy()

                if not cors_config.is_origin_allowed(origin):
                    failures.append("Origin")
                elif response_headers.get("Access-Control-Allow-Origin") != "*":
                    response_headers["Access-Control-Allow-Origin"] = origin

                pre_flight_requested_headers = [
                    header.strip()
                    for header in request_headers.get("Access-Control-Request-Headers", "").split(",")
                    if header.strip()
                ]

                if pre_flight_requested_headers:
                    if cors_config.is_allow_all_headers:
                        response_headers["Access-Control-Allow-Headers"] = ", ".join(
                            sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS)  # pyright: ignore
                        )
                    elif any(
                        header.lower() not in cors_config.allow_headers for header in pre_flight_requested_headers
                    ):
                        failures.append("headers")

                return (
                    Response(
                        content=f"Disallowed CORS {', '.join(failures)}",
                        status_code=HTTP_400_BAD_REQUEST,
                        media_type=MediaType.TEXT,
                    )
                    if failures
                    else Response(
                        content=None,
                        status_code=HTTP_204_NO_CONTENT,
                        media_type=MediaType.TEXT,
                        headers=response_headers,
                    )
                )

            return Response(
                content=None,
                status_code=HTTP_204_NO_CONTENT,
                headers={"Allow": ", ".join(sorted(self.methods))},  # pyright: ignore
                media_type=MediaType.TEXT,
            )

        return HTTPRouteHandler(
            path=path,
            http_method=[HttpMethod.OPTIONS],
            include_in_schema=False,
            sync_to_thread=False,
        )(options_handler)

    @staticmethod
    async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None:
        for v in form_data.values():
            if isinstance(v, UploadFile) and not v.file.closed:
                await v.close()