diff options
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py')
-rw-r--r-- | venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py | 233 |
1 files changed, 233 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py b/venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py new file mode 100644 index 0000000..c3da5c4 --- /dev/null +++ b/venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from litestar._openapi.schema_generation import SchemaCreator +from litestar._openapi.schema_generation.utils import get_formatted_examples +from litestar.constants import RESERVED_KWARGS +from litestar.enums import ParamType +from litestar.exceptions import ImproperlyConfiguredException +from litestar.openapi.spec.parameter import Parameter +from litestar.openapi.spec.schema import Schema +from litestar.params import DependencyKwarg, ParameterKwarg +from litestar.types import Empty +from litestar.typing import FieldDefinition + +if TYPE_CHECKING: + from litestar._openapi.datastructures import OpenAPIContext + from litestar.handlers.base import BaseRouteHandler + from litestar.openapi.spec import Reference + from litestar.types.internal_types import PathParameterDefinition + +__all__ = ("create_parameters_for_handler",) + + +class ParameterCollection: + """Facilitates conditional deduplication of parameters. + + If multiple parameters with the same name are produced for a handler, the condition is ignored if the two + ``Parameter`` instances are the same (the first is retained and any duplicates are ignored). If the ``Parameter`` + instances are not the same, an exception is raised. + """ + + def __init__(self, route_handler: BaseRouteHandler) -> None: + """Initialize ``ParameterCollection``. + + Args: + route_handler: Associated route handler + """ + self.route_handler = route_handler + self._parameters: dict[tuple[str, str], Parameter] = {} + + def add(self, parameter: Parameter) -> None: + """Add a ``Parameter`` to the collection. + + If an existing parameter with the same name and type already exists, the + parameter is ignored. + + If an existing parameter with the same name but different type exists, raises + ``ImproperlyConfiguredException``. + """ + + if (parameter.name, parameter.param_in) not in self._parameters: + # because we are defining routes as unique per path, we have to handle here a situation when there is an optional + # path parameter. e.g. get(path=["/", "/{param:str}"]). When parsing the parameter for path, the route handler + # would still have a kwarg called param: + # def handler(param: str | None) -> ... + if parameter.param_in != ParamType.QUERY or all( + f"{{{parameter.name}:" not in path for path in self.route_handler.paths + ): + self._parameters[(parameter.name, parameter.param_in)] = parameter + return + + pre_existing = self._parameters[(parameter.name, parameter.param_in)] + if parameter == pre_existing: + return + + raise ImproperlyConfiguredException( + f"OpenAPI schema generation for handler `{self.route_handler}` detected multiple parameters named " + f"'{parameter.name}' with different types." + ) + + def list(self) -> list[Parameter]: + """Return a list of all ``Parameter``'s in the collection.""" + return list(self._parameters.values()) + + +class ParameterFactory: + """Factory for creating OpenAPI Parameters for a given route handler.""" + + def __init__( + self, + context: OpenAPIContext, + route_handler: BaseRouteHandler, + path_parameters: tuple[PathParameterDefinition, ...], + ) -> None: + """Initialize ParameterFactory. + + Args: + context: The OpenAPI context. + route_handler: The route handler. + path_parameters: The path parameters for the route. + """ + self.context = context + self.schema_creator = SchemaCreator.from_openapi_context(self.context, prefer_alias=True) + self.route_handler = route_handler + self.parameters = ParameterCollection(route_handler) + self.dependency_providers = route_handler.resolve_dependencies() + self.layered_parameters = route_handler.resolve_layered_parameters() + self.path_parameters_names = {p.name for p in path_parameters} + + def create_parameter(self, field_definition: FieldDefinition, parameter_name: str) -> Parameter: + """Create an OpenAPI Parameter instance for a field definition. + + Args: + field_definition: The field definition. + parameter_name: The name of the parameter. + """ + + result: Schema | Reference | None = None + kwarg_definition = ( + field_definition.kwarg_definition if isinstance(field_definition.kwarg_definition, ParameterKwarg) else None + ) + + if parameter_name in self.path_parameters_names: + param_in = ParamType.PATH + is_required = True + result = self.schema_creator.for_field_definition(field_definition) + elif kwarg_definition and kwarg_definition.header: + parameter_name = kwarg_definition.header + param_in = ParamType.HEADER + is_required = field_definition.is_required + elif kwarg_definition and kwarg_definition.cookie: + parameter_name = kwarg_definition.cookie + param_in = ParamType.COOKIE + is_required = field_definition.is_required + else: + is_required = field_definition.is_required + param_in = ParamType.QUERY + parameter_name = kwarg_definition.query if kwarg_definition and kwarg_definition.query else parameter_name + + if not result: + result = self.schema_creator.for_field_definition(field_definition) + + schema = result if isinstance(result, Schema) else self.context.schema_registry.from_reference(result).schema + + examples_list = kwarg_definition.examples or [] if kwarg_definition else [] + examples = get_formatted_examples(field_definition, examples_list) + + return Parameter( + description=schema.description, + name=parameter_name, + param_in=param_in, + required=is_required, + schema=result, + examples=examples or None, + ) + + def get_layered_parameter(self, field_name: str, field_definition: FieldDefinition) -> Parameter: + """Create a parameter for a field definition that has a KwargDefinition defined on the layers. + + Args: + field_name: The name of the field. + field_definition: The field definition. + """ + layer_field = self.layered_parameters[field_name] + + field = field_definition if field_definition.is_parameter_field else layer_field + default = layer_field.default if field_definition.has_default else field_definition.default + annotation = field_definition.annotation if field_definition is not Empty else layer_field.annotation + + parameter_name = field_name + if isinstance(field.kwarg_definition, ParameterKwarg): + parameter_name = ( + field.kwarg_definition.query + or field.kwarg_definition.header + or field.kwarg_definition.cookie + or field_name + ) + + field_definition = FieldDefinition.from_kwarg( + inner_types=field.inner_types, + default=default, + extra=field.extra, + annotation=annotation, + kwarg_definition=field.kwarg_definition, + name=field_name, + ) + return self.create_parameter(field_definition=field_definition, parameter_name=parameter_name) + + def create_parameters_for_field_definitions(self, fields: dict[str, FieldDefinition]) -> None: + """Add Parameter models to the handler's collection for the given field definitions. + + Args: + fields: The field definitions. + """ + unique_handler_fields = ( + (k, v) for k, v in fields.items() if k not in RESERVED_KWARGS and k not in self.layered_parameters + ) + unique_layered_fields = ( + (k, v) for k, v in self.layered_parameters.items() if k not in RESERVED_KWARGS and k not in fields + ) + intersection_fields = ( + (k, v) for k, v in fields.items() if k not in RESERVED_KWARGS and k in self.layered_parameters + ) + + for field_name, field_definition in unique_handler_fields: + if ( + isinstance(field_definition.kwarg_definition, DependencyKwarg) + and field_name not in self.dependency_providers + ): + # never document explicit dependencies + continue + + if provider := self.dependency_providers.get(field_name): + self.create_parameters_for_field_definitions(fields=provider.parsed_fn_signature.parameters) + else: + self.parameters.add(self.create_parameter(field_definition=field_definition, parameter_name=field_name)) + + for field_name, field_definition in unique_layered_fields: + self.parameters.add(self.create_parameter(field_definition=field_definition, parameter_name=field_name)) + + for field_name, field_definition in intersection_fields: + self.parameters.add(self.get_layered_parameter(field_name=field_name, field_definition=field_definition)) + + def create_parameters_for_handler(self) -> list[Parameter]: + """Create a list of path/query/header Parameter models for the given PathHandler.""" + handler_fields = self.route_handler.parsed_fn_signature.parameters + self.create_parameters_for_field_definitions(handler_fields) + return self.parameters.list() + + +def create_parameters_for_handler( + context: OpenAPIContext, + route_handler: BaseRouteHandler, + path_parameters: tuple[PathParameterDefinition, ...], +) -> list[Parameter]: + """Create a list of path/query/header Parameter models for the given PathHandler.""" + factory = ParameterFactory( + context=context, + route_handler=route_handler, + path_parameters=path_parameters, + ) + return factory.create_parameters_for_handler() |