summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/_openapi/parameters.py
blob: c3da5c431a8ac7cd98965f676e64df4a9cc1b19d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from __future__ import annotations

from typing import TYPE_CHECKING

from litestar._openapi.schema_generation import SchemaCreator
from litestar._openapi.schema_generation.utils import get_formatted_examples
from litestar.constants import RESERVED_KWARGS
from litestar.enums import ParamType
from litestar.exceptions import ImproperlyConfiguredException
from litestar.openapi.spec.parameter import Parameter
from litestar.openapi.spec.schema import Schema
from litestar.params import DependencyKwarg, ParameterKwarg
from litestar.types import Empty
from litestar.typing import FieldDefinition

if TYPE_CHECKING:
    from litestar._openapi.datastructures import OpenAPIContext
    from litestar.handlers.base import BaseRouteHandler
    from litestar.openapi.spec import Reference
    from litestar.types.internal_types import PathParameterDefinition

__all__ = ("create_parameters_for_handler",)


class ParameterCollection:
    """Facilitates conditional deduplication of parameters.

    If multiple parameters with the same name are produced for a handler, the condition is ignored if the two
    ``Parameter`` instances are the same (the first is retained and any duplicates are ignored). If the ``Parameter``
    instances are not the same, an exception is raised.
    """

    def __init__(self, route_handler: BaseRouteHandler) -> None:
        """Initialize ``ParameterCollection``.

        Args:
            route_handler: Associated route handler
        """
        self.route_handler = route_handler
        self._parameters: dict[tuple[str, str], Parameter] = {}

    def add(self, parameter: Parameter) -> None:
        """Add a ``Parameter`` to the collection.

        If an existing parameter with the same name and type already exists, the
        parameter is ignored.

        If an existing parameter with the same name but different type exists, raises
        ``ImproperlyConfiguredException``.
        """

        if (parameter.name, parameter.param_in) not in self._parameters:
            # because we are defining routes as unique per path, we have to handle here a situation when there is an optional
            # path parameter. e.g. get(path=["/", "/{param:str}"]). When parsing the parameter for path, the route handler
            # would still have a kwarg called param:
            # def handler(param: str | None) -> ...
            if parameter.param_in != ParamType.QUERY or all(
                f"{{{parameter.name}:" not in path for path in self.route_handler.paths
            ):
                self._parameters[(parameter.name, parameter.param_in)] = parameter
            return

        pre_existing = self._parameters[(parameter.name, parameter.param_in)]
        if parameter == pre_existing:
            return

        raise ImproperlyConfiguredException(
            f"OpenAPI schema generation for handler `{self.route_handler}` detected multiple parameters named "
            f"'{parameter.name}' with different types."
        )

    def list(self) -> list[Parameter]:
        """Return a list of all ``Parameter``'s in the collection."""
        return list(self._parameters.values())


