diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py b/venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py new file mode 100644 index 0000000..554b0f3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/dto/dataclass_dto.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from dataclasses import MISSING, fields, replace +from typing import TYPE_CHECKING, Generic, TypeVar + +from litestar.dto.base_dto import AbstractDTO +from litestar.dto.data_structures import DTOFieldDefinition +from litestar.dto.field import DTO_FIELD_META_KEY, DTOField +from litestar.params import DependencyKwarg, KwargDefinition +from litestar.types.empty import Empty + +if TYPE_CHECKING: + from typing import Collection, Generator + + from litestar.types.protocols import DataclassProtocol + from litestar.typing import FieldDefinition + + +__all__ = ("DataclassDTO", "T") + +T = TypeVar("T", bound="DataclassProtocol | Collection[DataclassProtocol]") +AnyDataclass = TypeVar("AnyDataclass", bound="DataclassProtocol") + + +class DataclassDTO(AbstractDTO[T], Generic[T]): + """Support for domain modelling with dataclasses.""" + + @classmethod + def generate_field_definitions( + cls, model_type: type[DataclassProtocol] + ) -> Generator[DTOFieldDefinition, None, None]: + dc_fields = {f.name: f for f in fields(model_type)} + for key, field_definition in cls.get_model_type_hints(model_type).items(): + if not (dc_field := dc_fields.get(key)): + continue + + default = dc_field.default if dc_field.default is not MISSING else Empty + default_factory = dc_field.default_factory if dc_field.default_factory is not MISSING else None + field_defintion = replace( + DTOFieldDefinition.from_field_definition( + field_definition=field_definition, + default_factory=default_factory, + dto_field=dc_field.metadata.get(DTO_FIELD_META_KEY, DTOField()), + model_name=model_type.__name__, + ), + name=key, + default=default, + ) + + yield ( + replace(field_defintion, default=Empty, kwarg_definition=default) + if isinstance(default, (KwargDefinition, DependencyKwarg)) + else field_defintion + ) + + @classmethod + def detect_nested_field(cls, field_definition: FieldDefinition) -> bool: + return hasattr(field_definition.annotation, "__dataclass_fields__") |