1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
|
"""Common Test Fixtures."""
import hashlib
import io
import os
from typing import NamedTuple
import pytest
import yaml
from yaml import CDumper, CLoader
from sqlfluff.cli.commands import quoted_presenter
from sqlfluff.core import FluffConfig
from sqlfluff.core.linter import Linter
from sqlfluff.core.parser import Lexer, Parser
from sqlfluff.core.parser.markers import PositionMarker
from sqlfluff.core.parser.segments import (
BaseSegment,
CodeSegment,
CommentSegment,
Dedent,
Indent,
NewlineSegment,
SymbolSegment,
WhitespaceSegment,
)
from sqlfluff.core.rules import BaseRule
from sqlfluff.core.templaters import TemplatedFile
# When writing YAML files, double quotes string values needing escapes.
yaml.add_representer(str, quoted_presenter)
class ParseExample(NamedTuple):
"""A tuple representing an example SQL file to parse."""
dialect: str
sqlfile: str
def get_parse_fixtures(
fail_on_missing_yml=False,
) -> tuple[list[ParseExample], list[tuple[str, str, bool, str]]]:
"""Search for all parsing fixtures."""
parse_success_examples = []
parse_structure_examples = []
# Generate the filenames for each dialect from the parser test directory
for d in os.listdir(os.path.join("test", "fixtures", "dialects")):
# Ignore documentation
if d.endswith(".md"):
continue
# assume that d is now the name of a dialect
dirlist = os.listdir(os.path.join("test", "fixtures", "dialects", d))
for f in dirlist:
has_yml = False
if f.endswith(".sql"):
root = f[:-4]
# only look for sql files
parse_success_examples.append(ParseExample(d, f))
# Look for the code_only version of the structure
y = root + ".yml"
if y in dirlist:
parse_structure_examples.append((d, f, True, y))
has_yml = True
# Look for the non-code included version of the structure
y = root + "_nc.yml"
if y in dirlist:
parse_structure_examples.append((d, f, False, y))
has_yml = True
if not has_yml and fail_on_missing_yml:
raise (
Exception(
f"Missing .yml file for {os.path.join(d, f)}. Run the "
"test/generate_parse_fixture_yml.py script!"
)
)
return parse_success_examples, parse_structure_examples
def make_dialect_path(dialect, fname):
"""Work out how to find paths given a dialect and a file name."""
return os.path.join("test", "fixtures", "dialects", dialect, fname)
def load_file(dialect, fname):
"""Load a file."""
with open(make_dialect_path(dialect, fname), encoding="utf8") as f:
raw = f.read()
return raw
def process_struct(obj):
"""Process a nested dict or dict-like into a check tuple."""
if isinstance(obj, dict):
return tuple((k, process_struct(obj[k])) for k in obj)
elif isinstance(obj, list):
# If empty list, return empty tuple
if not len(obj):
return tuple()
# We'll assume that it's a list of dicts
if isinstance(obj[0], dict):
buff = [process_struct(elem) for elem in obj]
if any(len(elem) > 1 for elem in buff):
raise ValueError(f"Not sure how to deal with multi key dict: {buff!r}")
return tuple(elem[0] for elem in buff)
else:
raise TypeError(f"Did not expect a list of {type(obj[0])}: {obj[0]!r}")
elif isinstance(obj, (str, int, float)):
return str(obj)
elif obj is None:
return None
else:
raise TypeError(f"Not sure how to deal with type {type(obj)}: {obj!r}")
def parse_example_file(dialect: str, sqlfile: str):
"""Parse example SQL file, return parse tree."""
config = FluffConfig(overrides=dict(dialect=dialect))
# Load the SQL
raw = load_file(dialect, sqlfile)
# Lex and parse the file
tokens, _ = Lexer(config=config).lex(raw)
tree = Parser(config=config).parse(tokens, fname=dialect + "/" + sqlfile)
return tree
def compute_parse_tree_hash(tree):
"""Given a parse tree, compute a consistent hash value for it."""
if tree:
r = tree.as_record(code_only=True, show_raw=True)
if r:
r_io = io.StringIO()
yaml.dump(r, r_io, sort_keys=False, allow_unicode=True, Dumper=CDumper)
result = hashlib.blake2s(r_io.getvalue().encode("utf-8")).hexdigest()
return result
return None
def load_yaml(fpath):
"""Load a yaml structure and process it into a tuple."""
# Load raw file
with open(fpath, encoding="utf8") as f:
raw = f.read()
# Parse the yaml
obj = yaml.load(raw, Loader=CLoader)
# Return the parsed and structured object
_hash = None
if obj:
_hash = obj.pop("_hash", None)
processed = process_struct(obj)
if processed:
return _hash, process_struct(obj)[0]
else:
return None, None
@pytest.fixture()
def yaml_loader():
"""Return a yaml loading function."""
# Return a function
return load_yaml
def _generate_test_segments_func(elems):
"""Roughly generate test segments.
This function isn't totally robust, but good enough
for testing. Use with caution.
"""
buff = []
raw_file = "".join(elems)
templated_file = TemplatedFile.from_string(raw_file)
idx = 0
for elem in elems:
if elem == "<indent>":
buff.append(
Indent(pos_marker=PositionMarker.from_point(idx, idx, templated_file))
)
continue
elif elem == "<dedent>":
buff.append(
Dedent(pos_marker=PositionMarker.from_point(idx, idx, templated_file))
)
continue
seg_kwargs = {}
if set(elem) <= {" ", "\t"}:
SegClass = WhitespaceSegment
elif set(elem) <= {"\n"}:
SegClass = NewlineSegment
elif elem == "(":
SegClass = SymbolSegment
seg_kwargs = {"instance_types": ("start_bracket",)}
elif elem == ")":
SegClass = SymbolSegment
seg_kwargs = {"instance_types": ("end_bracket",)}
elif elem == "[":
SegClass = SymbolSegment
seg_kwargs = {"instance_types": ("start_square_bracket",)}
elif elem == "]":
SegClass = SymbolSegment
seg_kwargs = {"instance_types": ("end_square_bracket",)}
elif elem.startswith("--"):
SegClass = CommentSegment
seg_kwargs = {"instance_types": ("inline_comment",)}
elif elem.startswith('"'):
SegClass = CodeSegment
seg_kwargs = {"instance_types": ("double_quote",)}
elif elem.startswith("'"):
SegClass = CodeSegment
seg_kwargs = {"instance_types": ("single_quote",)}
else:
SegClass = CodeSegment
# Set a none position marker which we'll realign at the end.
buff.append(
SegClass(
raw=elem,
pos_marker=PositionMarker(
slice(idx, idx + len(elem)),
slice(idx, idx + len(elem)),
templated_file,
),
**seg_kwargs,
)
)
idx += len(elem)
return tuple(buff)
@pytest.fixture(scope="module")
def generate_test_segments():
"""Roughly generate test segments.
This is a factory function so that it works as a fixture,
but when actually used, this will return the inner function
which is what you actually need.
"""
return _generate_test_segments_func
@pytest.fixture
def raise_critical_errors_after_fix(monkeypatch):
"""Raises errors that break the Fix process.
These errors are otherwise swallowed to allow the lint messages to reach
the end user.
"""
@staticmethod
def _log_critical_errors(error: Exception):
raise error
monkeypatch.setattr(BaseRule, "_log_critical_errors", _log_critical_errors)
@pytest.fixture(autouse=True)
def fail_on_parse_error_after_fix(monkeypatch):
"""Cause tests to fail if a lint fix introduces a parse error.
In production, we have a couple of functions that, upon detecting a bug in
a lint rule, just log a warning. To catch bugs in new or modified rules, we
want to be more strict during dev and CI/CD testing. Here, we patch in
different functions which raise runtime errors, causing tests to fail if
this happens.
"""
@staticmethod
def raise_error_apply_fixes_check_issue(message, *args): # pragma: no cover
raise ValueError(message % args)
@staticmethod
def raise_error_conflicting_fixes_same_anchor(message: str): # pragma: no cover
raise ValueError(message)
monkeypatch.setattr(
BaseSegment, "_log_apply_fixes_check_issue", raise_error_apply_fixes_check_issue
)
monkeypatch.setattr(
Linter,
"_report_conflicting_fixes_same_anchor",
raise_error_conflicting_fixes_same_anchor,
)
@pytest.fixture(autouse=True)
def test_verbosity_level(request):
"""Report the verbosity level for a given pytest run.
For example:
$ pytest -vv
Has a verbosity level of 2
While:
$ pytest
Has a verbosity level of 0
"""
return request.config.getoption("verbose")
|