File: segments_test.py

package info (click to toggle)
sqlfluff 3.5.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 34,000 kB
  • sloc: python: 106,131; sql: 34,188; makefile: 52; sh: 8
file content (163 lines) | stat: -rw-r--r-- 4,430 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Tests for the segments module."""

import pytest

import sqlfluff.utils.functional.segment_predicates as sp
from sqlfluff.core.linter.linter import Linter
from sqlfluff.core.parser.segments.raw import RawSegment
from sqlfluff.utils.functional import segments

seg1 = RawSegment("s1")
seg2 = RawSegment("s2")
seg3 = RawSegment("s3")
seg4 = RawSegment("s4")


@pytest.mark.parametrize(
    ["lhs", "rhs", "expected"],
    [
        [
            segments.Segments(seg1, seg2),
            segments.Segments(seg3, seg4),
            segments.Segments(seg1, seg2, seg3, seg4),
        ],
        [
            segments.Segments(seg3, seg4),
            segments.Segments(seg1, seg2),
            segments.Segments(seg3, seg4, seg1, seg2),
        ],
        [
            segments.Segments(seg1, seg2),
            [seg3, seg4],
            segments.Segments(seg1, seg2, seg3, seg4),
        ],
        [
            [seg1, seg2],
            segments.Segments(seg3, seg4),
            segments.Segments(seg1, seg2, seg3, seg4),
        ],
    ],
)
def test_segments_add(lhs, rhs, expected):
    """Verify addition of Segments objects with themselves and lists."""
    result = lhs + rhs
    assert isinstance(result, segments.Segments)
    assert result == expected


@pytest.mark.parametrize(
    ["input", "expected"],
    [
        [
            segments.Segments(seg1, seg2),
            True,
        ],
        [
            segments.Segments(seg1, seg3),
            False,
        ],
    ],
)
def test_segments_all(input, expected):
    """Test the "all()" function."""
    assert input.all(lambda s: s.raw[-1] <= "2") == expected


@pytest.mark.parametrize(
    ["input", "expected"],
    [
        [
            segments.Segments(seg1, seg2),
            True,
        ],
        [
            segments.Segments(seg1, seg3),
            True,
        ],
        [
            segments.Segments(seg3),
            False,
        ],
    ],
)
def test_segments_any(input, expected):
    """Test the "any()" function."""
    assert input.any(lambda s: s.raw[-1] <= "2") == expected


def test_segments_reversed():
    """Test the "reversed()" function."""
    assert segments.Segments(seg1, seg2).reversed() == segments.Segments(seg2, seg1)


def test_segments_raw_slices_no_templated_file():
    """Test that raw_slices() fails if TemplatedFile not provided."""
    with pytest.raises(ValueError):
        segments.Segments(seg1).raw_slices


def test_segments_first_no_predicate():
    """Test the "first()" function with no predicate."""
    assert segments.Segments(seg1, seg2).first() == segments.Segments(seg1)


def test_segments_first_with_predicate():
    """Test the "first()" function with a predicate."""
    assert segments.Segments(seg1, seg2).first(sp.is_meta()) == segments.Segments()


def test_segments_last():
    """Test the "last()" function."""
    assert segments.Segments(seg1, seg2).last() == segments.Segments(seg2)


def test_segments_apply():
    """Test the "apply()" function."""
    assert segments.Segments(seg1, seg2).apply(lambda s: s.raw[-1]) == ["1", "2"]


@pytest.mark.parametrize(
    ["function", "expected"],
    [
        [sp.get_type(), ["raw", "raw"]],
        [sp.is_comment(), [False, False]],
        [sp.is_raw(), [True, True]],
    ],
)
def test_segments_apply_functions(function, expected):
    """Test the "apply()" function with the "get_name()" function."""
    assert segments.Segments(seg1, seg2).apply(function) == expected


def test_segment_predicates_and():
    """Test the "and_()" function."""
    assert segments.Segments(seg1, seg2).select(
        select_if=sp.and_(sp.is_raw(), lambda s: s.raw[-1] == "1")
    ) == segments.Segments(seg1)
    assert (
        segments.Segments(seg1, seg2).select(
            select_if=sp.and_(sp.is_raw(), lambda s: s.raw[-1] == "3")
        )
        == segments.Segments()
    )


def test_segments_recursive_crawl():
    """Test the "recursive_crawl()" function."""
    sql = """
    WITH cte AS (
        SELECT * FROM tab_a
    )
    SELECT
        cte.col_a,
        tab_b.col_b
    FROM cte
    INNER JOIN tab_b;
    """
    linter = Linter(dialect="ansi")
    parsed = linter.parse_string(sql)

    functional_tree = segments.Segments(parsed.root_variant().tree)

    assert len(functional_tree.recursive_crawl("common_table_expression")) == 1
    assert len(functional_tree.recursive_crawl("table_reference")) == 3