File: test_common.py

package info (click to toggle)
python-advanced-alchemy 1.4.1-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 3,708 kB
  • sloc: python: 25,811; makefile: 162; javascript: 123; sh: 4
file content (159 lines) | stat: -rw-r--r-- 5,963 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from __future__ import annotations

import datetime
import uuid
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

import pytest
from litestar.datastructures import State
from sqlalchemy import create_engine

from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.extensions.litestar._utils import _SCOPE_NAMESPACE  # pyright: ignore[reportPrivateUsage]
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.common import SESSION_SCOPE_KEY

if TYPE_CHECKING:
    from typing import Any

    from litestar.types import Scope
    from pytest import MonkeyPatch


@pytest.fixture(name="config_cls", params=[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig])
def _config_cls(request: Any) -> type[SQLAlchemySyncConfig | SQLAlchemyAsyncConfig]:
    """Return SQLAlchemy config class."""
    return request.param  # type:ignore[no-any-return]


def test_raise_improperly_configured_exception(config_cls: type[SQLAlchemySyncConfig]) -> None:
    """Test raise ImproperlyConfiguredException if both engine and connection string are provided."""
    with pytest.raises(ImproperConfigurationError):
        config_cls(connection_string="sqlite://", engine_instance=create_engine("sqlite://"))


def test_engine_config_dict_with_no_provided_config(
    config_cls: type[SQLAlchemySyncConfig],
) -> None:
    """Test engine_config_dict with no provided config."""
    config = config_cls()
    assert config.engine_config_dict.keys() == {"json_deserializer", "json_serializer"}


def test_session_config_dict_with_no_provided_config(
    config_cls: type[SQLAlchemySyncConfig],
) -> None:
    """Test session_config_dict with no provided config."""
    config = config_cls()
    assert config.session_config_dict == {}


def test_config_create_engine_if_engine_instance_provided(
    config_cls: type[SQLAlchemySyncConfig],
) -> None:
    """Test create_engine if engine instance provided."""
    engine = create_engine("sqlite://")
    config = config_cls(engine_instance=engine)
    assert config.get_engine() == engine


def test_create_engine_if_no_engine_instance_or_connection_string_provided(
    config_cls: type[SQLAlchemySyncConfig],
) -> None:
    """Test create_engine if no engine instance or connection string provided."""
    config = config_cls()
    with pytest.raises(ImproperConfigurationError):
        config.get_engine()


def test_call_create_engine_callable_type_error_handling(
    config_cls: type[SQLAlchemySyncConfig],
    monkeypatch: MonkeyPatch,
) -> None:
    """If the dialect doesn't support JSON types, we get a ValueError.
    This should be handled by removing the JSON serializer/deserializer kwargs.
    """
    call_count = 0

    def side_effect(*args: Any, **kwargs: Any) -> None:
        nonlocal call_count
        call_count += 1
        if call_count == 1:
            raise TypeError()

    config = config_cls(connection_string="sqlite://")
    create_engine_callable_mock = MagicMock(side_effect=side_effect)
    monkeypatch.setattr(config, "create_engine_callable", create_engine_callable_mock)

    config.get_engine()

    assert create_engine_callable_mock.call_count == 2
    first_call, second_call = create_engine_callable_mock.mock_calls
    assert first_call.kwargs.keys() == {"json_deserializer", "json_serializer"}
    assert second_call.kwargs.keys() == set()


def test_create_session_maker_if_session_maker_provided(
    config_cls: type[SQLAlchemySyncConfig],
) -> None:
    """Test create_session_maker if session maker provided to config."""
    session_maker = MagicMock()
    config = config_cls(session_maker=session_maker)
    assert config.create_session_maker() == session_maker


def test_create_session_maker_if_no_session_maker_or_bind_provided(
    config_cls: type[SQLAlchemySyncConfig],
    monkeypatch: MonkeyPatch,
) -> None:
    """Test create_session_maker if no session maker or bind provided to config."""
    config = config_cls()
    create_engine_mock = MagicMock(return_value=create_engine("sqlite://"))
    monkeypatch.setattr(config, "get_engine", create_engine_mock)
    assert config.session_maker is None
    assert isinstance(config.create_session_maker(), config.session_maker_class)
    create_engine_mock.assert_called_once()


def test_create_session_instance_if_session_not_in_scope_state(
    config_cls: type[SQLAlchemySyncConfig],
) -> None:
    """Test provide_session if session not in scope state."""
    with patch(
        "advanced_alchemy.extensions.litestar._utils.get_aa_scope_state",
    ) as get_scope_state_mock:
        get_scope_state_mock.return_value = None
        config = config_cls()
        state = State()
        state[config.session_maker_app_state_key] = MagicMock()
        scope: Scope = {}  # type:ignore[assignment]
        assert isinstance(config.provide_session(state, scope), MagicMock)
        assert SESSION_SCOPE_KEY in scope[_SCOPE_NAMESPACE]  # type: ignore[literal-required]


def test_app_state(config_cls: type[SQLAlchemySyncConfig], monkeypatch: MonkeyPatch) -> None:
    """Test app_state."""
    config = config_cls(connection_string="sqlite://")
    with (
        patch.object(config, "create_session_maker") as create_session_maker_mock,
        patch.object(config, "get_engine") as create_engine_mock,
    ):
        assert config.create_app_state_items().keys() == {
            config.engine_app_state_key,
            config.session_maker_app_state_key,
        }
        create_session_maker_mock.assert_called_once()
        create_engine_mock.assert_called_once()


def test_namespace_resolution() -> None:
    # https://github.com/litestar-org/advanced-alchemy/issues/256

    from litestar import Litestar, get

    @get("/")
    async def handler(param: datetime.datetime, other_param: uuid.UUID) -> None:
        return None

    Litestar([handler])