diff options
author | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
---|---|---|
committer | cyfraeviolae <cyfraeviolae> | 2024-04-03 03:10:44 -0400 |
commit | 6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch) | |
tree | b1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/_openapi/responses.py | |
parent | 4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff) |
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_openapi/responses.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/_openapi/responses.py | 335 |
1 files changed, 335 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py b/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py new file mode 100644 index 0000000..6b0f312 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/responses.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import contextlib +import re +from copy import copy +from dataclasses import asdict +from http import HTTPStatus +from operator import attrgetter +from typing import TYPE_CHECKING, Any, Iterator + +from litestar._openapi.schema_generation import SchemaCreator +from litestar._openapi.schema_generation.utils import get_formatted_examples +from litestar.enums import MediaType +from litestar.exceptions import HTTPException, ValidationException +from litestar.openapi.spec import Example, OpenAPIResponse, Reference +from litestar.openapi.spec.enums import OpenAPIFormat, OpenAPIType +from litestar.openapi.spec.header import OpenAPIHeader +from litestar.openapi.spec.media_type import OpenAPIMediaType +from litestar.openapi.spec.schema import Schema +from litestar.response import ( + File, + Redirect, + Stream, + Template, +) +from litestar.response import ( + Response as LitestarResponse, +) +from litestar.response.base import ASGIResponse +from litestar.types.builtin_types import NoneType +from litestar.typing import FieldDefinition +from litestar.utils import get_enum_string_value, get_name + +if TYPE_CHECKING: + from litestar._openapi.datastructures import OpenAPIContext + from litestar.datastructures.cookie import Cookie + from litestar.handlers.http_handlers import HTTPRouteHandler + from litestar.openapi.spec.responses import Responses + + +__all__ = ("create_responses_for_handler",) + +CAPITAL_LETTERS_PATTERN = re.compile(r"(?=[A-Z])") + + +def pascal_case_to_text(string: str) -> str: + """Given a 'PascalCased' string, return its split form- 'Pascal Cased'.""" + return " ".join(re.split(CAPITAL_LETTERS_PATTERN, string)).strip() + + +def create_cookie_schema(cookie: Cookie) -> Schema: + """Given a Cookie instance, return its corresponding OpenAPI schema. + + Args: + cookie: Cookie + + Returns: + Schema + """ + cookie_copy = copy(cookie) + cookie_copy.value = "<string>" + value = cookie_copy.to_header(header="") + return Schema(description=cookie.description or "", example=value) + + +class ResponseFactory: + """Factory for creating a Response instance for a given route handler.""" + + def __init__(self, context: OpenAPIContext, route_handler: HTTPRouteHandler) -> None: + """Initialize the factory. + + Args: + context: An OpenAPIContext instance. + route_handler: An HTTPRouteHandler instance. + """ + self.context = context + self.route_handler = route_handler + self.field_definition = route_handler.parsed_fn_signature.return_type + self.schema_creator = SchemaCreator.from_openapi_context(context, prefer_alias=False) + + def create_responses(self, raises_validation_error: bool) -> Responses | None: + """Create the schema for responses, if any. + + Args: + raises_validation_error: Boolean flag indicating whether the handler raises a ValidationException. + + Returns: + Responses + """ + responses: Responses = { + str(self.route_handler.status_code): self.create_success_response(), + } + + exceptions = list(self.route_handler.raises or []) + if raises_validation_error and ValidationException not in exceptions: + exceptions.append(ValidationException) + + for status_code, response in create_error_responses(exceptions=exceptions): + responses[status_code] = response + + for status_code, response in self.create_additional_responses(): + responses[status_code] = response + + return responses or None + + def create_description(self) -> str: + """Create the description for a success response.""" + default_descriptions: dict[Any, str] = { + Stream: "Stream Response", + Redirect: "Redirect Response", + File: "File Download", + } + return ( + self.route_handler.response_description + or default_descriptions.get(self.field_definition.annotation) + or HTTPStatus(self.route_handler.status_code).description + ) + + def create_success_response(self) -> OpenAPIResponse: + """Create the schema for a success response.""" + if self.field_definition.is_subclass_of((NoneType, ASGIResponse)): + response = OpenAPIResponse(content=None, description=self.create_description()) + elif self.field_definition.is_subclass_of(Redirect): + response = self.create_redirect_response() + elif self.field_definition.is_subclass_of((File, Stream)): + response = self.create_file_response() + else: + media_type = self.route_handler.media_type + + if dto := self.route_handler.resolve_return_dto(): + result = dto.create_openapi_schema( + field_definition=self.field_definition, + handler_id=self.route_handler.handler_id, + schema_creator=self.schema_creator, + ) + else: + if self.field_definition.is_subclass_of(Template): + field_def = FieldDefinition.from_annotation(str) + media_type = media_type or MediaType.HTML + elif self.field_definition.is_subclass_of(LitestarResponse): + field_def = ( + self.field_definition.inner_types[0] + if self.field_definition.inner_types + else FieldDefinition.from_annotation(Any) + ) + media_type = media_type or MediaType.JSON + else: + field_def = self.field_definition + + result = self.schema_creator.for_field_definition(field_def) + + schema = ( + result if isinstance(result, Schema) else self.context.schema_registry.from_reference(result).schema + ) + schema.content_encoding = self.route_handler.content_encoding + schema.content_media_type = self.route_handler.content_media_type + response = OpenAPIResponse( + content={get_enum_string_value(media_type): OpenAPIMediaType(schema=result)}, + description=self.create_description(), + ) + self.set_success_response_headers(response) + return response + + def create_redirect_response(self) -> OpenAPIResponse: + """Create the schema for a redirect response.""" + return OpenAPIResponse( + content=None, + description=self.create_description(), + headers={ + "location": OpenAPIHeader( + schema=Schema(type=OpenAPIType.STRING), description="target path for the redirect" + ) + }, + ) + + def create_file_response(self) -> OpenAPIResponse: + """Create the schema for a file/stream response.""" + return OpenAPIResponse( + content={ + self.route_handler.media_type: OpenAPIMediaType( + schema=Schema( + type=OpenAPIType.STRING, + content_encoding=self.route_handler.content_encoding, + content_media_type=self.route_handler.content_media_type or "application/octet-stream", + ), + ) + }, + description=self.create_description(), + headers={ + "content-length": OpenAPIHeader( + schema=Schema(type=OpenAPIType.STRING), description="File size in bytes" + ), + "last-modified": OpenAPIHeader( + schema=Schema(type=OpenAPIType.STRING, format=OpenAPIFormat.DATE_TIME), + description="Last modified data-time in RFC 2822 format", + ), + "etag": OpenAPIHeader(schema=Schema(type=OpenAPIType.STRING), description="Entity tag"), + }, + ) + + def set_success_response_headers(self, response: OpenAPIResponse) -> None: + """Set the schema for success response headers, if any.""" + + if response.headers is None: + response.headers = {} + + if not self.schema_creator.generate_examples: + schema_creator = self.schema_creator + else: + schema_creator = SchemaCreator.from_openapi_context(self.context, generate_examples=False) + + for response_header in self.route_handler.resolve_response_headers(): + header = OpenAPIHeader() + for attribute_name, attribute_value in ( + (k, v) for k, v in asdict(response_header).items() if v is not None + ): + if attribute_name == "value": + header.schema = schema_creator.for_field_definition( + FieldDefinition.from_annotation(type(attribute_value)) + ) + elif attribute_name != "documentation_only": + setattr(header, attribute_name, attribute_value) + + response.headers[response_header.name] = header + + if cookies := self.route_handler.resolve_response_cookies(): + response.headers["Set-Cookie"] = OpenAPIHeader( + schema=Schema( + all_of=[create_cookie_schema(cookie=cookie) for cookie in sorted(cookies, key=attrgetter("key"))] + ) + ) + + def create_additional_responses(self) -> Iterator[tuple[str, OpenAPIResponse]]: + """Create the schema for additional responses, if any.""" + if not self.route_handler.responses: + return + + for status_code, additional_response in self.route_handler.responses.items(): + schema_creator = SchemaCreator.from_openapi_context( + self.context, + prefer_alias=False, + generate_examples=additional_response.generate_examples, + ) + field_def = FieldDefinition.from_annotation(additional_response.data_container) + + examples: dict[str, Example | Reference] | None = ( + dict(get_formatted_examples(field_def, additional_response.examples)) + if additional_response.examples + else None + ) + + content: dict[str, OpenAPIMediaType] | None + if additional_response.data_container is not None: + schema = schema_creator.for_field_definition(field_def) + content = {additional_response.media_type: OpenAPIMediaType(schema=schema, examples=examples)} + else: + content = None + + yield ( + str(status_code), + OpenAPIResponse( + description=additional_response.description, + content=content, + ), + ) + + +def create_error_responses(exceptions: list[type[HTTPException]]) -> Iterator[tuple[str, OpenAPIResponse]]: + """Create the schema for error responses, if any.""" + grouped_exceptions: dict[int, list[type[HTTPException]]] = {} + for exc in exceptions: + if not grouped_exceptions.get(exc.status_code): + grouped_exceptions[exc.status_code] = [] + grouped_exceptions[exc.status_code].append(exc) + for status_code, exception_group in grouped_exceptions.items(): + exceptions_schemas = [] + group_description: str = "" + for exc in exception_group: + example_detail = "" + if hasattr(exc, "detail") and exc.detail: + group_description = exc.detail + example_detail = exc.detail + + if not example_detail: + with contextlib.suppress(Exception): + example_detail = HTTPStatus(status_code).phrase + + exceptions_schemas.append( + Schema( + type=OpenAPIType.OBJECT, + required=["detail", "status_code"], + properties={ + "status_code": Schema(type=OpenAPIType.INTEGER), + "detail": Schema(type=OpenAPIType.STRING), + "extra": Schema( + type=[OpenAPIType.NULL, OpenAPIType.OBJECT, OpenAPIType.ARRAY], + additional_properties=Schema(), + ), + }, + description=pascal_case_to_text(get_name(exc)), + examples=[{"status_code": status_code, "detail": example_detail, "extra": {}}], + ) + ) + if len(exceptions_schemas) > 1: # noqa: SIM108 + schema = Schema(one_of=exceptions_schemas) + else: + schema = exceptions_schemas[0] + + if not group_description: + with contextlib.suppress(Exception): + group_description = HTTPStatus(status_code).description + + yield ( + str(status_code), + OpenAPIResponse( + description=group_description, + content={MediaType.JSON: OpenAPIMediaType(schema=schema)}, + ), + ) + + +def create_responses_for_handler( + context: OpenAPIContext, route_handler: HTTPRouteHandler, raises_validation_error: bool +) -> Responses | None: + """Create the schema for responses, if any. + + Args: + context: An OpenAPIContext instance. + route_handler: An HTTPRouteHandler instance. + raises_validation_error: Boolean flag indicating whether the handler raises a ValidationException. + + Returns: + Responses + """ + return ResponseFactory(context, route_handler).create_responses(raises_validation_error=raises_validation_error) |