File: test_operations.py

package info (click to toggle)
python-advanced-alchemy 1.8.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,904 kB
  • sloc: python: 36,227; makefile: 153; sh: 4
file content (293 lines) | stat: -rw-r--r-- 12,193 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
"""Tests for advanced_alchemy.operations module."""

from typing import Any

import pytest
from sqlalchemy import Column, Integer, MetaData, String, Table

from advanced_alchemy.operations import MergeStatement, OnConflictUpsert, validate_identifier


@pytest.fixture
def sample_table() -> Table:
    """Create a sample table for testing."""
    metadata = MetaData()
    return Table(
        "test_table",
        metadata,
        Column("id", Integer, primary_key=True),
        Column("key", String(50), nullable=False),
        Column("namespace", String(50), nullable=False),
        Column("value", String(255)),
    )


class TestOnConflictUpsert:
    """Test OnConflictUpsert operations."""

    def test_supports_native_upsert(self) -> None:
        """Test dialect support detection."""
        # Supported dialects
        assert OnConflictUpsert.supports_native_upsert("postgresql") is True
        assert OnConflictUpsert.supports_native_upsert("cockroachdb") is True
        assert OnConflictUpsert.supports_native_upsert("sqlite") is True
        assert OnConflictUpsert.supports_native_upsert("mysql") is True
        assert OnConflictUpsert.supports_native_upsert("mariadb") is True
        assert OnConflictUpsert.supports_native_upsert("duckdb") is True

        # Unsupported dialects
        assert OnConflictUpsert.supports_native_upsert("oracle") is False
        assert OnConflictUpsert.supports_native_upsert("mssql") is False
        assert OnConflictUpsert.supports_native_upsert("unknown") is False

    def test_create_postgresql_upsert(self, sample_table: Table) -> None:
        """Test PostgreSQL ON CONFLICT upsert generation."""
        values = {"key": "test_key", "namespace": "test_ns", "value": "test_value"}
        conflict_columns = ["key", "namespace"]

        upsert_stmt = OnConflictUpsert.create_upsert(
            table=sample_table,
            values=values,
            conflict_columns=conflict_columns,
            dialect_name="postgresql",
        )

        # Should return a PostgreSQL insert statement with ON CONFLICT
        # This is primarily testing that the method doesn't raise an exception
        # and returns the expected type
        assert upsert_stmt is not None
        assert hasattr(upsert_stmt, "on_conflict_do_update")

    def test_create_mysql_upsert(self, sample_table: Table) -> None:
        """Test MySQL ON DUPLICATE KEY UPDATE upsert generation."""
        values = {"key": "test_key", "namespace": "test_ns", "value": "test_value"}
        conflict_columns = ["key", "namespace"]

        upsert_stmt = OnConflictUpsert.create_upsert(
            table=sample_table,
            values=values,
            conflict_columns=conflict_columns,
            dialect_name="mysql",
        )

        # Should return a MySQL insert statement with ON DUPLICATE KEY UPDATE
        assert upsert_stmt is not None
        assert hasattr(upsert_stmt, "on_duplicate_key_update")

    def test_create_duckdb_upsert(self, sample_table: Table) -> None:
        """Test DuckDB ON CONFLICT upsert generation."""
        values = {"key": "test_key", "namespace": "test_ns", "value": "test_value"}
        conflict_columns = ["key", "namespace"]

        upsert_stmt = OnConflictUpsert.create_upsert(
            table=sample_table,
            values=values,
            conflict_columns=conflict_columns,
            dialect_name="duckdb",
        )

        # Should return a PostgreSQL-style insert statement with ON CONFLICT (DuckDB uses PostgreSQL syntax)
        assert upsert_stmt is not None
        assert hasattr(upsert_stmt, "on_conflict_do_update")

    def test_create_upsert_unsupported_dialect(self, sample_table: Table) -> None:
        """Test that unsupported dialects raise NotImplementedError."""
        values = {"key": "test_key", "namespace": "test_ns", "value": "test_value"}
        conflict_columns = ["key", "namespace"]

        with pytest.raises(NotImplementedError, match="Native upsert not supported for dialect 'oracle'"):
            OnConflictUpsert.create_upsert(
                table=sample_table,
                values=values,
                conflict_columns=conflict_columns,
                dialect_name="oracle",
            )

    def test_create_merge_upsert(self, sample_table: Table) -> None:
        """Test MERGE-based upsert generation."""
        values = {"key": "test_key", "namespace": "test_ns", "value": "test_value"}
        conflict_columns = ["key", "namespace"]

        # Test default (non-Oracle) MERGE
        merge_stmt, additional_params = OnConflictUpsert.create_merge_upsert(
            table=sample_table,
            values=values,
            conflict_columns=conflict_columns,
        )

        assert isinstance(merge_stmt, MergeStatement)
        assert merge_stmt.table == sample_table
        # Default dialect (None) uses PostgreSQL format with %(key)s notation
        assert "%(key)s" in str(merge_stmt.source) or ":key" in str(
            merge_stmt.source
        )  # Check for parameter placeholder
        assert "FROM DUAL" not in str(merge_stmt.source)  # Should not have FROM DUAL by default
        assert additional_params == {}  # No additional params for non-Oracle

        # Test Oracle-specific MERGE
        oracle_merge_stmt, oracle_additional_params = OnConflictUpsert.create_merge_upsert(
            table=sample_table,
            values=values,
            conflict_columns=conflict_columns,
            dialect_name="oracle",
        )

        assert isinstance(oracle_merge_stmt, MergeStatement)
        assert oracle_merge_stmt.table == sample_table
        assert ":key" in str(oracle_merge_stmt.source)  # Check for parameter placeholder
        assert "FROM DUAL" in str(oracle_merge_stmt.source)  # Oracle should have FROM DUAL
        # Additional params should be empty for tables without UUID primary keys
        assert isinstance(oracle_additional_params, dict)
        assert "SELECT" in str(merge_stmt.source)


