summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/msgspec/_json_schema.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/msgspec/_json_schema.py')
-rw-r--r--venv/lib/python3.11/site-packages/msgspec/_json_schema.py439
1 files changed, 439 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/msgspec/_json_schema.py b/venv/lib/python3.11/site-packages/msgspec/_json_schema.py
new file mode 100644
index 0000000..be506e3
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/msgspec/_json_schema.py
@@ -0,0 +1,439 @@
+from __future__ import annotations
+
+import re
+import textwrap
+from collections.abc import Iterable
+from typing import Any, Optional, Callable
+
+from . import inspect as mi, to_builtins
+
+__all__ = ("schema", "schema_components")
+
+
+def schema(
+ type: Any, *, schema_hook: Optional[Callable[[type], dict[str, Any]]] = None
+) -> dict[str, Any]:
+ """Generate a JSON Schema for a given type.
+
+ Any schemas for (potentially) shared components are extracted and stored in
+ a top-level ``"$defs"`` field.
+
+ If you want to generate schemas for multiple types, or to have more control
+ over the generated schema you may want to use ``schema_components`` instead.
+
+ Parameters
+ ----------
+ type : type
+ The type to generate the schema for.
+ schema_hook : callable, optional
+ An optional callback to use for generating JSON schemas of custom
+ types. Will be called with the custom type, and should return a dict
+ representation of the JSON schema for that type.
+
+ Returns
+ -------
+ schema : dict
+ The generated JSON Schema.
+
+ See Also
+ --------
+ schema_components
+ """
+ (out,), components = schema_components((type,), schema_hook=schema_hook)
+ if components:
+ out["$defs"] = components
+ return out
+
+
+def schema_components(
+ types: Iterable[Any],
+ *,
+ schema_hook: Optional[Callable[[type], dict[str, Any]]] = None,
+ ref_template: str = "#/$defs/{name}",
+) -> tuple[tuple[dict[str, Any], ...], dict[str, Any]]:
+ """Generate JSON Schemas for one or more types.
+
+ Any schemas for (potentially) shared components are extracted and returned
+ in a separate ``components`` dict.
+
+ Parameters
+ ----------
+ types : Iterable[type]
+ An iterable of one or more types to generate schemas for.
+ schema_hook : callable, optional
+ An optional callback to use for generating JSON schemas of custom
+ types. Will be called with the custom type, and should return a dict
+ representation of the JSON schema for that type.
+ ref_template : str, optional
+ A template to use when generating ``"$ref"`` fields. This template is
+ formatted with the type name as ``template.format(name=name)``. This
+ can be useful if you intend to store the ``components`` mapping
+ somewhere other than a top-level ``"$defs"`` field. For example, you
+ might use ``ref_template="#/components/{name}"`` if generating an
+ OpenAPI schema.
+
+ Returns
+ -------
+ schemas : tuple[dict]
+ A tuple of JSON Schemas, one for each type in ``types``.
+ components : dict
+ A mapping of name to schema for any shared components used by
+ ``schemas``.
+
+ See Also
+ --------
+ schema
+ """
+ type_infos = mi.multi_type_info(types)
+
+ component_types = _collect_component_types(type_infos)
+
+ name_map = _build_name_map(component_types)
+
+ gen = _SchemaGenerator(name_map, schema_hook, ref_template)
+
+ schemas = tuple(gen.to_schema(t) for t in type_infos)
+
+ components = {
+ name_map[cls]: gen.to_schema(t, False) for cls, t in component_types.items()
+ }
+ return schemas, components
+
+
+def _collect_component_types(type_infos: Iterable[mi.Type]) -> dict[Any, mi.Type]:
+ """Find all types in the type tree that are "nameable" and worthy of being
+ extracted out into a shared top-level components mapping.
+
+ Currently this looks for Struct, Dataclass, NamedTuple, TypedDict, and Enum
+ types.
+ """
+ components = {}
+
+ def collect(t):
+ if isinstance(
+ t, (mi.StructType, mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)
+ ):
+ if t.cls not in components:
+ components[t.cls] = t
+ for f in t.fields:
+ collect(f.type)
+ elif isinstance(t, mi.EnumType):
+ components[t.cls] = t
+ elif isinstance(t, mi.Metadata):
+ collect(t.type)
+ elif isinstance(t, mi.CollectionType):
+ collect(t.item_type)
+ elif isinstance(t, mi.TupleType):
+ for st in t.item_types:
+ collect(st)
+ elif isinstance(t, mi.DictType):
+ collect(t.key_type)
+ collect(t.value_type)
+ elif isinstance(t, mi.UnionType):
+ for st in t.types:
+ collect(st)
+
+ for t in type_infos:
+ collect(t)
+
+ return components
+
+
+def _type_repr(obj):
+ return obj.__name__ if isinstance(obj, type) else repr(obj)
+
+
+def _get_class_name(cls: Any) -> str:
+ if hasattr(cls, "__origin__"):
+ name = cls.__origin__.__name__
+ args = ", ".join(_type_repr(a) for a in cls.__args__)
+ return f"{name}[{args}]"
+ return cls.__name__
+
+
+def _get_doc(t: mi.Type) -> str:
+ assert hasattr(t, "cls")
+ cls = getattr(t.cls, "__origin__", t.cls)
+ doc = getattr(cls, "__doc__", "")
+ if not doc:
+ return ""
+ doc = textwrap.dedent(doc).strip("\r\n")
+ if isinstance(t, mi.EnumType):
+ if doc == "An enumeration.":
+ return ""
+ elif isinstance(t, (mi.NamedTupleType, mi.DataclassType)):
+ if doc.startswith(f"{cls.__name__}(") and doc.endswith(")"):
+ return ""
+ return doc
+
+
+def _build_name_map(component_types: dict[Any, mi.Type]) -> dict[Any, str]:
+ """A mapping from nameable subcomponents to a generated name.
+
+ The generated name is usually a normalized version of the class name. In
+ the case of conflicts, the name will be expanded to also include the full
+ import path.
+ """
+
+ def normalize(name):
+ return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
+
+ def fullname(cls):
+ return normalize(f"{cls.__module__}.{cls.__qualname__}")
+
+ conflicts = set()
+ names: dict[str, Any] = {}
+
+ for cls in component_types:
+ name = normalize(_get_class_name(cls))
+ if name in names:
+ old = names.pop(name)
+ conflicts.add(name)
+ names[fullname(old)] = old
+ if name in conflicts:
+ names[fullname(cls)] = cls
+ else:
+ names[name] = cls
+ return {v: k for k, v in names.items()}
+
+
+class _SchemaGenerator:
+ def __init__(
+ self,
+ name_map: dict[Any, str],
+ schema_hook: Optional[Callable[[type], dict[str, Any]]] = None,
+ ref_template: str = "#/$defs/{name}",
+ ):
+ self.name_map = name_map
+ self.schema_hook = schema_hook
+ self.ref_template = ref_template
+
+ def to_schema(self, t: mi.Type, check_ref: bool = True) -> dict[str, Any]:
+ """Converts a Type to a json-schema."""
+ schema: dict[str, Any] = {}
+
+ while isinstance(t, mi.Metadata):
+ schema = mi._merge_json(schema, t.extra_json_schema)
+ t = t.type
+
+ if check_ref and hasattr(t, "cls"):
+ if name := self.name_map.get(t.cls):
+ schema["$ref"] = self.ref_template.format(name=name)
+ return schema
+
+ if isinstance(t, (mi.AnyType, mi.RawType)):
+ pass
+ elif isinstance(t, mi.NoneType):
+ schema["type"] = "null"
+ elif isinstance(t, mi.BoolType):
+ schema["type"] = "boolean"
+ elif isinstance(t, (mi.IntType, mi.FloatType)):
+ schema["type"] = "integer" if isinstance(t, mi.IntType) else "number"
+ if t.ge is not None:
+ schema["minimum"] = t.ge
+ if t.gt is not None:
+ schema["exclusiveMinimum"] = t.gt
+ if t.le is not None:
+ schema["maximum"] = t.le
+ if t.lt is not None:
+ schema["exclusiveMaximum"] = t.lt
+ if t.multiple_of is not None:
+ schema["multipleOf"] = t.multiple_of
+ elif isinstance(t, mi.StrType):
+ schema["type"] = "string"
+ if t.max_length is not None:
+ schema["maxLength"] = t.max_length
+ if t.min_length is not None:
+ schema["minLength"] = t.min_length
+ if t.pattern is not None:
+ schema["pattern"] = t.pattern
+ elif isinstance(t, (mi.BytesType, mi.ByteArrayType, mi.MemoryViewType)):
+ schema["type"] = "string"
+ schema["contentEncoding"] = "base64"
+ if t.max_length is not None:
+ schema["maxLength"] = 4 * ((t.max_length + 2) // 3)
+ if t.min_length is not None:
+ schema["minLength"] = 4 * ((t.min_length + 2) // 3)
+ elif isinstance(t, mi.DateTimeType):
+ schema["type"] = "string"
+ if t.tz is True:
+ schema["format"] = "date-time"
+ elif isinstance(t, mi.TimeType):
+ schema["type"] = "string"
+ if t.tz is True:
+ schema["format"] = "time"
+ elif t.tz is False:
+ schema["format"] = "partial-time"
+ elif isinstance(t, mi.DateType):
+ schema["type"] = "string"
+ schema["format"] = "date"
+ elif isinstance(t, mi.TimeDeltaType):
+ schema["type"] = "string"
+ schema["format"] = "duration"
+ elif isinstance(t, mi.UUIDType):
+ schema["type"] = "string"
+ schema["format"] = "uuid"
+ elif isinstance(t, mi.DecimalType):
+ schema["type"] = "string"
+ schema["format"] = "decimal"
+ elif isinstance(t, mi.CollectionType):
+ schema["type"] = "array"
+ if not isinstance(t.item_type, mi.AnyType):
+ schema["items"] = self.to_schema(t.item_type)
+ if t.max_length is not None:
+ schema["maxItems"] = t.max_length
+ if t.min_length is not None:
+ schema["minItems"] = t.min_length
+ elif isinstance(t, mi.TupleType):
+ schema["type"] = "array"
+ schema["minItems"] = schema["maxItems"] = len(t.item_types)
+ if t.item_types:
+ schema["prefixItems"] = [self.to_schema(i) for i in t.item_types]
+ schema["items"] = False
+ elif isinstance(t, mi.DictType):
+ schema["type"] = "object"
+ # If there are restrictions on the keys, specify them as propertyNames
+ if isinstance(key_type := t.key_type, mi.StrType):
+ property_names: dict[str, Any] = {}
+ if key_type.min_length is not None:
+ property_names["minLength"] = key_type.min_length
+ if key_type.max_length is not None:
+ property_names["maxLength"] = key_type.max_length
+ if key_type.pattern is not None:
+ property_names["pattern"] = key_type.pattern
+ if property_names:
+ schema["propertyNames"] = property_names
+ if not isinstance(t.value_type, mi.AnyType):
+ schema["additionalProperties"] = self.to_schema(t.value_type)
+ if t.max_length is not None:
+ schema["maxProperties"] = t.max_length
+ if t.min_length is not None:
+ schema["minProperties"] = t.min_length
+ elif isinstance(t, mi.UnionType):
+ structs = {}
+ other = []
+ tag_field = None
+ for subtype in t.types:
+ real_type = subtype
+ while isinstance(real_type, mi.Metadata):
+ real_type = real_type.type
+ if isinstance(real_type, mi.StructType) and not real_type.array_like:
+ tag_field = real_type.tag_field
+ structs[real_type.tag] = real_type
+ else:
+ other.append(subtype)
+
+ options = [self.to_schema(a) for a in other]
+
+ if len(structs) >= 2:
+ mapping = {
+ k: self.ref_template.format(name=self.name_map[v.cls])
+ for k, v in structs.items()
+ }
+ struct_schema = {
+ "anyOf": [self.to_schema(v) for v in structs.values()],
+ "discriminator": {"propertyName": tag_field, "mapping": mapping},
+ }
+ if options:
+ options.append(struct_schema)
+ schema["anyOf"] = options
+ else:
+ schema.update(struct_schema)
+ elif len(structs) == 1:
+ _, subtype = structs.popitem()
+ options.append(self.to_schema(subtype))
+ schema["anyOf"] = options
+ else:
+ schema["anyOf"] = options
+ elif isinstance(t, mi.LiteralType):
+ schema["enum"] = sorted(t.values)
+ elif isinstance(t, mi.EnumType):
+ schema.setdefault("title", t.cls.__name__)
+ if doc := _get_doc(t):
+ schema.setdefault("description", doc)
+ schema["enum"] = sorted(e.value for e in t.cls)
+ elif isinstance(t, mi.StructType):
+ schema.setdefault("title", _get_class_name(t.cls))
+ if doc := _get_doc(t):
+ schema.setdefault("description", doc)
+ required = []
+ names = []
+ fields = []
+
+ if t.tag_field is not None:
+ required.append(t.tag_field)
+ names.append(t.tag_field)
+ fields.append({"enum": [t.tag]})
+
+ for field in t.fields:
+ field_schema = self.to_schema(field.type)
+ if field.required:
+ required.append(field.encode_name)
+ elif field.default is not mi.NODEFAULT:
+ field_schema["default"] = to_builtins(field.default, str_keys=True)
+ elif field.default_factory in (list, dict, set, bytearray):
+ field_schema["default"] = field.default_factory()
+ names.append(field.encode_name)
+ fields.append(field_schema)
+
+ if t.array_like:
+ n_trailing_defaults = 0
+ for n_trailing_defaults, f in enumerate(reversed(t.fields)):
+ if f.required:
+ break
+ schema["type"] = "array"
+ schema["prefixItems"] = fields
+ schema["minItems"] = len(fields) - n_trailing_defaults
+ if t.forbid_unknown_fields:
+ schema["maxItems"] = len(fields)
+ else:
+ schema["type"] = "object"
+ schema["properties"] = dict(zip(names, fields))
+ schema["required"] = required
+ if t.forbid_unknown_fields:
+ schema["additionalProperties"] = False
+ elif isinstance(t, (mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)):
+ schema.setdefault("title", _get_class_name(t.cls))
+ if doc := _get_doc(t):
+ schema.setdefault("description", doc)
+ names = []
+ fields = []
+ required = []
+ for field in t.fields:
+ field_schema = self.to_schema(field.type)
+ if field.required:
+ required.append(field.encode_name)
+ elif field.default is not mi.NODEFAULT:
+ field_schema["default"] = to_builtins(field.default, str_keys=True)
+ names.append(field.encode_name)
+ fields.append(field_schema)
+ if isinstance(t, mi.NamedTupleType):
+ schema["type"] = "array"
+ schema["prefixItems"] = fields
+ schema["minItems"] = len(required)
+ schema["maxItems"] = len(fields)
+ else:
+ schema["type"] = "object"
+ schema["properties"] = dict(zip(names, fields))
+ schema["required"] = required
+ elif isinstance(t, mi.ExtType):
+ raise TypeError("json-schema doesn't support msgpack Ext types")
+ elif isinstance(t, mi.CustomType):
+ if self.schema_hook:
+ try:
+ schema = mi._merge_json(self.schema_hook(t.cls), schema)
+ except NotImplementedError:
+ pass
+ if not schema:
+ raise TypeError(
+ "Generating JSON schema for custom types requires either:\n"
+ "- specifying a `schema_hook`\n"
+ "- annotating the type with `Meta(extra_json_schema=...)`\n"
+ "\n"
+ f"type {t.cls!r} is not supported"
+ )
+ else:
+ # This should be unreachable
+ raise TypeError(f"json-schema doesn't support type {t!r}")
+
+ return schema