1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
import unittest
from sqlglot import exp, parse_one
from sqlglot.expressions import Func
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer
class TestGenerator(unittest.TestCase):
def test_fallback_function_sql(self):
class SpecialUDF(Func):
arg_types = {"a": True, "b": False}
class NewParser(Parser):
FUNCTIONS = SpecialUDF.default_parser_mappings()
tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x")
expression = NewParser().parse(tokens)[0]
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x")
def test_fallback_function_var_args_sql(self):
class SpecialUDF(Func):
arg_types = {"a": True, "expressions": False}
is_var_len_args = True
class NewParser(Parser):
FUNCTIONS = SpecialUDF.default_parser_mappings()
tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
expression = NewParser().parse(tokens)[0]
self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
self.assertEqual(
exp.DateTrunc(this=exp.to_column("event_date"), unit=exp.var("MONTH")).sql(),
"DATE_TRUNC('MONTH', event_date)",
)
def test_identify(self):
self.assertEqual(parse_one("x").sql(identify=True), '"x"')
self.assertEqual(parse_one("x").sql(identify=False), "x")
self.assertEqual(parse_one("X").sql(identify=True), '"X"')
self.assertEqual(parse_one('"x"').sql(identify=False), '"x"')
self.assertEqual(parse_one("x").sql(identify="safe"), '"x"')
self.assertEqual(parse_one("X").sql(identify="safe"), "X")
self.assertEqual(parse_one("x as 1").sql(identify="safe"), '"x" AS "1"')
self.assertEqual(parse_one("X as 1").sql(identify="safe"), 'X AS "1"')
def test_generate_nested_binary(self):
sql = "SELECT 'foo'" + (" || 'foo'" * 1000)
self.assertEqual(parse_one(sql).sql(copy=False), sql)
def test_overlap_operator(self):
for op in ("&<", "&>"):
with self.subTest(op=op):
input_sql = f"SELECT '[1,10]'::int4range {op} '[5,15]'::int4range"
expected_sql = (
f"SELECT CAST('[1,10]' AS INT4RANGE) {op} CAST('[5,15]' AS INT4RANGE)"
)
ast = parse_one(input_sql, read="postgres")
self.assertEqual(ast.sql(), expected_sql)
self.assertEqual(ast.sql("postgres"), expected_sql)
|