class ParameterFactory:
    """Factory for creating OpenAPI Parameters for a given route handler."""

    def __init__(
        self,
        context: OpenAPIContext,
        route_handler: BaseRouteHandler,
        path_parameters: tuple[PathParameterDefinition, ...],
    ) -> None:
        """Initialize ParameterFactory.

        Args:
            context: The OpenAPI context.
            route_handler: The route handler.
            path_parameters: The path parameters for the route.
        """
        self.context = context
        self.schema_creator = SchemaCreator.from_openapi_context(self.context, prefer_alias=True)
        self.route_handler = route_handler
        self.parameters = ParameterCollection(route_handler)
        self.dependency_providers = route_handler.resolve_dependencies()
        self.layered_parameters = route_handler.resolve_layered_parameters()
        self.path_parameters_names = {p.name for p in path_parameters}

    def create_parameter(self, field_definition: FieldDefinition, parameter_name: str) -> Parameter:
        """Create an OpenAPI Parameter instance for a field definition.

        Args:
            field_definition: The field definition.
            parameter_name: The name of the parameter.
        """

        result: Schema | Reference | None = None
        kwarg_definition = (
            field_definition.kwarg_definition if isinstance(field_definition.kwarg_definition, ParameterKwarg) else None
        )

        if parameter_name in self.path_parameters_names:
            param_in = ParamType.PATH
            is_required = True
            result = self.schema_creator.for_field_definition(field_definition)
        elif kwarg_definition and kwarg_definition.header:
            parameter_name = kwarg_definition.header
            param_in = ParamType.HEADER
            is_required = field_definition.is_required
        elif kwarg_definition and kwarg_definition.cookie:
            parameter_name = kwarg_definition.cookie
            param_in = ParamType.COOKIE
            is_required = field_definition.is_required
        else:
            is_required = field_definition.is_required
            param_in = ParamType.QUERY
            parameter_name = kwarg_definition.query if kwarg_definition and kwarg_definition.query else parameter_name

        if not result:
            result = self.schema_creator.for_field_definition(field_definition)

        schema = result if isinstance(result, Schema) else self.context.schema_registry.from_reference(result).schema

        examples_list = kwarg_definition.examples or [] if kwarg_definition else []
        examples = get_formatted_examples(field_definition, examples_list)

        return Parameter(
            description=schema.description,
            name=parameter_name,
            param_in=param_in,
            required=is_required,
            schema=result,
            examples=examples or None,
        )

    def get_layered_parameter(self, field_name: str, field_definition: FieldDefinition) -> Parameter:
        """Create a parameter for a field definition that has a KwargDefinition defined on the layers.

        Args:
            field_name: The name of the field.
            field_definition: The field definition.
        """
        layer_field = self.layered_parameters[field_name]

        field = field_definition if field_definition.is_parameter_field else layer_field
        default = layer_field.default if field_definition.has_default else field_definition.default
        annotation = field_definition.annotation if field_definition is not Empty else layer_field.annotation

        parameter_name = field_name
        if isinstance(field.kwarg_definition, ParameterKwarg):
            parameter_name = (
                field.kwarg_definition.query
                or field.kwarg_definition.header
                or field.kwarg_definition.cookie
                or field_name
            )

        field_definition = FieldDefinition.from_kwarg(
            inner_types=field.inner_types,
            default=default,
            extra=field.extra,
            annotation=annotation,
            kwarg_definition=field.kwarg_definition,
            name=field_name,
        )
        return self.create_parameter(field_definition=field_definition, parameter_name=parameter_name)

    def create_parameters_for_field_definitions(self, fields: dict[str, FieldDefinition]) -> None:
        """Add Parameter models to the handler's collection for the given field definitions.

        Args:
            fields: The field definitions.
        """
        unique_handler_fields = (
            (k, v) for k, v in fields.items() if k not in RESERVED_KWARGS and k not in self.layered_parameters
        )
        unique_layered_fields = (
            (k, v) for k, v in self.layered_parameters.items() if k not in RESERVED_KWARGS and k not in fields
        )
        intersection_fields = (
            (k, v) for k, v in fields.items() if k not in RESERVED_KWARGS and k in self.layered_parameters
        )

        for field_name, field_definition in unique_handler_fields:
            if (
                isinstance(field_definition.kwarg_definition, DependencyKwarg)
                and field_name not in self.dependency_providers
            ):
                # never document explicit dependencies
                continue

            if provider := self.dependency_providers.get(field_name):
                self.create_parameters_for_field_definitions(fields=provider.parsed_fn_signature.parameters)
            else:
                self.parameters.add(self.create_parameter(field_definition=field_definition, parameter_name=field_name))

        for field_name, field_definition in unique_layered_fields:
            self.parameters.add(self.create_parameter(field_definition=field_definition, parameter_name=field_name))

        for field_name, field_definition in intersection_fields:
            self.parameters.add(self.get_layered_parameter(field_name=field_name, field_definition=field_definition))

    def create_parameters_for_handler(self) -> list[Parameter]:
        """Create a list of path/query/header Parameter models for the given PathHandler."""
        handler_fields = self.route_handler.parsed_fn_signature.parameters
        self.create_parameters_for_field_definitions(handler_fields)
        return self.parameters.list()


def create_parameters_for_handler(
    context: OpenAPIContext,
    route_handler: BaseRouteHandler,
    path_parameters: tuple[PathParameterDefinition, ...],
) -> list[Parameter]:
    """Create a list of path/query/header Parameter models for the given PathHandler."""
    factory = ParameterFactory(
        context=context,
        route_handler=route_handler,
        path_parameters=path_parameters,
    )
    return factory.create_parameters_for_handler()