File: util.py

package info (click to toggle)
sqlfmt 0.29.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,580 kB
  • sloc: python: 10,007; sql: 5,626; makefile: 39
file content (108 lines) | stat: -rw-r--r-- 3,160 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import shutil
from pathlib import Path
from typing import Iterable, Iterator, List, Tuple, Union

TEST_DIR = Path(__file__).parent
BASE_DIR = TEST_DIR / "data"
RESULTS_DIR = TEST_DIR / ".results"


def read_test_data(relpath: Union[Path, str]) -> Tuple[str, str]:
    """reads a test file contents and returns a tuple of strings corresponding to
    the unformatted and formatted examples in the test file. If the test file doesn't
    include a ')))))__SQLFMT_OUTPUT__(((((' sentinel, returns the same string twice
    (as the input is assumed to be pre-formatted). relpath is relative to
    tests/data/"""
    SENTINEL = ")))))__SQLFMT_OUTPUT__((((("

    test_path = BASE_DIR / relpath

    with open(test_path, "r") as test_file:
        lines = test_file.readlines()

    source_query: List[str] = []
    formatted_query: List[str] = []

    target = source_query

    for line in lines:
        if line.rstrip() == SENTINEL:
            target = formatted_query
            continue
        target.append(line)

    if source_query and not formatted_query:
        formatted_query = source_query[:]

    return "".join(source_query).strip() + "\n", "".join(formatted_query)


def _safe_create_results_dir() -> Path:
    results_dir = RESULTS_DIR
    results_dir.mkdir(exist_ok=True)
    return results_dir


def delete_results_dir() -> None:
    shutil.rmtree(RESULTS_DIR, ignore_errors=True)


def check_formatting(expected: str, actual: str, ctx: str = "") -> None:
    try:
        assert expected == actual, (
            "Formatting error. Output file written to tests/.results/"
        )
    except AssertionError as e:
        import inspect

        results_dir = _safe_create_results_dir()

        caller = inspect.stack()[1].function
        if ctx:
            caller += "-"
            ctx = ctx.replace("/", "-")

        if ctx.endswith(".sql"):
            suffix = ""
        else:
            suffix = ".sql"

        p = results_dir / (caller + ctx + suffix)
        with open(p, "w") as f:
            f.write(actual)
        raise e


def discover_test_files(relpaths: Iterable[Union[str, Path]]) -> Iterator[Path]:
    for p in [BASE_DIR / p for p in relpaths]:
        if p.is_file() and p.suffix == ".sql":
            yield p
        elif p.is_dir():
            yield from (discover_test_files(p.iterdir()))


def copy_test_data_to_tmp(relpaths: List[str], tmp_path: Path) -> Path:
    """
    Reads in test data from an existing file or directory, and creates a new file
    at the temp_path with the source query from the original test data file.

    Returns the path to the temporary file
    """

    for abspath in discover_test_files(relpaths):
        file_contents, _ = read_test_data(abspath)

        with open(tmp_path / abspath.name, "w") as tmp_file:
            tmp_file.write(file_contents)

    return tmp_path


def copy_config_file_to_dst(file_name: str, dst_path: Path) -> Path:
    CONFIG_DIR = BASE_DIR / "config"
    file_path = CONFIG_DIR / file_name
    assert file_path.is_file()

    new_file_path = dst_path / "pyproject.toml"
    shutil.copyfile(file_path, new_file_path)
    return new_file_path