1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
|
import shutil
from pathlib import Path
from typing import Iterable, Iterator, List, Tuple, Union
TEST_DIR = Path(__file__).parent
BASE_DIR = TEST_DIR / "data"
RESULTS_DIR = TEST_DIR / ".results"
def read_test_data(relpath: Union[Path, str]) -> Tuple[str, str]:
"""reads a test file contents and returns a tuple of strings corresponding to
the unformatted and formatted examples in the test file. If the test file doesn't
include a ')))))__SQLFMT_OUTPUT__(((((' sentinel, returns the same string twice
(as the input is assumed to be pre-formatted). relpath is relative to
tests/data/"""
SENTINEL = ")))))__SQLFMT_OUTPUT__((((("
test_path = BASE_DIR / relpath
with open(test_path, "r") as test_file:
lines = test_file.readlines()
source_query: List[str] = []
formatted_query: List[str] = []
target = source_query
for line in lines:
if line.rstrip() == SENTINEL:
target = formatted_query
continue
target.append(line)
if source_query and not formatted_query:
formatted_query = source_query[:]
return "".join(source_query).strip() + "\n", "".join(formatted_query)
def _safe_create_results_dir() -> Path:
results_dir = RESULTS_DIR
results_dir.mkdir(exist_ok=True)
return results_dir
def delete_results_dir() -> None:
shutil.rmtree(RESULTS_DIR, ignore_errors=True)
def check_formatting(expected: str, actual: str, ctx: str = "") -> None:
try:
assert expected == actual, (
"Formatting error. Output file written to tests/.results/"
)
except AssertionError as e:
import inspect
results_dir = _safe_create_results_dir()
caller = inspect.stack()[1].function
if ctx:
caller += "-"
ctx = ctx.replace("/", "-")
if ctx.endswith(".sql"):
suffix = ""
else:
suffix = ".sql"
p = results_dir / (caller + ctx + suffix)
with open(p, "w") as f:
f.write(actual)
raise e
def discover_test_files(relpaths: Iterable[Union[str, Path]]) -> Iterator[Path]:
for p in [BASE_DIR / p for p in relpaths]:
if p.is_file() and p.suffix == ".sql":
yield p
elif p.is_dir():
yield from (discover_test_files(p.iterdir()))
def copy_test_data_to_tmp(relpaths: List[str], tmp_path: Path) -> Path:
"""
Reads in test data from an existing file or directory, and creates a new file
at the temp_path with the source query from the original test data file.
Returns the path to the temporary file
"""
for abspath in discover_test_files(relpaths):
file_contents, _ = read_test_data(abspath)
with open(tmp_path / abspath.name, "w") as tmp_file:
tmp_file.write(file_contents)
return tmp_path
def copy_config_file_to_dst(file_name: str, dst_path: Path) -> Path:
CONFIG_DIR = BASE_DIR / "config"
file_path = CONFIG_DIR / file_name
assert file_path.is_file()
new_file_path = dst_path / "pyproject.toml"
shutil.copyfile(file_path, new_file_path)
return new_file_path
|