File: test_dialect_entry_points.py

package info (click to toggle)
sqlglot 28.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,816 kB
  • sloc: python: 86,744; sql: 22,739; makefile: 48
file content (64 lines) | stat: -rw-r--r-- 2,104 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import unittest
from unittest.mock import Mock, patch

from sqlglot import Dialect
from sqlglot.dialects.dialect import Dialect as DialectBase


class FakeDialect(DialectBase):
    pass


class TestDialectEntryPoints(unittest.TestCase):
    def setUp(self):
        Dialect._classes.clear()

    def tearDown(self):
        Dialect._classes.clear()

    def test_entry_point_plugin_discovery_modern_api(self):
        fake_entry_point = Mock()
        fake_entry_point.name = "fakedb"
        fake_entry_point.load.return_value = FakeDialect

        mock_selectable = Mock()
        mock_selectable.select.return_value = [fake_entry_point]

        mock_entry_points = Mock(return_value=mock_selectable)

        with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
            dialect = Dialect.get("fakedb")

        self.assertIsNotNone(dialect)
        self.assertEqual(dialect, FakeDialect)
        fake_entry_point.load.assert_called_once()
        mock_selectable.select.assert_called_once_with(group="sqlglot.dialects", name="fakedb")

    def test_entry_point_plugin_discovery_legacy_api(self):
        fake_entry_point = Mock()
        fake_entry_point.name = "fakedb"
        fake_entry_point.load.return_value = FakeDialect

        mock_dict = Mock(spec=["get"])
        mock_dict.get.return_value = [fake_entry_point]

        mock_entry_points = Mock(return_value=mock_dict)

        with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
            dialect = Dialect.get("fakedb")

        self.assertIsNotNone(dialect)
        self.assertEqual(dialect, FakeDialect)
        fake_entry_point.load.assert_called_once()
        mock_dict.get.assert_called_once_with("sqlglot.dialects", [])

    def test_entry_point_plugin_not_found(self):
        mock_selectable = Mock()
        mock_selectable.select.return_value = []

        mock_entry_points = Mock(return_value=mock_selectable)

        with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
            dialect = Dialect.get("nonexistent")

        self.assertIsNone(dialect)