diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py b/venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py new file mode 100644 index 0000000..69cd3f3 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/openapi/spec/base.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import asdict, dataclass, fields, is_dataclass +from enum import Enum +from typing import Any + +__all__ = ("BaseSchemaObject",) + + +def _normalize_key(key: str) -> str: + if key.endswith("_in"): + return "in" + if key.startswith("schema_"): + return key.split("_")[1] + if "_" in key: + components = key.split("_") + return components[0] + "".join(component.title() for component in components[1:]) + return "$ref" if key == "ref" else key + + +def _normalize_value(value: Any) -> Any: + if isinstance(value, BaseSchemaObject): + return value.to_schema() + if is_dataclass(value): + return {_normalize_value(k): _normalize_value(v) for k, v in asdict(value).items() if v is not None} + if isinstance(value, dict): + return {_normalize_value(k): _normalize_value(v) for k, v in value.items() if v is not None} + if isinstance(value, list): + return [_normalize_value(v) for v in value] + return value.value if isinstance(value, Enum) else value + + +@dataclass +class BaseSchemaObject: + """Base class for schema spec objects""" + + def to_schema(self) -> dict[str, Any]: + """Transform the spec dataclass object into a string keyed dictionary. This method traverses all nested values + recursively. + """ + result: dict[str, Any] = {} + + for field in fields(self): + value = _normalize_value(getattr(self, field.name, None)) + + if value is not None: + if "alias" in field.metadata: + if not isinstance(field.metadata["alias"], str): + raise TypeError('metadata["alias"] must be a str') + key = field.metadata["alias"] + else: + key = _normalize_key(field.name) + + result[key] = value + + return result |