summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_openapi/responses.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/_openapi/responses.py335
1 files changed, 335 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py b/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py
new file mode 100644
index 0000000..6b0f312
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py
@@ -0,0 +1,335 @@
+from __future__ import annotations
+
+import contextlib
+import re
+from copy import copy
+from dataclasses import asdict
+from http import HTTPStatus
+from operator import attrgetter
+from typing import TYPE_CHECKING, Any, Iterator
+
+from litestar._openapi.schema_generation import SchemaCreator
+from litestar._openapi.schema_generation.utils import get_formatted_examples
+from litestar.enums import MediaType
+from litestar.exceptions import HTTPException, ValidationException
+from litestar.openapi.spec import Example, OpenAPIResponse, Reference
+from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType
+from litestar.openapi.spec.header import OpenAPIHeader
+from litestar.openapi.spec.media_type import OpenAPIMediaType
+from litestar.openapi.spec.schema import Schema
+from litestar.response import (
+ File,
+ Redirect,
+ Stream,
+ Template,
+)
+from litestar.response import (
+ Response as LitestarResponse,
+)
+from litestar.response.base import ASGIResponse
+from litestar.types.builtin_types import NoneType
+from litestar.typing import FieldDefinition
+from litestar.utils import get_enum_string_value, get_name
+
+if TYPE_CHECKING:
+ from litestar._openapi.datastructures import OpenAPIContext
+ from litestar.datastructures.cookie import Cookie
+ from litestar.handlers.http_handlers import HTTPRouteHandler
+ from litestar.openapi.spec.responses import Responses
+
+
+__all__ = ("create_responses_for_handler",)
+
+CAPITAL_LETTERS_PATTERN = re.compile(r"(?=[A-Z])")
+
+
+def pascal_case_to_text(string: str) -> str:
+ """Given a 'PascalCased' string, return its split form- 'Pascal Cased'."""
+ return " ".join(re.split(CAPITAL_LETTERS_PATTERN, string)).strip()
+
+
+def create_cookie_schema(cookie: Cookie) -> Schema:
+ """Given a Cookie instance, return its corresponding OpenAPI schema.
+
+ Args:
+ cookie: Cookie
+
+ Returns:
+ Schema
+ """
+ cookie_copy = copy(cookie)
+ cookie_copy.value = "<string>"
+ value = cookie_copy.to_header(header="")
+ return Schema(description=cookie.description or "", example=value)
+
+
+class ResponseFactory:
+ """Factory for creating a Response instance for a given route handler."""
+
+ def __init__(self, context: OpenAPIContext, route_handler: HTTPRouteHandler) -> None:
+ """Initialize the factory.
+
+ Args:
+ context: An OpenAPIContext instance.
+ route_handler: An HTTPRouteHandler instance.
+ """
+ self.context = context
+ self.route_handler = route_handler
+ self.field_definition = route_handler.parsed_fn_signature.return_type
+ self.schema_creator = SchemaCreator.from_openapi_context(context, prefer_alias=False)
+
+ def create_responses(self, raises_validation_error: bool) -> Responses | None:
+ """Create the schema for responses, if any.
+
+ Args:
+ raises_validation_error: Boolean flag indicating whether the handler raises a ValidationException.
+
+ Returns:
+ Responses
+ """
+ responses: Responses = {
+ str(self.route_handler.status_code): self.create_success_response(),
+ }
+
+ exceptions = list(self.route_handler.raises or [])
+ if raises_validation_error and ValidationException not in exceptions:
+ exceptions.append(ValidationException)
+
+ for status_code, response in create_error_responses(exceptions=exceptions):
+ responses[status_code] = response
+
+ for status_code, response in self.create_additional_responses():
+ responses[status_code] = response
+
+ return responses or None
+
+ def create_description(self) -> str:
+ """Create the description for a success response."""
+ default_descriptions: dict[Any, str] = {
+ Stream: "Stream Response",
+ Redirect: "Redirect Response",
+ File: "File Download",
+ }
+ return (
+ self.route_handler.response_description
+ or default_descriptions.get(self.field_definition.annotation)
+ or HTTPStatus(self.route_handler.status_code).description
+ )
+
+ def create_success_response(self) -> OpenAPIResponse:
+ """Create the schema for a success response."""
+ if self.field_definition.is_subclass_of((NoneType, ASGIResponse)):
+ response = OpenAPIResponse(content=None, description=self.create_description())
+ elif self.field_definition.is_subclass_of(Redirect):
+ response = self.create_redirect_response()
+ elif self.field_definition.is_subclass_of((File, Stream)):
+ response = self.create_file_response()
+ else:
+ media_type = self.route_handler.media_type
+
+ if dto := self.route_handler.resolve_return_dto():
+ result = dto.create_openapi_schema(
+ field_definition=self.field_definition,
+ handler_id=self.route_handler.handler_id,
+ schema_creator=self.schema_creator,
+ )
+ else:
+ if self.field_definition.is_subclass_of(Template):
+ field_def = FieldDefinition.from_annotation(str)
+ media_type = media_type or MediaType.HTML
+ elif self.field_definition.is_subclass_of(LitestarResponse):
+ field_def = (
+ self.field_definition.inner_types[0]
+ if self.field_definition.inner_types
+ else FieldDefinition.from_annotation(Any)
+ )
+ media_type = media_type or MediaType.JSON
+ else:
+ field_def = self.field_definition
+
+ result = self.schema_creator.for_field_definition(field_def)
+
+ schema = (
+ result if isinstance(result, Schema) else self.context.schema_registry.from_reference(result).schema
+ )
+ schema.content_encoding = self.route_handler.content_encoding
+ schema.content_media_type = self.route_handler.content_media_type
+ response = OpenAPIResponse(
+ content={get_enum_string_value(media_type): OpenAPIMediaType(schema=result)},
+ description=self.create_description(),
+ )
+ self.set_success_response_headers(response)
+ return response
+
+ def create_redirect_response(self) -> OpenAPIResponse:
+ """Create the schema for a redirect response."""
+ return OpenAPIResponse(
+ content=None,
+ description=self.create_description(),
+ headers={
+ "location": OpenAPIHeader(
+ schema=Schema(type=OpenAPIType.STRING), description="target path for the redirect"
+ )
+ },
+ )
+
+ def create_file_response(self) -> OpenAPIResponse:
+ """Create the schema for a file/stream response."""
+ return OpenAPIResponse(
+ content={
+ self.route_handler.media_type: OpenAPIMediaType(
+ schema=Schema(
+ type=OpenAPIType.STRING,
+ content_encoding=self.route_handler.content_encoding,
+ content_media_type=self.route_handler.content_media_type or "application/octet-stream",
+ ),
+ )
+ },
+ description=self.create_description(),
+ headers={
+ "content-length": OpenAPIHeader(
+ schema=Schema(type=OpenAPIType.STRING), description="File size in bytes"
+ ),
+ "last-modified": OpenAPIHeader(
+ schema=Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.DATE_TIME),
+ description="Last modified data-time in RFC 2822 format",
+ ),
+ "etag": OpenAPIHeader(schema=Schema(type=OpenAPIType.STRING), description="Entity tag"),
+ },
+ )
+
+ def set_success_response_headers(self, response: OpenAPIResponse) -> None:
+ """Set the schema for success response headers, if any."""
+
+ if response.headers is None:
+ response.headers = {}
+
+ if not self.schema_creator.generate_examples:
+ schema_creator = self.schema_creator
+ else:
+ schema_creator = SchemaCreator.from_openapi_context(self.context, generate_examples=False)
+
+ for response_header in self.route_handler.resolve_response_headers():
+ header = OpenAPIHeader()
+ for attribute_name, attribute_value in (
+ (k, v) for k, v in asdict(response_header).items() if v is not None
+ ):
+ if attribute_name == "value":
+ header.schema = schema_creator.for_field_definition(
+ FieldDefinition.from_annotation(type(attribute_value))
+ )
+ elif attribute_name != "documentation_only":
+ setattr(header, attribute_name, attribute_value)
+
+ response.headers[response_header.name] = header
+
+ if cookies := self.route_handler.resolve_response_cookies():
+ response.headers["Set-Cookie"] = OpenAPIHeader(
+ schema=Schema(
+ all_of=[create_cookie_schema(cookie=cookie) for cookie in sorted(cookies, key=attrgetter("key"))]
+ )
+ )
+
+ def create_additional_responses(self) -> Iterator[tuple[str, OpenAPIResponse]]:
+ """Create the schema for additional responses, if any."""
+ if not self.route_handler.responses:
+ return
+
+ for status_code, additional_response in self.route_handler.responses.items():
+ schema_creator = SchemaCreator.from_openapi_context(
+ self.context,
+ prefer_alias=False,
+ generate_examples=additional_response.generate_examples,
+ )
+ field_def = FieldDefinition.from_annotation(additional_response.data_container)
+
+ examples: dict[str, Example | Reference] | None = (
+ dict(get_formatted_examples(field_def, additional_response.examples))
+ if additional_response.examples
+ else None
+ )
+
+ content: dict[str, OpenAPIMediaType] | None
+ if additional_response.data_container is not None:
+ schema = schema_creator.for_field_definition(field_def)
+ content = {additional_response.media_type: OpenAPIMediaType(schema=schema, examples=examples)}
+ else:
+ content = None
+
+ yield (
+ str(status_code),
+ OpenAPIResponse(
+ description=additional_response.description,
+ content=content,
+ ),
+ )
+
+
+def create_error_responses(exceptions: list[type[HTTPException]]) -> Iterator[tuple[str, OpenAPIResponse]]:
+ """Create the schema for error responses, if any."""
+ grouped_exceptions: dict[int, list[type[HTTPException]]] = {}
+ for exc in exceptions:
+ if not grouped_exceptions.get(exc.status_code):
+ grouped_exceptions[exc.status_code] = []
+ grouped_exceptions[exc.status_code].append(exc)
+ for status_code, exception_group in grouped_exceptions.items():
+ exceptions_schemas = []
+ group_description: str = ""
+ for exc in exception_group:
+ example_detail = ""
+ if hasattr(exc, "detail") and exc.detail:
+ group_description = exc.detail
+ example_detail = exc.detail
+
+ if not example_detail:
+ with contextlib.suppress(Exception):
+ example_detail = HTTPStatus(status_code).phrase
+
+ exceptions_schemas.append(
+ Schema(
+ type=OpenAPIType.OBJECT,
+ required=["detail", "status_code"],
+ properties={
+ "status_code": Schema(type=OpenAPIType.INTEGER),
+ "detail": Schema(type=OpenAPIType.STRING),
+ "extra": Schema(
+ type=[OpenAPIType.NULL, OpenAPIType.OBJECT, OpenAPIType.ARRAY],
+ additional_properties=Schema(),
+ ),
+ },
+ description=pascal_case_to_text(get_name(exc)),
+ examples=[{"status_code": status_code, "detail": example_detail, "extra": {}}],
+ )
+ )
+ if len(exceptions_schemas) > 1: # noqa: SIM108
+ schema = Schema(one_of=exceptions_schemas)
+ else:
+ schema = exceptions_schemas[0]
+
+ if not group_description:
+ with contextlib.suppress(Exception):
+ group_description = HTTPStatus(status_code).description
+
+ yield (
+ str(status_code),
+ OpenAPIResponse(
+ description=group_description,
+ content={MediaType.JSON: OpenAPIMediaType(schema=schema)},
+ ),
+ )
+
+
+def create_responses_for_handler(
+ context: OpenAPIContext, route_handler: HTTPRouteHandler, raises_validation_error: bool
+) -> Responses | None:
+ """Create the schema for responses, if any.
+
+ Args:
+ context: An OpenAPIContext instance.
+ route_handler: An HTTPRouteHandler instance.
+ raises_validation_error: Boolean flag indicating whether the handler raises a ValidationException.
+
+ Returns:
+ Responses
+ """
+ return ResponseFactory(context, route_handler).create_responses(raises_validation_error=raises_validation_error)