File: respace_test.py

package info (click to toggle)
sqlfluff 3.5.0-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 34,000 kB
  • sloc: python: 106,131; sql: 34,188; makefile: 52; sh: 8
file content (106 lines) | stat: -rw-r--r-- 3,643 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Tests for respacing methods.

These are mostly on the ReflowPoint class.
"""

import logging

import pytest

from sqlfluff.core import Linter
from sqlfluff.utils.reflow.elements import ReflowPoint
from sqlfluff.utils.reflow.helpers import fixes_from_results
from sqlfluff.utils.reflow.sequence import ReflowSequence


def parse_ansi_string(sql, config):
    """Parse an ansi sql string for testing."""
    linter = Linter(config=config)
    return linter.parse_string(sql).tree


@pytest.mark.parametrize(
    "raw_sql_in,kwargs,raw_sql_out",
    [
        # Basic cases
        ("select 1+2", {}, "select 1 + 2"),
        ("select    1   +   2    ", {}, "select 1 + 2"),
        # Check newline handling
        ("select\n    1   +   2", {}, "select\n    1 + 2"),
        ("select\n  1   +   2", {}, "select\n  1 + 2"),
        ("select\n  1   +   2", {"strip_newlines": True}, "select 1 + 2"),
        # Check filtering
        ("select  \n  1   +   2 \n ", {}, "select\n  1 + 2\n"),
        ("select  \n  1   +   2 \n ", {"filter": "all"}, "select\n  1 + 2\n"),
        ("select  \n  1   +   2 \n ", {"filter": "inline"}, "select  \n  1 + 2 \n "),
        ("select  \n  1   +   2 \n ", {"filter": "newline"}, "select\n  1   +   2\n"),
    ],
)
def test_reflow__sequence_respace(
    raw_sql_in, kwargs, raw_sql_out, default_config, caplog
):
    """Test the ReflowSequence.respace() method directly."""
    root = parse_ansi_string(raw_sql_in, default_config)
    seq = ReflowSequence.from_root(root, config=default_config)

    with caplog.at_level(logging.DEBUG, logger="sqlfluff.rules.reflow"):
        new_seq = seq.respace(**kwargs)

    assert new_seq.get_raw() == raw_sql_out


@pytest.mark.parametrize(
    "raw_sql_in,point_idx,kwargs,raw_point_sql_out,fixes_out",
    [
        # Basic cases
        ("select    1", 1, {}, " ", {("replace", "    ")}),
        ("select 1+2", 3, {}, " ", {("create_after", "1")}),
        ("select (1+2)", 3, {}, "", set()),
        ("select (  1+2)", 3, {}, "", {("delete", "  ")}),
        # Newline handling
        ("select\n1", 1, {}, "\n", set()),
        ("select\n  1", 1, {}, "\n  ", set()),
        ("select  \n  1", 1, {}, "\n  ", {("delete", "  ")}),
        (
            "select  \n 1",
            1,
            {"strip_newlines": True},
            " ",
            {("delete", "\n"), ("delete", " "), ("replace", "  ")},
        ),
        (
            "select ( \n  1)",
            3,
            {"strip_newlines": True},
            "",
            {("delete", "\n"), ("delete", " "), ("delete", "  ")},
        ),
    ],
)
def test_reflow__point_respace_point(
    raw_sql_in, point_idx, kwargs, raw_point_sql_out, fixes_out, default_config, caplog
):
    """Test the ReflowPoint.respace_point() method directly.

    NOTE: This doesn't check any pre-existing fixes.
    That should be a separate more specific test.
    """
    root = parse_ansi_string(raw_sql_in, default_config)
    seq = ReflowSequence.from_root(root, config=default_config)
    pnt = seq.elements[point_idx]
    assert isinstance(pnt, ReflowPoint)

    with caplog.at_level(logging.DEBUG, logger="sqlfluff.rules.reflow"):
        results, new_pnt = pnt.respace_point(
            prev_block=seq.elements[point_idx - 1],
            next_block=seq.elements[point_idx + 1],
            root_segment=root,
            lint_results=[],
            **kwargs,
        )

    assert new_pnt.raw == raw_point_sql_out
    # NOTE: We use set comparison, because ordering isn't important for fixes.
    assert {
        (fix.edit_type, fix.anchor.raw) for fix in fixes_from_results(results)
    } == fixes_out