summaryrefslogtreecommitdiff
path: root/venv/lib/python3.11/site-packages/aiosqlite/core.py
diff options
context:
space:
mode:
authorcyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
committercyfraeviolae <cyfraeviolae>2024-04-03 03:10:44 -0400
commit6d7ba58f880be618ade07f8ea080fe8c4bf8a896 (patch)
treeb1c931051ffcebd2bd9d61d98d6233ffa289bbce /venv/lib/python3.11/site-packages/aiosqlite/core.py
parent4f884c9abc32990b4061a1bb6997b4b37e58ea0b (diff)
venv
Diffstat (limited to 'venv/lib/python3.11/site-packages/aiosqlite/core.py')
-rw-r--r--venv/lib/python3.11/site-packages/aiosqlite/core.py394
1 files changed, 394 insertions, 0 deletions
diff --git a/venv/lib/python3.11/site-packages/aiosqlite/core.py b/venv/lib/python3.11/site-packages/aiosqlite/core.py
new file mode 100644
index 0000000..58c3fec
--- /dev/null
+++ b/venv/lib/python3.11/site-packages/aiosqlite/core.py
@@ -0,0 +1,394 @@
+# Copyright 2022 Amethyst Reese
+# Licensed under the MIT license
+
+"""
+Core implementation of aiosqlite proxies
+"""
+
+import asyncio
+import logging
+import sqlite3
+from functools import partial
+from pathlib import Path
+from queue import Empty, Queue, SimpleQueue
+from threading import Thread
+from typing import (
+ Any,
+ AsyncIterator,
+ Callable,
+ Generator,
+ Iterable,
+ Literal,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
+from warnings import warn
+
+from .context import contextmanager
+from .cursor import Cursor
+
+__all__ = ["connect", "Connection", "Cursor"]
+
+LOG = logging.getLogger("aiosqlite")
+
+
+IsolationLevel = Optional[Literal["DEFERRED", "IMMEDIATE", "EXCLUSIVE"]]
+
+
+def set_result(fut: asyncio.Future, result: Any) -> None:
+ """Set the result of a future if it hasn't been set already."""
+ if not fut.done():
+ fut.set_result(result)
+
+
+def set_exception(fut: asyncio.Future, e: BaseException) -> None:
+ """Set the exception of a future if it hasn't been set already."""
+ if not fut.done():
+ fut.set_exception(e)
+
+
+_STOP_RUNNING_SENTINEL = object()
+
+
+class Connection(Thread):
+ def __init__(
+ self,
+ connector: Callable[[], sqlite3.Connection],
+ iter_chunk_size: int,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ ) -> None:
+ super().__init__()
+ self._running = True
+ self._connection: Optional[sqlite3.Connection] = None
+ self._connector = connector
+ self._tx: SimpleQueue[Tuple[asyncio.Future, Callable[[], Any]]] = SimpleQueue()
+ self._iter_chunk_size = iter_chunk_size
+
+ if loop is not None:
+ warn(
+ "aiosqlite.Connection no longer uses the `loop` parameter",
+ DeprecationWarning,
+ )
+
+ def _stop_running(self):
+ self._running = False
+ # PEP 661 is not accepted yet, so we cannot type a sentinel
+ self._tx.put_nowait(_STOP_RUNNING_SENTINEL) # type: ignore[arg-type]
+
+ @property
+ def _conn(self) -> sqlite3.Connection:
+ if self._connection is None:
+ raise ValueError("no active connection")
+
+ return self._connection
+
+ def _execute_insert(self, sql: str, parameters: Any) -> Optional[sqlite3.Row]:
+ cursor = self._conn.execute(sql, parameters)
+ cursor.execute("SELECT last_insert_rowid()")
+ return cursor.fetchone()
+
+ def _execute_fetchall(self, sql: str, parameters: Any) -> Iterable[sqlite3.Row]:
+ cursor = self._conn.execute(sql, parameters)
+ return cursor.fetchall()
+
+ def run(self) -> None:
+ """
+ Execute function calls on a separate thread.
+
+ :meta private:
+ """
+ while True:
+ # Continues running until all queue items are processed,
+ # even after connection is closed (so we can finalize all
+ # futures)
+
+ tx_item = self._tx.get()
+ if tx_item is _STOP_RUNNING_SENTINEL:
+ break
+
+ future, function = tx_item
+
+ try:
+ LOG.debug("executing %s", function)
+ result = function()
+ LOG.debug("operation %s completed", function)
+ future.get_loop().call_soon_threadsafe(set_result, future, result)
+ except BaseException as e: # noqa B036
+ LOG.debug("returning exception %s", e)
+ future.get_loop().call_soon_threadsafe(set_exception, future, e)
+
+ async def _execute(self, fn, *args, **kwargs):
+ """Queue a function with the given arguments for execution."""
+ if not self._running or not self._connection:
+ raise ValueError("Connection closed")
+
+ function = partial(fn, *args, **kwargs)
+ future = asyncio.get_event_loop().create_future()
+
+ self._tx.put_nowait((future, function))
+
+ return await future
+
+ async def _connect(self) -> "Connection":
+ """Connect to the actual sqlite database."""
+ if self._connection is None:
+ try:
+ future = asyncio.get_event_loop().create_future()
+ self._tx.put_nowait((future, self._connector))
+ self._connection = await future
+ except Exception:
+ self._stop_running()
+ self._connection = None
+ raise
+
+ return self
+
+ def __await__(self) -> Generator[Any, None, "Connection"]:
+ self.start()
+ return self._connect().__await__()
+
+ async def __aenter__(self) -> "Connection":
+ return await self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
+ await self.close()
+
+ @contextmanager
+ async def cursor(self) -> Cursor:
+ """Create an aiosqlite cursor wrapping a sqlite3 cursor object."""
+ return Cursor(self, await self._execute(self._conn.cursor))
+
+ async def commit(self) -> None:
+ """Commit the current transaction."""
+ await self._execute(self._conn.commit)
+
+ async def rollback(self) -> None:
+ """Roll back the current transaction."""
+ await self._execute(self._conn.rollback)
+
+ async def close(self) -> None:
+ """Complete queued queries/cursors and close the connection."""
+
+ if self._connection is None:
+ return
+
+ try:
+ await self._execute(self._conn.close)
+ except Exception:
+ LOG.info("exception occurred while closing connection")
+ raise
+ finally:
+ self._stop_running()
+ self._connection = None
+
+ @contextmanager
+ async def execute(
+ self, sql: str, parameters: Optional[Iterable[Any]] = None
+ ) -> Cursor:
+ """Helper to create a cursor and execute the given query."""
+ if parameters is None:
+ parameters = []
+ cursor = await self._execute(self._conn.execute, sql, parameters)
+ return Cursor(self, cursor)
+
+ @contextmanager
+ async def execute_insert(
+ self, sql: str, parameters: Optional[Iterable[Any]] = None
+ ) -> Optional[sqlite3.Row]:
+ """Helper to insert and get the last_insert_rowid."""
+ if parameters is None:
+ parameters = []
+ return await self._execute(self._execute_insert, sql, parameters)
+
+ @contextmanager
+ async def execute_fetchall(
+ self, sql: str, parameters: Optional[Iterable[Any]] = None
+ ) -> Iterable[sqlite3.Row]:
+ """Helper to execute a query and return all the data."""
+ if parameters is None:
+ parameters = []
+ return await self._execute(self._execute_fetchall, sql, parameters)
+
+ @contextmanager
+ async def executemany(
+ self, sql: str, parameters: Iterable[Iterable[Any]]
+ ) -> Cursor:
+ """Helper to create a cursor and execute the given multiquery."""
+ cursor = await self._execute(self._conn.executemany, sql, parameters)
+ return Cursor(self, cursor)
+
+ @contextmanager
+ async def executescript(self, sql_script: str) -> Cursor:
+ """Helper to create a cursor and execute a user script."""
+ cursor = await self._execute(self._conn.executescript, sql_script)
+ return Cursor(self, cursor)
+
+ async def interrupt(self) -> None:
+ """Interrupt pending queries."""
+ return self._conn.interrupt()
+
+ async def create_function(
+ self, name: str, num_params: int, func: Callable, deterministic: bool = False
+ ) -> None:
+ """
+ Create user-defined function that can be later used
+ within SQL statements. Must be run within the same thread
+ that query executions take place so instead of executing directly
+ against the connection, we defer this to `run` function.
+
+ If ``deterministic`` is true, the created function is marked as deterministic,
+ which allows SQLite to perform additional optimizations. This flag is supported
+ by SQLite 3.8.3 or higher, ``NotSupportedError`` will be raised if used with
+ older versions.
+ """
+ await self._execute(
+ self._conn.create_function,
+ name,
+ num_params,
+ func,
+ deterministic=deterministic,
+ )
+
+ @property
+ def in_transaction(self) -> bool:
+ return self._conn.in_transaction
+
+ @property
+ def isolation_level(self) -> Optional[str]:
+ return self._conn.isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value: IsolationLevel) -> None:
+ self._conn.isolation_level = value
+
+ @property
+ def row_factory(self) -> Optional[Type]:
+ return self._conn.row_factory
+
+ @row_factory.setter
+ def row_factory(self, factory: Optional[Type]) -> None:
+ self._conn.row_factory = factory
+
+ @property
+ def text_factory(self) -> Callable[[bytes], Any]:
+ return self._conn.text_factory
+
+ @text_factory.setter
+ def text_factory(self, factory: Callable[[bytes], Any]) -> None:
+ self._conn.text_factory = factory
+
+ @property
+ def total_changes(self) -> int:
+ return self._conn.total_changes
+
+ async def enable_load_extension(self, value: bool) -> None:
+ await self._execute(self._conn.enable_load_extension, value) # type: ignore
+
+ async def load_extension(self, path: str):
+ await self._execute(self._conn.load_extension, path) # type: ignore
+
+ async def set_progress_handler(
+ self, handler: Callable[[], Optional[int]], n: int
+ ) -> None:
+ await self._execute(self._conn.set_progress_handler, handler, n)
+
+ async def set_trace_callback(self, handler: Callable) -> None:
+ await self._execute(self._conn.set_trace_callback, handler)
+
+ async def iterdump(self) -> AsyncIterator[str]:
+ """
+ Return an async iterator to dump the database in SQL text format.
+
+ Example::
+
+ async for line in db.iterdump():
+ ...
+
+ """
+ dump_queue: Queue = Queue()
+
+ def dumper():
+ try:
+ for line in self._conn.iterdump():
+ dump_queue.put_nowait(line)
+ dump_queue.put_nowait(None)
+
+ except Exception:
+ LOG.exception("exception while dumping db")
+ dump_queue.put_nowait(None)
+ raise
+
+ fut = self._execute(dumper)
+ task = asyncio.ensure_future(fut)
+
+ while True:
+ try:
+ line: Optional[str] = dump_queue.get_nowait()
+ if line is None:
+ break
+ yield line
+
+ except Empty:
+ if task.done():
+ LOG.warning("iterdump completed unexpectedly")
+ break
+
+ await asyncio.sleep(0.01)
+
+ await task
+
+ async def backup(
+ self,
+ target: Union["Connection", sqlite3.Connection],
+ *,
+ pages: int = 0,
+ progress: Optional[Callable[[int, int, int], None]] = None,
+ name: str = "main",
+ sleep: float = 0.250,
+ ) -> None:
+ """
+ Make a backup of the current database to the target database.
+
+ Takes either a standard sqlite3 or aiosqlite Connection object as the target.
+ """
+ if isinstance(target, Connection):
+ target = target._conn
+
+ await self._execute(
+ self._conn.backup,
+ target,
+ pages=pages,
+ progress=progress,
+ name=name,
+ sleep=sleep,
+ )
+
+
+def connect(
+ database: Union[str, Path],
+ *,
+ iter_chunk_size=64,
+ loop: Optional[asyncio.AbstractEventLoop] = None,
+ **kwargs: Any,
+) -> Connection:
+ """Create and return a connection proxy to the sqlite database."""
+
+ if loop is not None:
+ warn(
+ "aiosqlite.connect() no longer uses the `loop` parameter",
+ DeprecationWarning,
+ )
+
+ def connector() -> sqlite3.Connection:
+ if isinstance(database, str):
+ loc = database
+ elif isinstance(database, bytes):
+ loc = database.decode("utf-8")
+ else:
+ loc = str(database)
+
+ return sqlite3.connect(loc, **kwargs)
+
+ return Connection(connector, iter_chunk_size)