File: test_generator.py

package info (click to toggle)
sqlglot 28.5.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,672 kB
  • sloc: python: 84,517; sql: 22,534; makefile: 48
file content (61 lines) | stat: -rw-r--r-- 2,556 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import unittest

from sqlglot import exp, parse_one
from sqlglot.expressions import Func
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer


class TestGenerator(unittest.TestCase):
    def test_fallback_function_sql(self):
        class SpecialUDF(Func):
            arg_types = {"a": True, "b": False}

        class NewParser(Parser):
            FUNCTIONS = SpecialUDF.default_parser_mappings()

        tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x")
        expression = NewParser().parse(tokens)[0]
        self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x")

    def test_fallback_function_var_args_sql(self):
        class SpecialUDF(Func):
            arg_types = {"a": True, "expressions": False}
            is_var_len_args = True

        class NewParser(Parser):
            FUNCTIONS = SpecialUDF.default_parser_mappings()

        tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")
        expression = NewParser().parse(tokens)[0]
        self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x")

        self.assertEqual(
            exp.DateTrunc(this=exp.to_column("event_date"), unit=exp.var("MONTH")).sql(),
            "DATE_TRUNC('MONTH', event_date)",
        )

    def test_identify(self):
        self.assertEqual(parse_one("x").sql(identify=True), '"x"')
        self.assertEqual(parse_one("x").sql(identify=False), "x")
        self.assertEqual(parse_one("X").sql(identify=True), '"X"')
        self.assertEqual(parse_one('"x"').sql(identify=False), '"x"')
        self.assertEqual(parse_one("x").sql(identify="safe"), '"x"')
        self.assertEqual(parse_one("X").sql(identify="safe"), "X")
        self.assertEqual(parse_one("x as 1").sql(identify="safe"), '"x" AS "1"')
        self.assertEqual(parse_one("X as 1").sql(identify="safe"), 'X AS "1"')

    def test_generate_nested_binary(self):
        sql = "SELECT 'foo'" + (" || 'foo'" * 1000)
        self.assertEqual(parse_one(sql).sql(copy=False), sql)

    def test_overlap_operator(self):
        for op in ("&<", "&>"):
            with self.subTest(op=op):
                input_sql = f"SELECT '[1,10]'::int4range {op} '[5,15]'::int4range"
                expected_sql = (
                    f"SELECT CAST('[1,10]' AS INT4RANGE) {op} CAST('[5,15]' AS INT4RANGE)"
                )
                ast = parse_one(input_sql, read="postgres")
                self.assertEqual(ast.sql(), expected_sql)
                self.assertEqual(ast.sql("postgres"), expected_sql)