File: conftest.py

package info (click to toggle)
pg-activity 3.6.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,144 kB
  • sloc: python: 3,902; sql: 1,067; sh: 5; makefile: 2
file content (109 lines) | stat: -rw-r--r-- 3,117 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from __future__ import annotations

import logging
import pathlib
import threading
from typing import Any

import psycopg
import psycopg.errors
import pytest
from psycopg import sql
from psycopg.conninfo import make_conninfo

from pgactivity import pg

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.DEBUG)


def pytest_report_header(config: Any) -> list[str]:
    return [f"psycopg: {pg.__version__}"]


@pytest.fixture(scope="session")
def datadir() -> pathlib.Path:
    return pathlib.Path(__file__).parent / "data"


@pytest.fixture
def database_factory(postgresql):
    dbnames = set()

    def createdb(dbname: str, encoding: str, locale: str | None = None) -> None:
        with psycopg.connect(postgresql.info.dsn, autocommit=True) as conn:
            qs = sql.SQL(
                "CREATE DATABASE {dbname} ENCODING {encoding} TEMPLATE template0"
            ).format(dbname=sql.Identifier(dbname), encoding=sql.Identifier(encoding))
            if locale:
                qs = sql.SQL(" ").join(
                    [
                        qs,
                        sql.SQL("LOCALE {locale}").format(
                            locale=sql.Identifier(locale)
                        ),
                    ]
                )
            conn.execute(qs)
        dbnames.add(dbname)

    yield createdb

    with psycopg.connect(postgresql.info.dsn, autocommit=True) as conn:
        for dbname in dbnames:
            conn.execute(
                sql.SQL("DROP DATABASE IF EXISTS {dbname} WITH (FORCE)").format(
                    dbname=sql.Identifier(dbname)
                )
            )


@pytest.fixture
def execute(postgresql):
    """Create a thread and return an execute() function that will run SQL queries in that
    thread.
    """
    threads_and_cnx = []

    def execute(
        query: str,
        commit: bool = False,
        autocommit: bool = False,
        dbname: str | None = None,
    ) -> None:
        dsn, kwargs = postgresql.info.dsn, {}
        if dbname:
            kwargs["dbname"] = dbname
        conn = psycopg.connect(make_conninfo(dsn, **kwargs))
        conn.autocommit = autocommit

        def _execute() -> None:
            LOGGER.info(
                "running query %s (commit=%s, autocommit=%s) using connection <%s>",
                query,
                commit,
                autocommit,
                id(conn),
            )
            with conn.cursor() as c:
                try:
                    c.execute(query)
                except (
                    psycopg.errors.AdminShutdown,
                    psycopg.errors.QueryCanceled,
                ):
                    return
                if not autocommit and commit:
                    conn.commit()
            LOGGER.info("query %s finished", query)

        thread = threading.Thread(target=_execute, daemon=True)
        thread.start()
        threads_and_cnx.append((thread, conn))

    yield execute

    for thread, conn in threads_and_cnx:
        thread.join(timeout=2)
        LOGGER.info("closing connection <%s>", id(conn))
        conn.close()