summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/litestar/utils/sync.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/litestar/utils/sync.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/litestar/utils/sync.py')
-rw-r--r--venv/lib/python3.11/site-packages/litestar/utils/sync.py79
1 files changed, 79 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/litestar/utils/sync.py b/venv/lib/python3.11/site-packages/litestar/utils/sync.py
new file mode 100644
index 0000000..02acabf
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/litestar/utils/sync.py
@@ -0,0 +1,79 @@
+from __future__ import annotations
+
+from typing import (
+ AsyncGenerator,
+ Awaitable,
+ Callable,
+ Generic,
+ Iterable,
+ Iterator,
+ TypeVar,
+)
+
+from typing_extensions import ParamSpec
+
+from litestar.concurrency import sync_to_thread
+from litestar.utils.predicates import is_async_callable
+
+__all__ = ("ensure_async_callable", "AsyncIteratorWrapper", "AsyncCallable", "is_async_callable")
+
+
+P = ParamSpec("P")
+T = TypeVar("T")
+
+
+def ensure_async_callable(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]:
+ """Ensure that ``fn`` is an asynchronous callable.
+ If it is an asynchronous, return the original object, else wrap it in an
+ ``AsyncCallable``
+ """
+ if is_async_callable(fn):
+ return fn
+ return AsyncCallable(fn) # pyright: ignore
+
+
+class AsyncCallable:
+ """Wrap a given callable to be called in a thread pool using
+ ``anyio.to_thread.run_sync``, keeping a reference to the original callable as
+ :attr:`func`
+ """
+
+ def __init__(self, fn: Callable[P, T]) -> None: # pyright: ignore
+ self.func = fn
+
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]: # pyright: ignore
+ return sync_to_thread(self.func, *args, **kwargs) # pyright: ignore
+
+
+class AsyncIteratorWrapper(Generic[T]):
+ """Asynchronous generator, wrapping an iterable or iterator."""
+
+ __slots__ = ("iterator", "generator")
+
+ def __init__(self, iterator: Iterator[T] | Iterable[T]) -> None:
+ """Take a sync iterator or iterable and yields values from it asynchronously.
+
+ Args:
+ iterator: A sync iterator or iterable.
+ """
+ self.iterator = iterator if isinstance(iterator, Iterator) else iter(iterator)
+ self.generator = self._async_generator()
+
+ def _call_next(self) -> T:
+ try:
+ return next(self.iterator)
+ except StopIteration as e:
+ raise ValueError from e
+
+ async def _async_generator(self) -> AsyncGenerator[T, None]:
+ while True:
+ try:
+ yield await sync_to_thread(self._call_next)
+ except ValueError:
+ return
+
+ def __aiter__(self) -> AsyncIteratorWrapper[T]:
+ return self
+
+ async def __anext__(self) -> T:
+ return await self.generator.__anext__()