File: test_persistence_handlers.py

package info (click to toggle)
python-polyfactory 2.22.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,892 kB
  • sloc: python: 11,338; makefile: 103; sh: 37
file content (66 lines) | stat: -rw-r--r-- 2,105 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Any

import pytest

from pydantic import BaseModel

from polyfactory import AsyncPersistenceProtocol, SyncPersistenceProtocol
from polyfactory.factories.pydantic_factory import ModelFactory


class MyModel(BaseModel):
    name: str


class MySyncPersistenceHandler(SyncPersistenceProtocol):
    def save(self, data: Any, *args: Any, **kwargs: Any) -> Any:
        return data

    def save_many(self, data: Any, *args: Any, **kwargs: Any) -> Any:
        return data


class MyAsyncPersistenceHandler(AsyncPersistenceProtocol):
    async def save(self, data: Any, *args: Any, **kwargs: Any) -> Any:
        return data

    async def save_many(self, data: Any, *args: Any, **kwargs: Any) -> Any:
        return data


def test_sync_persistence_handler_is_set_and_called_with_instance() -> None:
    class MyFactory(ModelFactory):
        __model__ = MyModel
        __sync_persistence__ = MySyncPersistenceHandler()

    assert MyFactory.create_sync().name
    assert [instance.name for instance in MyFactory.create_batch_sync(size=2)]


def test_sync_persistence_handler_is_set_and_called_with_class() -> None:
    class MyFactory(ModelFactory):
        __model__ = MyModel
        __sync_persistence__ = MySyncPersistenceHandler

    assert MyFactory.create_sync().name
    assert [instance.name for instance in MyFactory.create_batch_sync(size=2)]


@pytest.mark.asyncio()
async def test_async_persistence_handler_is_set_and_called_with_instance() -> None:
    class MyFactory(ModelFactory):
        __model__ = MyModel
        __async_persistence__ = MyAsyncPersistenceHandler()

    assert (await MyFactory.create_async()).name
    assert [instance.name for instance in (await MyFactory.create_batch_async(size=2))]


@pytest.mark.asyncio()
async def test_async_persistence_handler_is_set_and_called_with_class() -> None:
    class MyFactory(ModelFactory):
        __model__ = MyModel
        __async_persistence__ = MyAsyncPersistenceHandler

    assert (await MyFactory.create_async()).name
    assert [instance.name for instance in (await MyFactory.create_batch_async(size=2))]