summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/di.py
diff options
context:
space:
mode:
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/di.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/di.py117
1 files changed, 117 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/di.py b/venv/lib/python3.11/site-packages/litestar/di.py
new file mode 100644
index 0000000..066a128
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/di.py
@@ -0,0 +1,117 @@
+from __future__ import annotations
+
+from inspect import isasyncgenfunction, isclass, isgeneratorfunction
+from typing import TYPE_CHECKING, Any
+
+from litestar.exceptions import ImproperlyConfiguredException
+from litestar.types import Empty
+from litestar.utils import ensure_async_callable
+from litestar.utils.predicates import is_async_callable
+from litestar.utils.warnings import (
+ warn_implicit_sync_to_thread,
+ warn_sync_to_thread_with_async_callable,
+ warn_sync_to_thread_with_generator,
+)
+
+if TYPE_CHECKING:
+ from litestar._signature import SignatureModel
+ from litestar.types import AnyCallable
+ from litestar.utils.signature import ParsedSignature
+
+__all__ = ("Provide",)
+
+
+class Provide:
+ """Wrapper class for dependency injection"""
+
+ __slots__ = (
+ "dependency",
+ "has_sync_callable",
+ "has_sync_generator_dependency",
+ "has_async_generator_dependency",
+ "parsed_fn_signature",
+ "signature_model",
+ "sync_to_thread",
+ "use_cache",
+ "value",
+ )
+
+ parsed_fn_signature: ParsedSignature
+ signature_model: type[SignatureModel]
+ dependency: AnyCallable
+
+ def __init__(
+ self,
+ dependency: AnyCallable | type[Any],
+ use_cache: bool = False,
+ sync_to_thread: bool | None = None,
+ ) -> None:
+ """Initialize ``Provide``
+
+ Args:
+ dependency: Callable to call or class to instantiate. The result is then injected as a dependency.
+ use_cache: Cache the dependency return value. Defaults to False.
+ sync_to_thread: Run sync code in an async thread. Defaults to False.
+ """
+ if not callable(dependency):
+ raise ImproperlyConfiguredException("Provider dependency must a callable value")
+
+ is_class_dependency = isclass(dependency)
+ self.has_sync_generator_dependency = isgeneratorfunction(
+ dependency if not is_class_dependency else dependency.__call__ # type: ignore[operator]
+ )
+ self.has_async_generator_dependency = isasyncgenfunction(
+ dependency if not is_class_dependency else dependency.__call__ # type: ignore[operator]
+ )
+ has_generator_dependency = self.has_sync_generator_dependency or self.has_async_generator_dependency
+
+ if has_generator_dependency and use_cache:
+ raise ImproperlyConfiguredException(
+ "Cannot cache generator dependency, consider using Lifespan Context instead."
+ )
+
+ has_sync_callable = is_class_dependency or not is_async_callable(dependency) # pyright: ignore
+
+ if sync_to_thread is not None:
+ if has_generator_dependency:
+ warn_sync_to_thread_with_generator(dependency, stacklevel=3) # type: ignore[arg-type]
+ elif not has_sync_callable:
+ warn_sync_to_thread_with_async_callable(dependency, stacklevel=3) # pyright: ignore
+ elif has_sync_callable and not has_generator_dependency:
+ warn_implicit_sync_to_thread(dependency, stacklevel=3) # pyright: ignore
+
+ if sync_to_thread and has_sync_callable:
+ self.dependency = ensure_async_callable(dependency) # pyright: ignore
+ self.has_sync_callable = False
+ else:
+ self.dependency = dependency # pyright: ignore
+ self.has_sync_callable = has_sync_callable
+
+ self.sync_to_thread = bool(sync_to_thread)
+ self.use_cache = use_cache
+ self.value: Any = Empty
+
+ async def __call__(self, **kwargs: Any) -> Any:
+ """Call the provider's dependency."""
+
+ if self.use_cache and self.value is not Empty:
+ return self.value
+
+ if self.has_sync_callable:
+ value = self.dependency(**kwargs)
+ else:
+ value = await self.dependency(**kwargs)
+
+ if self.use_cache:
+ self.value = value
+
+ return value
+
+ def __eq__(self, other: Any) -> bool:
+ # check if memory address is identical, otherwise compare attributes
+ return other is self or (
+ isinstance(other, self.__class__)
+ and other.dependency == self.dependency
+ and other.use_cache == self.use_cache
+ and other.value == self.value
+ )