summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py
blob: 69cd3f3605f21d1baa15288baa20147c42d0260f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import annotations

from dataclasses import asdict, dataclass, fields, is_dataclass
from enum import Enum
from typing import Any

__all__ = ("BaseSchemaObject",)


def _normalize_key(key: str) -> str:
    if key.endswith("_in"):
        return "in"
    if key.startswith("schema_"):
        return key.split("_")[1]
    if "_" in key:
        components = key.split("_")
        return components[0] + "".join(component.title() for component in components[1:])
    return "$ref" if key == "ref" else key


def _normalize_value(value: Any) -> Any:
    if isinstance(value, BaseSchemaObject):
        return value.to_schema()
    if is_dataclass(value):
        return {_normalize_value(k): _normalize_value(v) for k, v in asdict(value).items() if v is not None}
    if isinstance(value, dict):
        return {_normalize_value(k): _normalize_value(v) for k, v in value.items() if v is not None}
    if isinstance(value, list):
        return [_normalize_value(v) for v in value]
    return value.value if isinstance(value, Enum) else value


@dataclass
class BaseSchemaObject:
    """Base class for schema spec objects"""

    def to_schema(self) -> dict[str, Any]:
        """Transform the spec dataclass object into a string keyed dictionary. This method traverses all nested values
        recursively.
        """
        result: dict[str, Any] = {}

        for field in fields(self):
            value = _normalize_value(getattr(self, field.name, None))

            if value is not None:
                if "alias" in field.metadata:
                    if not isinstance(field.metadata["alias"], str):
                        raise TypeError('metadata["alias"] must be a str')
                    key = field.metadata["alias"]
                else:
                    key = _normalize_key(field.name)

                result[key] = value

        return result