1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
import unittest
from unittest.mock import Mock, patch
from sqlglot import Dialect
from sqlglot.dialects.dialect import Dialect as DialectBase
class FakeDialect(DialectBase):
pass
class TestDialectEntryPoints(unittest.TestCase):
def setUp(self):
Dialect._classes.clear()
def tearDown(self):
Dialect._classes.clear()
def test_entry_point_plugin_discovery_modern_api(self):
fake_entry_point = Mock()
fake_entry_point.name = "fakedb"
fake_entry_point.load.return_value = FakeDialect
mock_selectable = Mock()
mock_selectable.select.return_value = [fake_entry_point]
mock_entry_points = Mock(return_value=mock_selectable)
with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
dialect = Dialect.get("fakedb")
self.assertIsNotNone(dialect)
self.assertEqual(dialect, FakeDialect)
fake_entry_point.load.assert_called_once()
mock_selectable.select.assert_called_once_with(group="sqlglot.dialects", name="fakedb")
def test_entry_point_plugin_discovery_legacy_api(self):
fake_entry_point = Mock()
fake_entry_point.name = "fakedb"
fake_entry_point.load.return_value = FakeDialect
mock_dict = Mock(spec=["get"])
mock_dict.get.return_value = [fake_entry_point]
mock_entry_points = Mock(return_value=mock_dict)
with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
dialect = Dialect.get("fakedb")
self.assertIsNotNone(dialect)
self.assertEqual(dialect, FakeDialect)
fake_entry_point.load.assert_called_once()
mock_dict.get.assert_called_once_with("sqlglot.dialects", [])
def test_entry_point_plugin_not_found(self):
mock_selectable = Mock()
mock_selectable.select.return_value = []
mock_entry_points = Mock(return_value=mock_selectable)
with patch("sqlglot.dialects.dialect.entry_points", mock_entry_points):
dialect = Dialect.get("nonexistent")
self.assertIsNone(dialect)
|