1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
import unittest
from sqlglot import exp, parse_one
from sqlglot.expressions import Expression, Func
from sqlglot.parsers.snowflake import SnowflakeParser
import sqlglot.expressions.core as _core_module
_EXPRESSION_IS_COMPILED = getattr(_core_module, "__file__", "").endswith(".so")
class TestGenerator(unittest.TestCase):
@unittest.skipIf(_EXPRESSION_IS_COMPILED, "mypyc compiled expressions cannot be subclassed")
def test_fallback_function_sql(self):
class SpecialUdf(Expression, Func):
arg_types = {"a": True, "b": False}
SnowflakeParser.FUNCTIONS["SPECIAL_UDF"] = SpecialUdf.from_arg_list
try:
sql = "SELECT SPECIAL_UDF(a) FROM x"
expression = parse_one(sql, dialect="snowflake")
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x")
finally:
del SnowflakeParser.FUNCTIONS["SPECIAL_UDF"]
@unittest.skipIf(_EXPRESSION_IS_COMPILED, "mypyc compiled expressions cannot be subclassed")
def test_fallback_function_var_args_sql(self):
class SpecialUdf(Expression, Func):
arg_types = {"a": True, "expressions": False}
is_var_len_args = True
SnowflakeParser.FUNCTIONS["SPECIAL_UDF"] = SpecialUdf.from_arg_list
try:
sql = "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x"
expression = parse_one(sql, dialect="snowflake")
self.assertEqual(expression.sql(), sql)
finally:
del SnowflakeParser.FUNCTIONS["SPECIAL_UDF"]
self.assertEqual(
exp.DateTrunc(this=exp.to_column("event_date"), unit=exp.var("MONTH")).sql(),
"DATE_TRUNC('MONTH', event_date)",
)
def test_identify(self):
self.assertEqual(parse_one("x").sql(identify=True), '"x"')
self.assertEqual(parse_one("x").sql(identify=False), "x")
self.assertEqual(parse_one("X").sql(identify=True), '"X"')
self.assertEqual(parse_one('"x"').sql(identify=False), '"x"')
self.assertEqual(parse_one("x").sql(identify="safe"), '"x"')
self.assertEqual(parse_one("X").sql(identify="safe"), "X")
self.assertEqual(parse_one("x as 1").sql(identify="safe"), '"x" AS "1"')
self.assertEqual(parse_one("X as 1").sql(identify="safe"), 'X AS "1"')
def test_generate_nested_binary(self):
sql = "SELECT 'foo'" + (" || 'foo'" * 1000)
self.assertEqual(parse_one(sql).sql(copy=False), sql)
def test_overlap_operator(self):
for op in ("&<", "&>"):
with self.subTest(op=op):
input_sql = f"SELECT '[1,10]'::int4range {op} '[5,15]'::int4range"
expected_sql = (
f"SELECT CAST('[1,10]' AS INT4RANGE) {op} CAST('[5,15]' AS INT4RANGE)"
)
ast = parse_one(input_sql, read="postgres")
self.assertEqual(ast.sql(), expected_sql)
self.assertEqual(ast.sql("postgres"), expected_sql)
def test_pretty_nested_types(self):
def assert_pretty_nested(
datatype: exp.DataType,
single_line: str,
pretty: str,
max_text_width: int = 10,
**kwargs,
) -> None:
self.assertEqual(datatype.sql(), single_line)
self.assertEqual(
datatype.sql(pretty=True, max_text_width=max_text_width, **kwargs), pretty
)
# STRUCT
type_str = "STRUCT<a INT, b TEXT>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"STRUCT<\n a INT,\n b TEXT\n>",
)
# STRUCT - type def shorter than max text width so stays one line
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"STRUCT<a INT, b TEXT>",
max_text_width=50,
)
# STRUCT, leading_comma = True
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"STRUCT<\n a INT\n , b TEXT\n>",
leading_comma=True,
)
# ARRAY
type_str = "ARRAY<DECIMAL(38, 9)>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"ARRAY<\n DECIMAL(38, 9)\n>",
)
# ARRAY nested STRUCT
type_str = "ARRAY<STRUCT<a INT, b TEXT>>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"ARRAY<\n STRUCT<\n a INT,\n b TEXT\n >\n>",
)
# RANGE
type_str = "RANGE<DECIMAL(38, 9)>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"RANGE<\n DECIMAL(38, 9)\n>",
)
# LIST
type_str = "LIST<INT, INT, TEXT>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"LIST<\n INT,\n INT,\n TEXT\n>",
)
# MAP
type_str = "MAP<INT, DECIMAL(38, 9)>"
assert_pretty_nested(
exp.DataType.build(type_str),
type_str,
"MAP<\n INT,\n DECIMAL(38, 9)\n>",
)
|