1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
"""Tests for the segments module."""
import pytest
import sqlfluff.utils.functional.segment_predicates as sp
from sqlfluff.core.linter.linter import Linter
from sqlfluff.core.parser.segments.raw import RawSegment
from sqlfluff.utils.functional import segments
seg1 = RawSegment("s1")
seg2 = RawSegment("s2")
seg3 = RawSegment("s3")
seg4 = RawSegment("s4")
@pytest.mark.parametrize(
["lhs", "rhs", "expected"],
[
[
segments.Segments(seg1, seg2),
segments.Segments(seg3, seg4),
segments.Segments(seg1, seg2, seg3, seg4),
],
[
segments.Segments(seg3, seg4),
segments.Segments(seg1, seg2),
segments.Segments(seg3, seg4, seg1, seg2),
],
[
segments.Segments(seg1, seg2),
[seg3, seg4],
segments.Segments(seg1, seg2, seg3, seg4),
],
[
[seg1, seg2],
segments.Segments(seg3, seg4),
segments.Segments(seg1, seg2, seg3, seg4),
],
],
)
def test_segments_add(lhs, rhs, expected):
"""Verify addition of Segments objects with themselves and lists."""
result = lhs + rhs
assert isinstance(result, segments.Segments)
assert result == expected
@pytest.mark.parametrize(
["input", "expected"],
[
[
segments.Segments(seg1, seg2),
True,
],
[
segments.Segments(seg1, seg3),
False,
],
],
)
def test_segments_all(input, expected):
"""Test the "all()" function."""
assert input.all(lambda s: s.raw[-1] <= "2") == expected
@pytest.mark.parametrize(
["input", "expected"],
[
[
segments.Segments(seg1, seg2),
True,
],
[
segments.Segments(seg1, seg3),
True,
],
[
segments.Segments(seg3),
False,
],
],
)
def test_segments_any(input, expected):
"""Test the "any()" function."""
assert input.any(lambda s: s.raw[-1] <= "2") == expected
def test_segments_reversed():
"""Test the "reversed()" function."""
assert segments.Segments(seg1, seg2).reversed() == segments.Segments(seg2, seg1)
def test_segments_raw_slices_no_templated_file():
"""Test that raw_slices() fails if TemplatedFile not provided."""
with pytest.raises(ValueError):
segments.Segments(seg1).raw_slices
def test_segments_first_no_predicate():
"""Test the "first()" function with no predicate."""
assert segments.Segments(seg1, seg2).first() == segments.Segments(seg1)
def test_segments_first_with_predicate():
"""Test the "first()" function with a predicate."""
assert segments.Segments(seg1, seg2).first(sp.is_meta()) == segments.Segments()
def test_segments_last():
"""Test the "last()" function."""
assert segments.Segments(seg1, seg2).last() == segments.Segments(seg2)
def test_segments_apply():
"""Test the "apply()" function."""
assert segments.Segments(seg1, seg2).apply(lambda s: s.raw[-1]) == ["1", "2"]
@pytest.mark.parametrize(
["function", "expected"],
[
[sp.get_type(), ["raw", "raw"]],
[sp.is_comment(), [False, False]],
[sp.is_raw(), [True, True]],
],
)
def test_segments_apply_functions(function, expected):
"""Test the "apply()" function with the "get_name()" function."""
assert segments.Segments(seg1, seg2).apply(function) == expected
def test_segment_predicates_and():
"""Test the "and_()" function."""
assert segments.Segments(seg1, seg2).select(
select_if=sp.and_(sp.is_raw(), lambda s: s.raw[-1] == "1")
) == segments.Segments(seg1)
assert (
segments.Segments(seg1, seg2).select(
select_if=sp.and_(sp.is_raw(), lambda s: s.raw[-1] == "3")
)
== segments.Segments()
)
def test_segments_recursive_crawl():
"""Test the "recursive_crawl()" function."""
sql = """
WITH cte AS (
SELECT * FROM tab_a
)
SELECT
cte.col_a,
tab_b.col_b
FROM cte
INNER JOIN tab_b;
"""
linter = Linter(dialect="ansi")
parsed = linter.parse_string(sql)
functional_tree = segments.Segments(parsed.root_variant().tree)
assert len(functional_tree.recursive_crawl("common_table_expression")) == 1
assert len(functional_tree.recursive_crawl("table_reference")) == 3
|