1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
import unittest
from sqlglot import exp, parse_one
from sqlglot.expressions import Func
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer
class TestGenerator(unittest.TestCase):
def test_fallback_function_sql(self):
class SpecialUDF(Func):
arg_types = {"a": True, "b": False}
class NewParser(Parser):
FUNCTIONS = SpecialUDF.default_parser_mappings()
tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x")
expression = NewParser().parse(tokens)[0]
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x")
def test_fallback_function_var_args_sql(self):
class SpecialUDF(Func):
arg_types = {"a": True, "expressions": False}
is_var_len_args = True
class NewParser(Parser):
FUNCTIONS = SpecialUDF.default_parser_mappings()
tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
expression = NewParser().parse(tokens)[0]
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
self.assertEqual(
exp.DateTrunc(this=exp.to_column("event_date"), unit=exp.var("MONTH")).sql(),
"DATE_TRUNC('MONTH', event_date)",
)
def test_identify(self):
self.assertEqual(parse_one("x").sql(identify=True), '"x"')
self.assertEqual(parse_one("x").sql(identify=False), "x")
self.assertEqual(parse_one("X").sql(identify=True), '"X"')
self.assertEqual(parse_one('"x"').sql(identify=False), '"x"')
self.assertEqual(parse_one("x").sql(identify="safe"), '"x"')
self.assertEqual(parse_one("X").sql(identify="safe"), "X")
self.assertEqual(parse_one("x as 1").sql(identify="safe"), '"x" AS "1"')
self.assertEqual(parse_one("X as 1").sql(identify="safe"), 'X AS "1"')
def test_generate_nested_binary(self):
sql = "SELECT 'foo'" + (" || 'foo'" * 1000)
self.assertEqual(parse_one(sql).sql(copy=False), sql)
def test_overlap_operator(self):
for op in ("&<", "&>"):
with self.subTest(op=op):
input_sql = f"SELECT '[1,10]'::int4range {op} '[5,15]'::int4range"
expected_sql = (
f"SELECT CAST('[1,10]' AS INT4RANGE) {op} CAST('[5,15]' AS INT4RANGE)"
)
ast = parse_one(input_sql, read="postgres")
self.assertEqual(ast.sql(), expected_sql)
self.assertEqual(ast.sql("postgres"), expected_sql)
def test_pretty_nested_types(self):
def assert_pretty_nested(
datatype: exp.DataType,
single_line: str,
pretty: str,
max_text_width: int = 10,
**kwargs,
) -> None:
self.assertEqual(datatype.sql(), single_line)
self.assertEqual(
datatype.sql(pretty=True, max_text_width=max_text_width, **kwargs), pretty
)
# STRUCT
type_str = "STRUCT<a INT, b TEXT>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"STRUCT<\n a INT,\n b TEXT\n>",
)
# STRUCT - type def shorter than max text width so stays one line
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"STRUCT<a INT, b TEXT>",
max_text_width=50,
)
# STRUCT, leading_comma = True
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"STRUCT<\n a INT\n , b TEXT\n>",
leading_comma=True,
)
# ARRAY
type_str = "ARRAY<DECIMAL(38, 9)>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"ARRAY<\n DECIMAL(38, 9)\n>",
)
# ARRAY nested STRUCT
type_str = "ARRAY<STRUCT<a INT, b TEXT>>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"ARRAY<\n STRUCT<\n a INT,\n b TEXT\n >\n>",
)
# RANGE
type_str = "RANGE<DECIMAL(38, 9)>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"RANGE<\n DECIMAL(38, 9)\n>",
)
# LIST
type_str = "LIST<INT, INT, TEXT>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"LIST<\n INT,\n INT,\n TEXT\n>",
)
# MAP
type_str = "MAP<INT, DECIMAL(38, 9)>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"MAP<\n INT,\n DECIMAL(38, 9)\n>",
)
|