1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
|
# ruff: noqa: UP006, UP007
from __future__ import annotations
import re
from functools import partial
from pathlib import Path, PurePath
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Literal,
Optional,
Sequence,
Set,
Type,
TypedDict,
Union,
cast,
)
from uuid import UUID
from msgspec import NODEFAULT, Meta, Struct, ValidationError, convert, defstruct
from msgspec.structs import asdict
from typing_extensions import Annotated
from litestar._signature.types import ExtendedMsgSpecValidationError
from litestar._signature.utils import (
_get_decoder_for_type,
_normalize_annotation,
_validate_signature_dependencies,
)
from litestar.datastructures.state import ImmutableState
from litestar.datastructures.url import URL
from litestar.dto import AbstractDTO, DTOData
from litestar.enums import ParamType, ScopeType
from litestar.exceptions import InternalServerException, ValidationException
from litestar.params import KwargDefinition, ParameterKwarg
from litestar.typing import FieldDefinition # noqa
from litestar.utils import is_class_and_subclass
from litestar.utils.dataclass import simple_asdict
if TYPE_CHECKING:
from typing_extensions import NotRequired
from litestar.connection import ASGIConnection
from litestar.types import AnyCallable, TypeDecodersSequence
from litestar.utils.signature import ParsedSignature
__all__ = (
"ErrorMessage",
"SignatureModel",
)
class ErrorMessage(TypedDict):
# key may not be set in some cases, like when a query param is set but
# doesn't match the required length during `attrs` validation
# in this case, we don't show a key at all as it will be empty
key: NotRequired[str]
message: str
source: NotRequired[Literal["body"] | ParamType]
MSGSPEC_CONSTRAINT_FIELDS = (
"gt",
"ge",
"lt",
"le",
"multiple_of",
"pattern",
"min_length",
"max_length",
)
ERR_RE = re.compile(r"`\$\.(.+)`$")
DEFAULT_TYPE_DECODERS = [
(lambda x: is_class_and_subclass(x, (Path, PurePath, ImmutableState, UUID)), lambda t, v: t(v)),
]
def _deserializer(target_type: Any, value: Any, default_deserializer: Callable[[Any, Any], Any]) -> Any:
if isinstance(value, DTOData):
return value
if isinstance(value, target_type):
return value
if decoder := getattr(target_type, "_decoder", None):
return decoder(target_type, value)
return default_deserializer(target_type, value)
class SignatureModel(Struct):
"""Model that represents a function signature that uses a msgspec specific type or types."""
_data_dto: ClassVar[Optional[Type[AbstractDTO]]]
_dependency_name_set: ClassVar[Set[str]]
_fields: ClassVar[Dict[str, FieldDefinition]]
_return_annotation: ClassVar[Any]
@classmethod
def _create_exception(cls, connection: ASGIConnection, messages: list[ErrorMessage]) -> Exception:
"""Create an exception class - either a ValidationException or an InternalServerException, depending on whether
the failure is in client provided values or injected dependencies.
Args:
connection: An ASGI connection instance.
messages: A list of error messages.
Returns:
An Exception
"""
method = connection.method if hasattr(connection, "method") else ScopeType.WEBSOCKET # pyright: ignore
if client_errors := [
err_message
for err_message in messages
if ("key" in err_message and err_message["key"] not in cls._dependency_name_set) or "key" not in err_message
]:
path = URL.from_components(
path=connection.url.path,
query=connection.url.query,
)
return ValidationException(detail=f"Validation failed for {method} {path}", extra=client_errors)
return InternalServerException()
@classmethod
def _build_error_message(cls, keys: Sequence[str], exc_msg: str, connection: ASGIConnection) -> ErrorMessage:
"""Build an error message.
Args:
keys: A list of keys.
exc_msg: A message.
connection: An ASGI connection instance.
Returns:
An ErrorMessage
"""
message: ErrorMessage = {"message": exc_msg.split(" - ")[0]}
if keys:
message["key"] = key = ".".join(keys)
if keys[0].startswith("data"):
message["key"] = message["key"].replace("data.", "")
message["source"] = "body"
elif key in connection.query_params:
message["source"] = ParamType.QUERY
elif key in cls._fields and isinstance(cls._fields[key].kwarg_definition, ParameterKwarg):
if cast(ParameterKwarg, cls._fields[key].kwarg_definition).cookie:
message["source"] = ParamType.COOKIE
elif cast(ParameterKwarg, cls._fields[key].kwarg_definition).header:
message["source"] = ParamType.HEADER
else:
message["source"] = ParamType.QUERY
return message
@classmethod
def _collect_errors(cls, deserializer: Callable[[Any, Any], Any], **kwargs: Any) -> list[tuple[str, Exception]]:
exceptions: list[tuple[str, Exception]] = []
for field_name in cls._fields:
try:
raw_value = kwargs[field_name]
annotation = cls.__annotations__[field_name]
convert(raw_value, type=annotation, strict=False, dec_hook=deserializer, str_keys=True)
except Exception as e: # noqa: BLE001
exceptions.append((field_name, e))
return exceptions
@classmethod
def parse_values_from_connection_kwargs(cls, connection: ASGIConnection, **kwargs: Any) -> dict[str, Any]:
"""Extract values from the connection instance and return a dict of parsed values.
Args:
connection: The ASGI connection instance.
**kwargs: A dictionary of kwargs.
Raises:
ValidationException: If validation failed.
InternalServerException: If another exception has been raised.
Returns:
A dictionary of parsed values
"""
messages: list[ErrorMessage] = []
deserializer = partial(_deserializer, default_deserializer=connection.route_handler.default_deserializer)
try:
return convert(kwargs, cls, strict=False, dec_hook=deserializer, str_keys=True).to_dict()
except ExtendedMsgSpecValidationError as e:
for exc in e.errors:
keys = [str(loc) for loc in exc["loc"]]
message = cls._build_error_message(keys=keys, exc_msg=exc["msg"], connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
except ValidationError as e:
for field_name, exc in cls._collect_errors(deserializer=deserializer, **kwargs): # type: ignore[assignment]
match = ERR_RE.search(str(exc))
keys = [field_name, str(match.group(1))] if match else [field_name]
message = cls._build_error_message(keys=keys, exc_msg=str(exc), connection=connection)
messages.append(message)
raise cls._create_exception(messages=messages, connection=connection) from e
def to_dict(self) -> dict[str, Any]:
"""Normalize access to the signature model's dictionary method, because different backends use different methods
for this.
Returns: A dictionary of string keyed values.
"""
return asdict(self)
@classmethod
def create(
cls,
dependency_name_set: set[str],
fn: AnyCallable,
parsed_signature: ParsedSignature,
type_decoders: TypeDecodersSequence,
data_dto: type[AbstractDTO] | None = None,
) -> type[SignatureModel]:
fn_name = (
fn_name if (fn_name := getattr(fn, "__name__", "anonymous")) and fn_name != "<lambda>" else "anonymous"
)
dependency_names = _validate_signature_dependencies(
dependency_name_set=dependency_name_set, fn_name=fn_name, parsed_signature=parsed_signature
)
struct_fields: list[tuple[str, Any, Any]] = []
for field_definition in parsed_signature.parameters.values():
meta_data: Meta | None = None
if isinstance(field_definition.kwarg_definition, KwargDefinition):
meta_kwargs: dict[str, Any] = {"extra": {}}
kwarg_definition = simple_asdict(field_definition.kwarg_definition, exclude_empty=True)
if min_items := kwarg_definition.pop("min_items", None):
meta_kwargs["min_length"] = min_items
if max_items := kwarg_definition.pop("max_items", None):
meta_kwargs["max_length"] = max_items
for k, v in kwarg_definition.items():
if hasattr(Meta, k) and v is not None:
meta_kwargs[k] = v
else:
meta_kwargs["extra"][k] = v
meta_data = Meta(**meta_kwargs)
annotation = cls._create_annotation(
field_definition=field_definition,
type_decoders=[*(type_decoders or []), *DEFAULT_TYPE_DECODERS],
meta_data=meta_data,
data_dto=data_dto,
)
default = field_definition.default if field_definition.has_default else NODEFAULT
struct_fields.append((field_definition.name, annotation, default))
return defstruct( # type:ignore[return-value]
f"{fn_name}_signature_model",
struct_fields,
bases=(cls,),
module=getattr(fn, "__module__", None),
namespace={
"_return_annotation": parsed_signature.return_type.annotation,
"_dependency_name_set": dependency_names,
"_fields": parsed_signature.parameters,
"_data_dto": data_dto,
},
kw_only=True,
)
@classmethod
def _create_annotation(
cls,
field_definition: FieldDefinition,
type_decoders: TypeDecodersSequence,
meta_data: Meta | None = None,
data_dto: type[AbstractDTO] | None = None,
) -> Any:
# DTOs have already validated their data, so we can just use Any here
if field_definition.name == "data" and data_dto:
return Any
annotation = _normalize_annotation(field_definition=field_definition)
if annotation is Any:
return annotation
if field_definition.is_union:
types = [
cls._create_annotation(
field_definition=inner_type,
type_decoders=type_decoders,
meta_data=meta_data,
)
for inner_type in field_definition.inner_types
if not inner_type.is_none_type
]
return Optional[Union[tuple(types)]] if field_definition.is_optional else Union[tuple(types)] # pyright: ignore
if decoder := _get_decoder_for_type(annotation, type_decoders=type_decoders):
# FIXME: temporary (hopefully) hack, see: https://github.com/jcrist/msgspec/issues/497
setattr(annotation, "_decoder", decoder)
if meta_data:
annotation = Annotated[annotation, meta_data] # pyright: ignore
return annotation
|