1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
|
import pytest
from sqlparse import parse
from pgcli.packages.parseutils.ctes import (
token_start_pos,
extract_ctes,
extract_column_names as _extract_column_names,
)
def extract_column_names(sql):
p = parse(sql)[0]
return _extract_column_names(p)
def test_token_str_pos():
sql = "SELECT * FROM xxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM ")
sql = "SELECT * FROM \nxxx"
p = parse(sql)[0]
idx = p.token_index(p.tokens[-1])
assert token_start_pos(p.tokens, idx) == len("SELECT * FROM \n")
def test_single_column_name_extraction():
sql = "SELECT abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_aliased_single_column_name_extraction():
sql = "SELECT abc def FROM xxx"
assert extract_column_names(sql) == ("def",)
def test_aliased_expression_name_extraction():
sql = "SELECT 99 abc FROM xxx"
assert extract_column_names(sql) == ("abc",)
def test_multiple_column_name_extraction():
sql = "SELECT abc, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_missing_column_name_handled_gracefully():
sql = "SELECT abc, 99 FROM xxx"
assert extract_column_names(sql) == ("abc",)
sql = "SELECT abc, 99, def FROM xxx"
assert extract_column_names(sql) == ("abc", "def")
def test_aliased_multiple_column_name_extraction():
sql = "SELECT abc def, ghi jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
def test_table_qualified_column_name_extraction():
sql = "SELECT abc.def, ghi.jkl FROM xxx"
assert extract_column_names(sql) == ("def", "jkl")
@pytest.mark.parametrize(
"sql",
[
"INSERT INTO foo (x, y, z) VALUES (5, 6, 7) RETURNING x, y",
"DELETE FROM foo WHERE x > y RETURNING x, y",
"UPDATE foo SET x = 9 RETURNING x, y",
],
)
def test_extract_column_names_from_returning_clause(sql):
assert extract_column_names(sql) == ("x", "y")
def test_simple_cte_extraction():
sql = "WITH a AS (SELECT abc FROM xxx) SELECT * FROM a"
start_pos = len("WITH a AS ")
stop_pos = len("WITH a AS (SELECT abc FROM xxx)")
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (("a", ("abc",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_cte_extraction_around_comments():
sql = """--blah blah blah
WITH a AS (SELECT abc def FROM x)
SELECT * FROM a"""
start_pos = len(
"""--blah blah blah
WITH a AS """
)
stop_pos = len(
"""--blah blah blah
WITH a AS (SELECT abc def FROM x)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (("a", ("def",), start_pos, stop_pos),)
assert remainder.strip() == "SELECT * FROM a"
def test_multiple_cte_extraction():
sql = """WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)
SELECT * FROM a, b"""
start1 = len(
"""WITH
x AS """
)
stop1 = len(
"""WITH
x AS (SELECT abc, def FROM x)"""
)
start2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS """
)
stop2 = len(
"""WITH
x AS (SELECT abc, def FROM x),
y AS (SELECT ghi, jkl FROM y)"""
)
ctes, remainder = extract_ctes(sql)
assert tuple(ctes) == (
("x", ("abc", "def"), start1, stop1),
("y", ("ghi", "jkl"), start2, stop2),
)
|