class TestMergeStatement:
    """Test MergeStatement compilation."""

    def test_merge_statement_creation(self, sample_table: Table) -> None:
        """Test basic MergeStatement creation."""
        from sqlalchemy import bindparam, text

        source = "SELECT 'key1' as key, 'ns1' as namespace, 'value1' as value"
        on_condition = text("tgt.key = src.key AND tgt.namespace = src.namespace")
        when_matched_update: dict[str, Any] = {"value": bindparam("value")}
        when_not_matched_insert: dict[str, Any] = {
            "key": bindparam("key"),
            "namespace": bindparam("namespace"),
            "value": bindparam("value"),
        }

        merge_stmt = MergeStatement(
            table=sample_table,
            source=source,
            on_condition=on_condition,
            when_matched_update=when_matched_update,
            when_not_matched_insert=when_not_matched_insert,
        )

        assert merge_stmt.table == sample_table
        assert merge_stmt.source == source
        assert merge_stmt.on_condition == on_condition
        assert merge_stmt.when_matched_update == when_matched_update
        assert merge_stmt.when_not_matched_insert == when_not_matched_insert

    def test_compile_merge_default_raises_error(self, sample_table: Table) -> None:
        """Test that default compiler raises NotImplementedError."""
        from sqlalchemy import text

        from advanced_alchemy.operations import compile_merge_default

        merge_stmt = MergeStatement(
            table=sample_table,
            source="SELECT 1",
            on_condition=text("1=1"),
        )

        # Create a mock compiler for an unsupported dialect
        class MockDialect:
            name = "unsupported"

        class MockCompiler:
            dialect = MockDialect()

        compiler = MockCompiler()

        with pytest.raises(NotImplementedError, match="MERGE statement not supported for dialect 'unsupported'"):
            compile_merge_default(merge_stmt, compiler)  # type: ignore[arg-type]  # pyright: ignore


class TestIdentifierValidation:
    """Test identifier validation security feature."""

    def test_valid_identifiers(self) -> None:
        """Test that valid identifiers pass validation."""
        assert validate_identifier("user_id") == "user_id"
        assert validate_identifier("users_table", "table") == "users_table"
        assert validate_identifier("created_at", "column") == "created_at"
        assert validate_identifier("_private_field") == "_private_field"
        assert validate_identifier("table123") == "table123"

    def test_empty_identifier(self) -> None:
        """Test that empty identifiers are rejected."""
        with pytest.raises(ValueError, match="Empty identifier name"):
            validate_identifier("")

    def test_invalid_characters(self) -> None:
        """Test that identifiers with invalid characters are rejected."""
        invalid_names = [
            "user-id",  # hyphen
            "user.id",  # dot
            "user id",  # space
            "123user",  # starts with number
            "user;",  # semicolon
            "user'",  # quote
            "user`",  # backtick
            "drop table users; --",  # SQL injection attempt
        ]

        for name in invalid_names:
            with pytest.raises(ValueError, match=r"Invalid.*Only alphanumeric"):
                validate_identifier(name)

    def test_sql_keywords_allowed(self) -> None:
        """Test that SQL keywords are allowed as identifiers."""
        # SQL keywords should be allowed since they can be quoted in SQL
        keywords = ["select", "SELECT", "insert", "UPDATE", "delete", "DROP", "create", "ALTER", "truncate"]

        for keyword in keywords:
            # Should not raise an error
            assert validate_identifier(keyword) == keyword
            assert validate_identifier(keyword.lower()) == keyword.lower()
            assert validate_identifier(keyword.upper()) == keyword.upper()

    def test_identifier_type_in_error(self) -> None:
        """Test that identifier type appears in error messages."""
        with pytest.raises(ValueError, match="Empty column name"):
            validate_identifier("", "column")

        with pytest.raises(ValueError, match="Invalid table name"):
            validate_identifier("123invalid", "table")

    def test_upsert_with_validation(self, sample_table: Table) -> None:
        """Test that create_upsert validates identifiers when requested."""
        values = {"key": "test_key", "namespace": "test_ns", "value": "test_value"}

        # Should work with validation enabled for valid identifiers
        upsert_stmt = OnConflictUpsert.create_upsert(
            table=sample_table,
            values=values,
            conflict_columns=["key", "namespace"],
            update_columns=["value"],
            dialect_name="postgresql",
            validate_identifiers=True,
        )
        assert upsert_stmt is not None

    def test_merge_with_validation(self, sample_table: Table) -> None:
        """Test that create_merge_upsert validates identifiers when requested."""
        values = {"key": "test_key", "namespace": "test_ns", "value": "test_value"}

        # Should work with validation enabled for valid identifiers
        merge_stmt, _ = OnConflictUpsert.create_merge_upsert(
            table=sample_table,
            values=values,
            conflict_columns=["key", "namespace"],
            update_columns=["value"],
            dialect_name="oracle",
            validate_identifiers=True,
        )
        assert merge_stmt is not None


class TestStoreIntegration:
    """Test that the store can use the new operations."""

    def test_store_imports_operations(self) -> None:
        """Test that store successfully imports new operations."""
        from advanced_alchemy.extensions.litestar.store import SQLAlchemyStore
        from advanced_alchemy.operations import MergeStatement, OnConflictUpsert

        # This test passes if no import errors occur
        assert OnConflictUpsert is not None
        assert MergeStatement is not None
        assert SQLAlchemyStore is not None