1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
|
import datetime as dt
import pytest
import psycopg
from psycopg import rows
from .fix_crdb import crdb_encoding
@pytest.fixture
async def aconn(aconn, anyio_backend):
aconn.cursor_factory = psycopg.AsyncClientCursor
return aconn
async def test_default_cursor(aconn):
cur = aconn.cursor()
assert type(cur) is psycopg.AsyncClientCursor
async def test_str(aconn):
cur = aconn.cursor()
assert "psycopg.%s" % psycopg.AsyncClientCursor.__name__ in str(cur)
async def test_from_cursor_factory(aconn_cls, dsn):
async with await aconn_cls.connect(
dsn, cursor_factory=psycopg.AsyncClientCursor
) as aconn:
cur = aconn.cursor()
assert type(cur) is psycopg.AsyncClientCursor
async def test_execute_many_results_param(aconn):
cur = aconn.cursor()
assert cur.nextset() is None
rv = await cur.execute("select %s; select generate_series(1, %s)", ("foo", 3))
assert rv is cur
assert (await cur.fetchall()) == [("foo",)]
assert cur.rowcount == 1
assert cur.nextset()
assert (await cur.fetchall()) == [(1,), (2,), (3,)]
assert cur.nextset() is None
await cur.close()
assert cur.nextset() is None
async def test_query_params_execute(aconn):
cur = aconn.cursor()
assert cur._query is None
await cur.execute("select %t, %s::text", [1, None])
assert cur._query is not None
assert cur._query.query == b"select 1, NULL::text"
assert cur._query.params == (b"1", b"NULL")
await cur.execute("select 1")
assert cur._query.query == b"select 1"
assert not cur._query.params
with pytest.raises(psycopg.DataError):
await cur.execute("select %t::int", ["wat"])
assert cur._query.query == b"select 'wat'::int"
assert cur._query.params == (b"'wat'",)
async def test_query_params_executemany(aconn):
cur = aconn.cursor()
await cur.executemany("select %t, %t", [[1, 2], [3, 4]])
assert cur._query.query == b"select 3, 4"
assert cur._query.params == (b"3", b"4")
@pytest.mark.slow
@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
async def test_leak(aconn_cls, dsn, faker, fetch, row_factory, gc):
faker.choose_schema(ncols=5)
faker.make_records(10)
row_factory = getattr(rows, row_factory)
async def work():
async with await aconn_cls.connect(dsn) as conn, conn.transaction(
force_rollback=True
):
async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur:
await cur.execute(faker.drop_stmt)
await cur.execute(faker.create_stmt)
async with faker.find_insert_problem_async(conn):
await cur.executemany(faker.insert_stmt, faker.records)
await cur.execute(faker.select_stmt)
if fetch == "one":
while (await cur.fetchone()) is not None:
pass
elif fetch == "many":
while await cur.fetchmany(3):
pass
elif fetch == "all":
await cur.fetchall()
elif fetch == "iter":
async for rec in cur:
pass
n = []
gc.collect()
for i in range(3):
await work()
gc.collect()
n.append(gc.count())
assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
@pytest.mark.parametrize(
"query, params, want",
[
("select 'hello'", (), "select 'hello'"),
("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"),
("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"),
("select %%", (), "select %%"),
("select %%, %s", (["a"],), "select %, 'a'"),
("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"),
("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"),
],
)
async def test_mogrify(aconn, query, params, want):
cur = aconn.cursor()
got = cur.mogrify(query, *params)
assert got == want
@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
async def test_mogrify_encoding(aconn, encoding):
await aconn.execute(f"set client_encoding to {encoding}")
q = aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
assert q == "select '\u20ac'"
@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
async def test_mogrify_badenc(aconn, encoding):
await aconn.execute(f"set client_encoding to {encoding}")
with pytest.raises(UnicodeEncodeError):
aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
|