File: test_cache.py

package info (click to toggle)
sqlfmt 0.29.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,580 kB
  • sloc: python: 10,007; sql: 5,626; makefile: 39
file content (167 lines) | stat: -rw-r--r-- 4,779 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import pickle
from pathlib import Path
from typing import Dict, Generator, List, Tuple

import pytest

from sqlfmt.cache import (
    Cache,
    check_cache,
    clear_cache,
    get_cache_file,
    load_cache,
    write_cache,
)
from sqlfmt.exception import SqlfmtError
from sqlfmt.mode import Mode
from sqlfmt.report import SqlFormatResult
from tests.util import BASE_DIR


@pytest.fixture(autouse=True)
def auto_clear_cache() -> Generator[None, None, None]:
    clear_cache()
    yield
    clear_cache()


@pytest.fixture
def sample_paths() -> Dict[str, Path]:
    paths = {
        "001": BASE_DIR / "preformatted" / "001_select_1.sql",
        "002": BASE_DIR / "preformatted" / "002_select_from_where.sql",
        "003": BASE_DIR / "preformatted" / "003_literals.sql",
        "004": BASE_DIR / "preformatted" / "004_with_select.sql",
        "005": BASE_DIR / "preformatted" / "005_fmt_off.sql",
        "900": BASE_DIR / "errors" / "900_bad_token.sql",
    }
    return paths


@pytest.fixture
def sample_stat() -> Tuple[float, int]:
    return (1000000000.100000, 1)


@pytest.fixture
def small_cache(sample_paths: Dict[str, Path], sample_stat: Tuple[float, int]) -> Cache:
    cache = {v: sample_stat for v in sample_paths.values() if "errors" not in str(v)}
    return cache


@pytest.fixture
def results_for_caching(sample_paths: Dict[str, Path]) -> List[SqlFormatResult]:
    results = [
        SqlFormatResult(
            sample_paths["001"],
            "select 1\n",
            "select 1\n",
            encoding="utf-8",
            utf_bom="",
        ),
        SqlFormatResult(
            sample_paths["002"],
            "select 1\n",
            "",
            encoding="utf-8",
            utf_bom="",
            from_cache=True,
        ),
        SqlFormatResult(
            sample_paths["003"],
            "select 'abc'\n",
            "select\n    'abc'\n",
            encoding="utf-8",
            utf_bom="",
        ),
        SqlFormatResult(
            sample_paths["900"],
            "!\n",
            "",
            encoding="utf-8",
            utf_bom="",
            exception=SqlfmtError("oops"),
        ),
    ]
    return results


def test_get_cache_file() -> None:
    cache_file = get_cache_file()
    assert cache_file
    assert isinstance(cache_file, Path)


def test_write_cache(
    small_cache: Cache,
    results_for_caching: List[SqlFormatResult],
    default_mode: Mode,
    sample_paths: Dict[str, Path],
    sample_stat: Tuple[float, int],
) -> None:
    cache_file = get_cache_file()
    assert not cache_file.exists()
    write_cache(cache=small_cache, results=results_for_caching, mode=default_mode)
    assert cache_file.exists()
    with open(cache_file, "rb") as f:
        written_cache = pickle.load(f)
    assert isinstance(written_cache, dict)
    assert small_cache.keys() == written_cache.keys()
    assert written_cache[sample_paths["001"]] != sample_stat, (
        "Should write new stat to cache for unchanged files"
    )
    assert written_cache[sample_paths["002"]] == sample_stat, (
        "Should not write new stat to cache for results from cache"
    )
    assert written_cache[sample_paths["003"]] != sample_stat, (
        "Should write new stat to cache for changed files in default mode"
    )
    assert sample_paths["900"] not in written_cache, "Should not write errors to cache"


def test_load_cache(
    small_cache: Cache,
    results_for_caching: List[SqlFormatResult],
    default_mode: Mode,
) -> None:
    empty_cache = load_cache()
    assert empty_cache == {}
    write_cache(cache=small_cache, results=results_for_caching, mode=default_mode)
    populated_cache = load_cache()
    assert isinstance(populated_cache, dict)
    assert populated_cache.keys() == small_cache.keys()
    assert populated_cache != small_cache


def test_check_cache(
    small_cache: Cache,
    results_for_caching: List[SqlFormatResult],
    default_mode: Mode,
    sample_paths: Dict[str, Path],
) -> None:
    assert all([not check_cache(small_cache, p) for p in sample_paths.values()])
    write_cache(cache=small_cache, results=results_for_caching, mode=default_mode)
    new_cache = load_cache()
    expected_cache_hits = [
        True,
        False,
        True,
        False,
        False,
        False,
    ]
    actual_cache_hits = [check_cache(new_cache, p) for p in sample_paths.values()]
    assert actual_cache_hits == expected_cache_hits


def test_clear_cache(
    small_cache: Cache,
    results_for_caching: List[SqlFormatResult],
    default_mode: Mode,
) -> None:
    cache_path = get_cache_file()
    write_cache(cache=small_cache, results=results_for_caching, mode=default_mode)
    assert cache_path.exists()

    clear_cache()
    assert not cache_path.exists()