1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
|
from __future__ import annotations
from itertools import chain
from typing import TYPE_CHECKING, Any, cast
from msgspec.msgpack import decode as _decode_msgpack_plain
from litestar.constants import DEFAULT_ALLOWED_CORS_HEADERS
from litestar.datastructures.headers import Headers
from litestar.datastructures.upload_file import UploadFile
from litestar.enums import HttpMethod, MediaType, ScopeType
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
from litestar.handlers.http_handlers import HTTPRouteHandler
from litestar.response import Response
from litestar.routes.base import BaseRoute
from litestar.status_codes import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
from litestar.types.empty import Empty
from litestar.utils.scope.state import ScopeState
if TYPE_CHECKING:
from litestar._kwargs import KwargsModel
from litestar._kwargs.cleanup import DependencyCleanupGroup
from litestar.connection import Request
from litestar.types import ASGIApp, HTTPScope, Method, Receive, Scope, Send
class HTTPRoute(BaseRoute):
"""An HTTP route, capable of handling multiple ``HTTPRouteHandler``\\ s.""" # noqa: D301
__slots__ = (
"route_handler_map",
"route_handlers",
)
def __init__(
self,
*,
path: str,
route_handlers: list[HTTPRouteHandler],
) -> None:
"""Initialize ``HTTPRoute``.
Args:
path: The path for the route.
route_handlers: A list of :class:`~.handlers.HTTPRouteHandler`.
"""
methods = list(chain.from_iterable([route_handler.http_methods for route_handler in route_handlers]))
if "OPTIONS" not in methods:
methods.append("OPTIONS")
options_handler = self.create_options_handler(path)
options_handler.owner = route_handlers[0].owner
route_handlers.append(options_handler)
self.route_handlers = route_handlers
self.route_handler_map: dict[Method, tuple[HTTPRouteHandler, KwargsModel]] = {}
super().__init__(
methods=methods,
path=path,
scope_type=ScopeType.HTTP,
handler_names=[route_handler.handler_name for route_handler in self.route_handlers],
)
async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None: # type: ignore[override]
"""ASGI app that creates a Request from the passed in args, determines which handler function to call and then
handles the call.
Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.
Returns:
None
"""
route_handler, parameter_model = self.route_handler_map[scope["method"]]
request: Request[Any, Any, Any] = route_handler.resolve_request_class()(scope=scope, receive=receive, send=send)
if route_handler.resolve_guards():
await route_handler.authorize_connection(connection=request)
response = await self._get_response_for_request(
scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
)
await response(scope, receive, send)
if after_response_handler := route_handler.resolve_after_response():
await after_response_handler(request)
if form_data := scope.get("_form", {}):
await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data))
def create_handler_map(self) -> None:
"""Parse the ``router_handlers`` of this route and return a mapping of
http- methods and route handlers.
"""
for route_handler in self.route_handlers:
kwargs_model = self.create_handler_kwargs_model(route_handler=route_handler)
for http_method in route_handler.http_methods:
if self.route_handler_map.get(http_method):
raise ImproperlyConfiguredException(
f"Handler already registered for path {self.path!r} and http method {http_method}"
)
self.route_handler_map[http_method] = (route_handler, kwargs_model)
async def _get_response_for_request(
self,
scope: Scope,
request: Request[Any, Any, Any],
route_handler: HTTPRouteHandler,
parameter_model: KwargsModel,
) -> ASGIApp:
"""Return a response for the request.
If caching is enabled and a response exist in the cache, the cached response will be returned.
If caching is enabled and a response does not exist in the cache, the newly created
response will be cached.
Args:
scope: The Request's scope
request: The Request instance
route_handler: The HTTPRouteHandler instance
parameter_model: The Handler's KwargsModel
Returns:
An instance of Response or a compatible ASGIApp or a subclass of it
"""
if route_handler.cache and (
response := await self._get_cached_response(request=request, route_handler=route_handler)
):
return response
return await self._call_handler_function(
scope=scope, request=request, parameter_model=parameter_model, route_handler=route_handler
)
async def _call_handler_function(
self, scope: Scope, request: Request, parameter_model: KwargsModel, route_handler: HTTPRouteHandler
) -> ASGIApp:
"""Call the before request handlers, retrieve any data required for the route handler, and call the route
handler's ``to_response`` method.
This is wrapped in a try except block - and if an exception is raised,
it tries to pass it to an appropriate exception handler - if defined.
"""
response_data: Any = None
cleanup_group: DependencyCleanupGroup | None = None
if before_request_handler := route_handler.resolve_before_request():
response_data = await before_request_handler(request)
if not response_data:
response_data, cleanup_group = await self._get_response_data(
route_handler=route_handler, parameter_model=parameter_model, request=request
)
response: ASGIApp = await route_handler.to_response(app=scope["app"], data=response_data, request=request)
if cleanup_group:
await cleanup_group.cleanup()
return response
@staticmethod
async def _get_response_data(
route_handler: HTTPRouteHandler, parameter_model: KwargsModel, request: Request
) -> tuple[Any, DependencyCleanupGroup | None]:
"""Determine what kwargs are required for the given route handler's ``fn`` and calls it."""
parsed_kwargs: dict[str, Any] = {}
cleanup_group: DependencyCleanupGroup | None = None
if parameter_model.has_kwargs and route_handler.signature_model:
kwargs = parameter_model.to_kwargs(connection=request)
if "data" in kwargs:
try:
data = await kwargs["data"]
except SerializationException as e:
raise ClientException(str(e)) from e
if data is Empty:
del kwargs["data"]
else:
kwargs["data"] = data
if "body" in kwargs:
kwargs["body"] = await kwargs["body"]
if parameter_model.dependency_batches:
cleanup_group = await parameter_model.resolve_dependencies(request, kwargs)
parsed_kwargs = route_handler.signature_model.parse_values_from_connection_kwargs(
connection=request, **kwargs
)
if cleanup_group:
async with cleanup_group:
data = (
route_handler.fn(**parsed_kwargs)
if route_handler.has_sync_callable
else await route_handler.fn(**parsed_kwargs)
)
elif route_handler.has_sync_callable:
data = route_handler.fn(**parsed_kwargs)
else:
data = await route_handler.fn(**parsed_kwargs)
return data, cleanup_group
@staticmethod
async def _get_cached_response(request: Request, route_handler: HTTPRouteHandler) -> ASGIApp | None:
"""Retrieve and un-pickle the cached response, if existing.
Args:
request: The :class:`Request <litestar.connection.Request>` instance
route_handler: The :class:`~.handlers.HTTPRouteHandler` instance
Returns:
A cached response instance, if existing.
"""
cache_config = request.app.response_cache_config
cache_key = (route_handler.cache_key_builder or cache_config.key_builder)(request)
store = cache_config.get_store_from_app(request.app)
if not (cached_response_data := await store.get(key=cache_key)):
return None
# we use the regular msgspec.msgpack.decode here since we don't need any of
# the added decoders
messages = _decode_msgpack_plain(cached_response_data)
async def cached_response(scope: Scope, receive: Receive, send: Send) -> None:
ScopeState.from_scope(scope).is_cached = True
for message in messages:
await send(message)
return cached_response
def create_options_handler(self, path: str) -> HTTPRouteHandler:
"""Args:
path: The route path
Returns:
An HTTP route handler for OPTIONS requests.
"""
def options_handler(scope: Scope) -> Response:
"""Handler function for OPTIONS requests.
Args:
scope: The ASGI Scope.
Returns:
Response
"""
cors_config = scope["app"].cors_config
request_headers = Headers.from_scope(scope=scope)
origin = request_headers.get("origin")
if cors_config and origin:
pre_flight_method = request_headers.get("Access-Control-Request-Method")
failures = []
if not cors_config.is_allow_all_methods and (
pre_flight_method and pre_flight_method not in cors_config.allow_methods
):
failures.append("method")
response_headers = cors_config.preflight_headers.copy()
if not cors_config.is_origin_allowed(origin):
failures.append("Origin")
elif response_headers.get("Access-Control-Allow-Origin") != "*":
response_headers["Access-Control-Allow-Origin"] = origin
pre_flight_requested_headers = [
header.strip()
for header in request_headers.get("Access-Control-Request-Headers", "").split(",")
if header.strip()
]
if pre_flight_requested_headers:
if cors_config.is_allow_all_headers:
response_headers["Access-Control-Allow-Headers"] = ", ".join(
sorted(set(pre_flight_requested_headers) | DEFAULT_ALLOWED_CORS_HEADERS) # pyright: ignore
)
elif any(
header.lower() not in cors_config.allow_headers for header in pre_flight_requested_headers
):
failures.append("headers")
return (
Response(
content=f"Disallowed CORS {', '.join(failures)}",
status_code=HTTP_400_BAD_REQUEST,
media_type=MediaType.TEXT,
)
if failures
else Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
media_type=MediaType.TEXT,
headers=response_headers,
)
)
return Response(
content=None,
status_code=HTTP_204_NO_CONTENT,
headers={"Allow": ", ".join(sorted(self.methods))}, # pyright: ignore
media_type=MediaType.TEXT,
)
return HTTPRouteHandler(
path=path,
http_method=[HttpMethod.OPTIONS],
include_in_schema=False,
sync_to_thread=False,
)(options_handler)
@staticmethod
async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None:
for v in form_data.values():
if isinstance(v, UploadFile) and not v.file.closed:
await v.close()
|