File: db.py

package info (click to toggle)
python-pytest-djangoapp 1.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 396 kB
  • sloc: python: 1,116; makefile: 114; sh: 6
file content (124 lines) | stat: -rw-r--r-- 3,264 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from typing import List

import pytest
from decimal import Decimal
from django.db import connections, DEFAULT_DB_ALIAS, reset_queries

if False:  # pragma: nocover
    from collections import deque  # noqa


@pytest.fixture
def db_queries(settings) -> 'Queries':
    """Allows access to executed DB queries.

    Example::

        def test_db(db_queries):

            # Previous queries cleared at the beginning.
            assert len(db_queries) == 0

            ...  # Do some DB-related stuff.

            # Assert total queries on all DBs.
            assert len(db_queries) == 10

            # Default DBs SQLs with auxiliary commands filtered out by default.
            sqls = db_queries.sql()

            # Assert total execution time is less than a second.
            assert db_queries.time() < 1

            # Drop SQL gathered so far on default DB.
            db_queries.clear()


    .. warning:: Requires Django 1.9+ to work.

    """
    queries = Queries()

    debug_values_prev = {}

    for connection in connections.all():
        debug_values_prev[connection.alias] = connection.force_debug_cursor
        connection.force_debug_cursor = True

    try:
        queries.clear_all()

        yield queries

    finally:

        for connection in connections.all():
            prev_debug_value = debug_values_prev.get(connection.alias, None)

            if prev_debug_value is not None:
                connection.force_debug_cursor = prev_debug_value

        queries.clear_all()


class Queries:
    """Allows access to executed DB queries."""

    sql_drop = {
        'BEGIN',
        'COMMIT',
        'END',
    }

    def __len__(self) -> int:
        return len(self.get_log())

    def get_log(self, db_alias: str = None) -> 'deque':
        """
        :param db_alias:

        """
        return connections[db_alias or DEFAULT_DB_ALIAS].queries_log

    def clear_all(self):
        """Clears all queries logged for all DBs."""
        reset_queries()

    def clear(self, db_alias: str = None):
        """Clear queries for the given or default DB.

        :param db_alias: Database alias. Default is used if not given.

        """
        self.get_log(db_alias=db_alias).clear()

    def sql(self, db_alias: str = None, *, drop_auxiliary: bool = True) -> List[str]:
        """Returns a list of queries executed using the given or default DB.

        :param db_alias: Database alias. Default is used if not given.

        :param drop_auxiliary: Filter out auxiliary SQL like:
            * BEGIN
            * COMMIT
            * END

        """
        sqls = []

        auxiliary = self.sql_drop

        for log_entry in self.get_log(db_alias=db_alias):
            sql = ' '.join(sql_line.strip() for sql_line in log_entry['sql'].splitlines())
            if not drop_auxiliary or sql not in auxiliary:
                sqls.append(sql)

        return sqls

    def time(self, db_alias: str = None) -> Decimal:
        """Returns total time executing queries (in seconds) using the given or default DB.

        :param db_alias: Database alias. Default is used if not given.

        """
        times = [Decimal(log_entry['time']) for log_entry in self.get_log(db_alias=db_alias)]
        return sum(times)