1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
|
# flake8: builtins=reveal_type
from __future__ import annotations
from typing import Any
from dataclasses import dataclass
from collections.abc import Callable, Sequence
from psycopg import AsyncConnection, AsyncCursor, AsyncServerCursor, Connection, Cursor
from psycopg import ServerCursor, connect, rows
def int_row_factory(
cursor: Cursor[Any] | AsyncCursor[Any],
) -> Callable[[Sequence[int]], int]:
return lambda values: values[0] if values else 42
@dataclass
class Person:
name: str
address: str
@classmethod
def row_factory(
cls, cursor: Cursor[Any] | AsyncCursor[Any]
) -> Callable[[Sequence[str]], Person]:
def mkrow(values: Sequence[str]) -> Person:
name, address = values
return cls(name, address)
return mkrow
def kwargsf(*, foo: int, bar: int, baz: int) -> int:
return 42
def argsf(foo: int, bar: int, baz: int) -> float:
return 42.0
def check_row_factory_cursor() -> None:
"""Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
conn = connect()
cur1: Cursor[Any]
cur1 = conn.cursor()
r1: Any | None
r1 = cur1.fetchone()
r1 is not None
cur2: Cursor[int]
r2: int | None
with conn.cursor(row_factory=int_row_factory) as cur2:
cur2.execute("select 1")
r2 = cur2.fetchone()
r2 and r2 > 0
cur3: ServerCursor[Person]
persons: Sequence[Person]
with conn.cursor(name="s", row_factory=Person.row_factory) as cur3:
cur3.execute("select * from persons where name like 'al%'")
persons = cur3.fetchall()
persons[0].address
async def async_check_row_factory_cursor() -> None:
"""Type-check connection.cursor(..., row_factory=<MyRowFactory>) case."""
conn = await AsyncConnection.connect()
cur1: AsyncCursor[Any]
cur1 = conn.cursor()
r1: Any | None
r1 = await cur1.fetchone()
r1 is not None
cur2: AsyncCursor[int]
r2: int | None
async with conn.cursor(row_factory=int_row_factory) as cur2:
await cur2.execute("select 1")
r2 = await cur2.fetchone()
r2 and r2 > 0
cur3: AsyncServerCursor[Person]
persons: Sequence[Person]
async with conn.cursor(name="s", row_factory=Person.row_factory) as cur3:
await cur3.execute("select * from persons where name like 'al%'")
persons = await cur3.fetchall()
persons[0].address
def check_row_factory_connection() -> None:
"""Type-check connect(..., row_factory=<MyRowFactory>) or
Connection.row_factory cases.
"""
conn1: Connection[int]
cur1: Cursor[int]
r1: int | None
conn1 = connect(row_factory=int_row_factory)
cur1 = conn1.execute("select 1")
r1 = cur1.fetchone()
r1 != 0
with conn1.cursor() as cur1:
cur1.execute("select 2")
conn2: Connection[Person]
cur2: Cursor[Person]
r2: Person | None
conn2 = connect(row_factory=Person.row_factory)
cur2 = conn2.execute("select * from persons")
r2 = cur2.fetchone()
r2 and r2.name
with conn2.cursor() as cur2:
cur2.execute("select 2")
cur3: Cursor[tuple[Any, ...]]
r3: tuple[Any, ...] | None
conn3 = connect()
cur3 = conn3.execute("select 3")
with conn3.cursor() as cur3:
cur3.execute("select 42")
r3 = cur3.fetchone()
r3 and len(r3)
async def async_check_row_factory_connection() -> None:
"""Type-check connect(..., row_factory=<MyRowFactory>) or
Connection.row_factory cases.
"""
conn1: AsyncConnection[int]
cur1: AsyncCursor[int]
r1: int | None
conn1 = await AsyncConnection.connect(row_factory=int_row_factory)
cur1 = await conn1.execute("select 1")
r1 = await cur1.fetchone()
r1 != 0
async with conn1.cursor() as cur1:
await cur1.execute("select 2")
conn2: AsyncConnection[Person]
cur2: AsyncCursor[Person]
r2: Person | None
conn2 = await AsyncConnection.connect(row_factory=Person.row_factory)
cur2 = await conn2.execute("select * from persons")
r2 = await cur2.fetchone()
r2 and r2.name
async with conn2.cursor() as cur2:
await cur2.execute("select 2")
cur3: AsyncCursor[tuple[Any, ...]]
r3: tuple[Any, ...] | None
conn3 = await AsyncConnection.connect()
cur3 = await conn3.execute("select 3")
async with conn3.cursor() as cur3:
await cur3.execute("select 42")
r3 = await cur3.fetchone()
r3 and len(r3)
def check_row_factories() -> None:
conn1 = connect(row_factory=rows.tuple_row)
v1: tuple[Any, ...] = conn1.execute("").fetchall()[0]
conn2 = connect(row_factory=rows.dict_row)
v2: dict[str, Any] = conn2.execute("").fetchall()[0]
conn3 = connect(row_factory=rows.class_row(Person))
v3: Person = conn3.execute("").fetchall()[0]
conn4 = connect(row_factory=rows.args_row(argsf))
v4: float = conn4.execute("").fetchall()[0]
conn5 = connect(row_factory=rows.kwargs_row(kwargsf))
v5: int = conn5.execute("").fetchall()[0]
v1, v2, v3, v4, v5
|