summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/aiosqlite/context.py
blob: 316845fbaf27b13df31d568dc8844906bc7b3b6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright 2018
# Licensed under the MIT license


from functools import wraps
from typing import Any, AsyncContextManager, Callable, Coroutine, Generator, TypeVar

from .cursor import Cursor

_T = TypeVar("_T")


class Result(AsyncContextManager[_T], Coroutine[Any, Any, _T]):
    __slots__ = ("_coro", "_obj")

    def __init__(self, coro: Coroutine[Any, Any, _T]):
        self._coro = coro
        self._obj: _T

    def send(self, value) -> None:
        return self._coro.send(value)

    def throw(self, typ, val=None, tb=None) -> None:
        if val is None:
            return self._coro.throw(typ)

        if tb is None:
            return self._coro.throw(typ, val)

        return self._coro.throw(typ, val, tb)

    def close(self) -> None:
        return self._coro.close()

    def __await__(self) -> Generator[Any, None, _T]:
        return self._coro.__await__()

    async def __aenter__(self) -> _T:
        self._obj = await self._coro
        return self._obj

    async def __aexit__(self, exc_type, exc, tb) -> None:
        if isinstance(self._obj, Cursor):
            await self._obj.close()


def contextmanager(
    method: Callable[..., Coroutine[Any, Any, _T]]
) -> Callable[..., Result[_T]]:
    @wraps(method)
    def wrapper(self, *args, **kwargs) -> Result[_T]:
        return Result(method(self, *args, **kwargs))

    return wrapper