File: test_cursor_client_async.py

package info (click to toggle)
psycopg3 3.3.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,836 kB
  • sloc: python: 46,657; sh: 403; ansic: 149; makefile: 73
file content (149 lines) | stat: -rw-r--r-- 4,744 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import datetime as dt

import pytest

import psycopg
from psycopg import rows

from .fix_crdb import crdb_encoding


@pytest.fixture
async def aconn(aconn, anyio_backend):
    aconn.cursor_factory = psycopg.AsyncClientCursor
    return aconn


async def test_default_cursor(aconn):
    cur = aconn.cursor()
    assert type(cur) is psycopg.AsyncClientCursor


async def test_str(aconn):
    cur = aconn.cursor()
    assert "psycopg.%s" % psycopg.AsyncClientCursor.__name__ in str(cur)


async def test_from_cursor_factory(aconn_cls, dsn):
    async with await aconn_cls.connect(
        dsn, cursor_factory=psycopg.AsyncClientCursor
    ) as aconn:
        cur = aconn.cursor()
        assert type(cur) is psycopg.AsyncClientCursor


async def test_execute_many_results_param(aconn):
    cur = aconn.cursor()
    assert cur.nextset() is None

    rv = await cur.execute("select %s; select generate_series(1, %s)", ("foo", 3))
    assert rv is cur
    assert (await cur.fetchall()) == [("foo",)]
    assert cur.rowcount == 1
    assert cur.nextset()
    assert (await cur.fetchall()) == [(1,), (2,), (3,)]
    assert cur.nextset() is None

    await cur.close()
    assert cur.nextset() is None


async def test_query_params_execute(aconn):
    cur = aconn.cursor()
    assert cur._query is None

    await cur.execute("select %t, %s::text", [1, None])
    assert cur._query is not None
    assert cur._query.query == b"select 1, NULL::text"
    assert cur._query.params == (b"1", b"NULL")

    await cur.execute("select 1")
    assert cur._query.query == b"select 1"
    assert not cur._query.params

    with pytest.raises(psycopg.DataError):
        await cur.execute("select %t::int", ["wat"])

    assert cur._query.query == b"select 'wat'::int"
    assert cur._query.params == (b"'wat'",)


async def test_query_params_executemany(aconn):
    cur = aconn.cursor()

    await cur.executemany("select %t, %t", [[1, 2], [3, 4]])
    assert cur._query.query == b"select 3, 4"
    assert cur._query.params == (b"3", b"4")


@pytest.mark.slow
@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
async def test_leak(aconn_cls, dsn, faker, fetch, row_factory, gc):
    faker.choose_schema(ncols=5)
    faker.make_records(10)
    row_factory = getattr(rows, row_factory)

    async def work():
        async with await aconn_cls.connect(dsn) as conn, conn.transaction(
            force_rollback=True
        ):
            async with psycopg.AsyncClientCursor(conn, row_factory=row_factory) as cur:
                await cur.execute(faker.drop_stmt)
                await cur.execute(faker.create_stmt)
                async with faker.find_insert_problem_async(conn):
                    await cur.executemany(faker.insert_stmt, faker.records)
                await cur.execute(faker.select_stmt)

                if fetch == "one":
                    while (await cur.fetchone()) is not None:
                        pass
                elif fetch == "many":
                    while await cur.fetchmany(3):
                        pass
                elif fetch == "all":
                    await cur.fetchall()
                elif fetch == "iter":
                    async for rec in cur:
                        pass

    n = []
    gc.collect()
    for i in range(3):
        await work()
        gc.collect()
        n.append(gc.count())

    assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"


@pytest.mark.parametrize(
    "query, params, want",
    [
        ("select 'hello'", (), "select 'hello'"),
        ("select %s, %s", ([1, dt.date(2020, 1, 1)],), "select 1, '2020-01-01'::date"),
        ("select %(foo)s, %(foo)s", ({"foo": "x"},), "select 'x', 'x'"),
        ("select %%", (), "select %%"),
        ("select %%, %s", (["a"],), "select %, 'a'"),
        ("select %%, %(foo)s", ({"foo": "x"},), "select %, 'x'"),
        ("select %%s, %(foo)s", ({"foo": "x"},), "select %s, 'x'"),
    ],
)
async def test_mogrify(aconn, query, params, want):
    cur = aconn.cursor()
    got = cur.mogrify(query, *params)
    assert got == want


@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
async def test_mogrify_encoding(aconn, encoding):
    await aconn.execute(f"set client_encoding to {encoding}")
    q = aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})
    assert q == "select '\u20ac'"


@pytest.mark.parametrize("encoding", [crdb_encoding("latin1")])
async def test_mogrify_badenc(aconn, encoding):
    await aconn.execute(f"set client_encoding to {encoding}")
    with pytest.raises(UnicodeEncodeError):
        aconn.cursor().mogrify("select %(s)s", {"s": "\u20ac"})