File: test_serialization_plugin.py

package info (click to toggle)
python-advanced-alchemy 1.4.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 3,708 kB
  • sloc: python: 25,811; makefile: 162; javascript: 123; sh: 4
file content (76 lines) | stat: -rw-r--r-- 2,400 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from types import ModuleType
from typing import Callable

from litestar import get
from litestar.status_codes import HTTP_200_OK
from litestar.testing import RequestFactory, create_test_client
from sqlalchemy import String
from sqlalchemy.orm import Mapped, mapped_column

from advanced_alchemy.base import UUIDAuditBase
from advanced_alchemy.extensions.litestar import SQLAlchemySerializationPlugin
from advanced_alchemy.service.pagination import OffsetPagination


async def test_serialization_plugin(
    create_module: Callable[[str], ModuleType],
    request_factory: RequestFactory,
) -> None:
    module = create_module(
        """
from __future__ import annotations

from typing import Dict, List, Set, Tuple, Type, List

from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column

from litestar import Litestar, get, post
from advanced_alchemy.extensions.litestar import SQLAlchemySerializationPlugin

class Base(DeclarativeBase):
    id: Mapped[int] = mapped_column(primary_key=True)

class A(Base):
    __tablename__ = "a"
    a: Mapped[str]

@post("/a")
def post_handler(data: A) -> A:
    return data

@get("/a")
def get_handler() -> List[A]:
    return [A(id=1, a="test"), A(id=2, a="test2")]

@get("/a/1")
def get_a() -> A:
    return A(id=1, a="test")
""",
    )
    with create_test_client(
        route_handlers=[module.post_handler, module.get_handler, module.get_a],
        plugins=[SQLAlchemySerializationPlugin()],
    ) as client:
        response = client.post("/a", json={"id": 1, "a": "test"})
        assert response.status_code == 201
        assert response.json() == {"id": 1, "a": "test"}
        response = client.get("/a")
        assert response.json() == [{"id": 1, "a": "test"}, {"id": 2, "a": "test2"}]
        response = client.get("/a/1")
        assert response.json() == {"id": 1, "a": "test"}


class User(UUIDAuditBase):
    first_name: Mapped[str] = mapped_column(String(200))


def test_pagination_serialization() -> None:
    users = [User(first_name="ASD"), User(first_name="qwe")]

    @get("/paginated")
    async def paginated_handler() -> OffsetPagination[User]:
        return OffsetPagination[User](items=users, limit=2, offset=0, total=2)

    with create_test_client(paginated_handler, plugins=[SQLAlchemySerializationPlugin()]) as client:
        response = client.get("/paginated")
        assert response.status_code == HTTP_200_OK