summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/plugins/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/plugins/base.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/plugins/base.py323
1 files changed, 323 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/plugins/base.py b/venv/lib/python3.11/site-packages/litestar/plugins/base.py
new file mode 100644
index 0000000..afc571e
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/plugins/base.py
@@ -0,0 +1,323 @@
+from __future__ import annotations
+
+import abc
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, Any, Iterator, Protocol, TypeVar, Union, cast, runtime_checkable
+
+if TYPE_CHECKING:
+ from inspect import Signature
+
+ from click import Group
+
+ from litestar._openapi.schema_generation import SchemaCreator
+ from litestar.app import Litestar
+ from litestar.config.app import AppConfig
+ from litestar.dto import AbstractDTO
+ from litestar.openapi.spec import Schema
+ from litestar.routes import BaseRoute
+ from litestar.typing import FieldDefinition
+
+__all__ = (
+ "SerializationPluginProtocol",
+ "InitPluginProtocol",
+ "OpenAPISchemaPluginProtocol",
+ "OpenAPISchemaPlugin",
+ "PluginProtocol",
+ "CLIPlugin",
+ "CLIPluginProtocol",
+ "PluginRegistry",
+ "DIPlugin",
+)
+
+
+@runtime_checkable
+class InitPluginProtocol(Protocol):
+ """Protocol used to define plugins that affect the application's init process."""
+
+ __slots__ = ()
+
+ def on_app_init(self, app_config: AppConfig) -> AppConfig:
+ """Receive the :class:`AppConfig<.config.app.AppConfig>` instance after `on_app_init` hooks have been called.
+
+ Examples:
+ .. code-block:: python
+
+ from litestar import Litestar, get
+ from litestar.di import Provide
+ from litestar.plugins import InitPluginProtocol
+
+
+ def get_name() -> str:
+ return "world"
+
+
+ @get("/my-path")
+ def my_route_handler(name: str) -> dict[str, str]:
+ return {"hello": name}
+
+
+ class MyPlugin(InitPluginProtocol):
+ def on_app_init(self, app_config: AppConfig) -> AppConfig:
+ app_config.dependencies["name"] = Provide(get_name)
+ app_config.route_handlers.append(my_route_handler)
+ return app_config
+
+
+ app = Litestar(plugins=[MyPlugin()])
+
+ Args:
+ app_config: The :class:`AppConfig <litestar.config.app.AppConfig>` instance.
+
+ Returns:
+ The app config object.
+ """
+ return app_config # pragma: no cover
+
+
+class ReceiveRoutePlugin:
+ """Receive routes as they are added to the application."""
+
+ __slots__ = ()
+
+ def receive_route(self, route: BaseRoute) -> None:
+ """Receive routes as they are registered on an application."""
+
+
+@runtime_checkable
+class CLIPluginProtocol(Protocol):
+ """Plugin protocol to extend the CLI."""
+
+ __slots__ = ()
+
+ def on_cli_init(self, cli: Group) -> None:
+ """Called when the CLI is initialized.
+
+ This can be used to extend or override existing commands.
+
+ Args:
+ cli: The root :class:`click.Group` of the Litestar CLI
+
+ Examples:
+ .. code-block:: python
+
+ from litestar import Litestar
+ from litestar.plugins import CLIPluginProtocol
+ from click import Group
+
+
+ class CLIPlugin(CLIPluginProtocol):
+ def on_cli_init(self, cli: Group) -> None:
+ @cli.command()
+ def is_debug_mode(app: Litestar):
+ print(app.debug)
+
+
+ app = Litestar(plugins=[CLIPlugin()])
+ """
+
+
+class CLIPlugin(CLIPluginProtocol):
+ """Plugin protocol to extend the CLI Server Lifespan."""
+
+ __slots__ = ()
+
+ def on_cli_init(self, cli: Group) -> None:
+ return super().on_cli_init(cli)
+
+ @contextmanager
+ def server_lifespan(self, app: Litestar) -> Iterator[None]:
+ yield
+
+
+@runtime_checkable
+class SerializationPluginProtocol(Protocol):
+ """Protocol used to define a serialization plugin for DTOs."""
+
+ __slots__ = ()
+
+ def supports_type(self, field_definition: FieldDefinition) -> bool:
+ """Given a value of indeterminate type, determine if this value is supported by the plugin.
+
+ Args:
+ field_definition: A parsed type.
+
+ Returns:
+ Whether the type is supported by the plugin.
+ """
+ raise NotImplementedError()
+
+ def create_dto_for_type(self, field_definition: FieldDefinition) -> type[AbstractDTO]:
+ """Given a parsed type, create a DTO class.
+
+ Args:
+ field_definition: A parsed type.
+
+ Returns:
+ A DTO class.
+ """
+ raise NotImplementedError()
+
+
+class DIPlugin(abc.ABC):
+ """Extend dependency injection"""
+
+ @abc.abstractmethod
+ def has_typed_init(self, type_: Any) -> bool:
+ """Return ``True`` if ``type_`` has type information available for its
+ :func:`__init__` method that cannot be extracted from this method's type
+ annotations (e.g. a Pydantic BaseModel subclass), and
+ :meth:`DIPlugin.get_typed_init` supports extraction of these annotations.
+ """
+ ...
+
+ @abc.abstractmethod
+ def get_typed_init(self, type_: Any) -> tuple[Signature, dict[str, Any]]:
+ r"""Return signature and type information about the ``type_``\ s :func:`__init__`
+ method.
+ """
+ ...
+
+
+@runtime_checkable
+class OpenAPISchemaPluginProtocol(Protocol):
+ """Plugin protocol to extend the support of OpenAPI schema generation for non-library types."""
+
+ __slots__ = ()
+
+ @staticmethod
+ def is_plugin_supported_type(value: Any) -> bool:
+ """Given a value of indeterminate type, determine if this value is supported by the plugin.
+
+ Args:
+ value: An arbitrary value.
+
+ Returns:
+ A typeguard dictating whether the value is supported by the plugin.
+ """
+ raise NotImplementedError()
+
+ def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema:
+ """Given a type annotation, transform it into an OpenAPI schema class.
+
+ Args:
+ field_definition: An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance.
+ schema_creator: An instance of the openapi SchemaCreator.
+
+ Returns:
+ An :class:`OpenAPI <litestar.openapi.spec.schema.Schema>` instance.
+ """
+ raise NotImplementedError()
+
+
+class OpenAPISchemaPlugin(OpenAPISchemaPluginProtocol):
+ """Plugin to extend the support of OpenAPI schema generation for non-library types."""
+
+ @staticmethod
+ def is_plugin_supported_type(value: Any) -> bool:
+ """Given a value of indeterminate type, determine if this value is supported by the plugin.
+
+ This is called by the default implementation of :meth:`is_plugin_supported_field` for
+ backwards compatibility. User's should prefer to override that method instead.
+
+ Args:
+ value: An arbitrary value.
+
+ Returns:
+ A bool indicating whether the value is supported by the plugin.
+ """
+ raise NotImplementedError(
+ "One of either is_plugin_supported_type or is_plugin_supported_field should be defined. "
+ "The default implementation of is_plugin_supported_field calls is_plugin_supported_type "
+ "for backwards compatibility. Users should prefer to override is_plugin_supported_field "
+ "as it receives a 'FieldDefinition' instance which is more useful than a raw type."
+ )
+
+ def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool:
+ """Given a :class:`FieldDefinition <litestar.typing.FieldDefinition>` that represents an indeterminate type,
+ determine if this value is supported by the plugin
+
+ Args:
+ field_definition: A parsed type.
+
+ Returns:
+ Whether the type is supported by the plugin.
+ """
+ return self.is_plugin_supported_type(field_definition.annotation)
+
+ @staticmethod
+ def is_undefined_sentinel(value: Any) -> bool:
+ """Return ``True`` if ``value`` should be treated as an undefined field"""
+ return False
+
+ @staticmethod
+ def is_constrained_field(field_definition: FieldDefinition) -> bool:
+ """Return ``True`` if the field should be treated as constrained. If returning
+ ``True``, constraints should be defined in the field's extras
+ """
+ return False
+
+
+PluginProtocol = Union[
+ CLIPlugin,
+ CLIPluginProtocol,
+ InitPluginProtocol,
+ OpenAPISchemaPlugin,
+ OpenAPISchemaPluginProtocol,
+ ReceiveRoutePlugin,
+ SerializationPluginProtocol,
+ DIPlugin,
+]
+
+PluginT = TypeVar("PluginT", bound=PluginProtocol)
+
+
+class PluginRegistry:
+ __slots__ = {
+ "init": "Plugins that implement the InitPluginProtocol",
+ "openapi": "Plugins that implement the OpenAPISchemaPluginProtocol",
+ "receive_route": "ReceiveRoutePlugin instances",
+ "serialization": "Plugins that implement the SerializationPluginProtocol",
+ "cli": "Plugins that implement the CLIPluginProtocol",
+ "di": "DIPlugin instances",
+ "_plugins_by_type": None,
+ "_plugins": None,
+ "_get_plugins_of_type": None,
+ }
+
+ def __init__(self, plugins: list[PluginProtocol]) -> None:
+ self._plugins_by_type = {type(p): p for p in plugins}
+ self._plugins = frozenset(plugins)
+ self.init = tuple(p for p in plugins if isinstance(p, InitPluginProtocol))
+ self.openapi = tuple(p for p in plugins if isinstance(p, OpenAPISchemaPluginProtocol))
+ self.receive_route = tuple(p for p in plugins if isinstance(p, ReceiveRoutePlugin))
+ self.serialization = tuple(p for p in plugins if isinstance(p, SerializationPluginProtocol))
+ self.cli = tuple(p for p in plugins if isinstance(p, CLIPluginProtocol))
+ self.di = tuple(p for p in plugins if isinstance(p, DIPlugin))
+
+ def get(self, type_: type[PluginT] | str) -> PluginT:
+ """Return the registered plugin of ``type_``.
+
+ This should be used with subclasses of the plugin protocols.
+ """
+ if isinstance(type_, str):
+ for plugin in self._plugins:
+ _name = plugin.__class__.__name__
+ _module = plugin.__class__.__module__
+ _qualname = (
+ f"{_module}.{plugin.__class__.__qualname__}"
+ if _module is not None and _module != "__builtin__"
+ else plugin.__class__.__qualname__
+ )
+ if type_ in {_name, _qualname}:
+ return cast(PluginT, plugin)
+ raise KeyError(f"No plugin of type {type_!r} registered")
+ try:
+ return cast(PluginT, self._plugins_by_type[type_]) # type: ignore[index]
+ except KeyError as e:
+ raise KeyError(f"No plugin of type {type_.__name__!r} registered") from e
+
+ def __iter__(self) -> Iterator[PluginProtocol]:
+ return iter(self._plugins)
+
+ def __contains__(self, item: PluginProtocol) -> bool:
+ return item in self._plugins