1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
|
# Copyright Amethyst Reese
# Licensed under the MIT license
import sqlite3
from collections.abc import AsyncIterator, Iterable
from typing import Any, Callable, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .core import Connection
class Cursor:
def __init__(self, conn: "Connection", cursor: sqlite3.Cursor) -> None:
self.iter_chunk_size = conn._iter_chunk_size
self._conn = conn
self._cursor = cursor
def __aiter__(self) -> AsyncIterator[sqlite3.Row]:
"""The cursor proxy is also an async iterator."""
return self._fetch_chunked()
async def _fetch_chunked(self):
while True:
rows = await self.fetchmany(self.iter_chunk_size)
if not rows:
return
for row in rows:
yield row
async def _execute(self, fn, *args, **kwargs):
"""Execute the given function on the shared connection's thread."""
return await self._conn._execute(fn, *args, **kwargs)
async def execute(
self, sql: str, parameters: Optional[Iterable[Any]] = None
) -> "Cursor":
"""Execute the given query."""
if parameters is None:
parameters = []
await self._execute(self._cursor.execute, sql, parameters)
return self
async def executemany(
self, sql: str, parameters: Iterable[Iterable[Any]]
) -> "Cursor":
"""Execute the given multiquery."""
await self._execute(self._cursor.executemany, sql, parameters)
return self
async def executescript(self, sql_script: str) -> "Cursor":
"""Execute a user script."""
await self._execute(self._cursor.executescript, sql_script)
return self
async def fetchone(self) -> Optional[sqlite3.Row]:
"""Fetch a single row."""
return await self._execute(self._cursor.fetchone)
async def fetchmany(self, size: Optional[int] = None) -> Iterable[sqlite3.Row]:
"""Fetch up to `cursor.arraysize` number of rows."""
args: tuple[int, ...] = ()
if size is not None:
args = (size,)
return await self._execute(self._cursor.fetchmany, *args)
async def fetchall(self) -> Iterable[sqlite3.Row]:
"""Fetch all remaining rows."""
return await self._execute(self._cursor.fetchall)
async def close(self) -> None:
"""Close the cursor."""
await self._execute(self._cursor.close)
@property
def rowcount(self) -> int:
return self._cursor.rowcount
@property
def lastrowid(self) -> Optional[int]:
return self._cursor.lastrowid
@property
def arraysize(self) -> int:
return self._cursor.arraysize
@arraysize.setter
def arraysize(self, value: int) -> None:
self._cursor.arraysize = value
@property
def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...]:
return self._cursor.description
@property
def row_factory(self) -> Optional[Callable[[sqlite3.Cursor, sqlite3.Row], object]]:
return self._cursor.row_factory
@row_factory.setter
def row_factory(self, factory: Optional[type]) -> None:
self._cursor.row_factory = factory
@property
def connection(self) -> sqlite3.Connection:
return self._cursor.connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
|