File: test_recursive_models.py

package info (click to toggle)
python-polyfactory 2.22.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,892 kB
  • sloc: python: 11,338; makefile: 103; sh: 37
file content (134 lines) | stat: -rw-r--r-- 4,223 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, TypeVar, Union

import pytest

from pydantic import BaseModel, Field
from pydantic import __version__ as pydantic_version

from polyfactory.factories.dataclass_factory import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory


class _Sentinel: ...


@dataclass
class Node:
    value: int
    union_child: Union[Node, int]  # noqa: UP007
    list_child: List[Node]  # noqa: UP006
    optional_child: Optional[Node]  # noqa: RUF100, UP007, UP045
    child: Node = field(default=_Sentinel)  # type: ignore[assignment]

    def __post_init__(self) -> None:
        # Emulate recursive models set by external init, e.g. ORM relationships
        if self.child is _Sentinel:  # type: ignore[comparison-overlap]
            self.child = self


def test_recursive_model() -> None:
    factory = DataclassFactory.create_factory(Node)

    result = factory.build()
    assert result.child is result, "Default is not used"
    assert isinstance(result.union_child, int)
    assert result.optional_child is None
    assert result.list_child == []

    assert factory.build(child={"child": None}).child.child is None


class PydanticNode(BaseModel):
    value: int
    union_child: Union[PydanticNode, int]  # noqa: UP007
    list_child: List[PydanticNode]  # noqa: UP006
    optional_union_child: Union[PydanticNode, None]  # noqa: UP007
    optional_child: Optional[PydanticNode]  # noqa: RUF100, UP007, UP045
    child: PydanticNode = Field(default=_Sentinel)  # type: ignore[assignment]
    recursive_key: Dict[PydanticNode, Any]  # noqa: UP006
    recursive_value: Dict[str, PydanticNode]  # noqa: UP006


@pytest.mark.parametrize("factory_use_construct", (True, False))
def test_recursive_pydantic_models(factory_use_construct: bool) -> None:
    factory = ModelFactory.create_factory(PydanticNode)

    result = factory.build(factory_use_construct)
    assert result.child is _Sentinel, "Default is not used"  # type: ignore[comparison-overlap]
    assert isinstance(result.union_child, int)
    assert result.optional_union_child is None
    assert result.optional_child is None
    assert result.list_child == []
    assert result.recursive_key == {}
    assert result.recursive_value == {}


@dataclass
class Author:
    name: str
    books: List[Book]  # noqa: UP006


_DEFAULT_AUTHOR = Author(name="default", books=[])


@dataclass
class Book:
    name: str
    author: Author = field(default_factory=lambda: _DEFAULT_AUTHOR)


def test_recursive_list_model() -> None:
    factory = DataclassFactory.create_factory(Author)
    assert factory.build().books[0].author is _DEFAULT_AUTHOR
    assert factory.build(books=[]).books == []

    book_factory = DataclassFactory.create_factory(Book)
    assert book_factory.build().author.books == []
    assert book_factory.build(author=None).author is None


@pytest.mark.skipif(pydantic_version.startswith("1"), reason="Pydantic v2+ is required for JsonValue")
def test_recursive_type_annotation() -> None:
    from pydantic import JsonValue

    class RecursiveTypeModel(BaseModel):
        json_value: JsonValue

    factory = ModelFactory.create_factory(RecursiveTypeModel)
    results = factory.batch(50)

    valid_types = {int, str, bool, float, dict, list, type(None)}

    assert _get_types(result.json_value for result in results) == valid_types
    assert _get_types(result.json_value for result in factory.coverage()) == valid_types


RecursiveType = Union[List["RecursiveType"], int]


def test_recursive_model_with_forward_ref() -> None:
    @dataclass
    class RecursiveTypeModel:
        json_value: RecursiveType

    factory = DataclassFactory.create_factory(
        RecursiveTypeModel,
        __forward_references__={"RecursiveType": int},
    )
    results = factory.batch(50)

    valid_types = {int, list}

    assert _get_types(result.json_value for result in results) == valid_types
    assert _get_types(result.json_value for result in factory.coverage()) == valid_types


_T = TypeVar("_T")


def _get_types(items: Iterable[_T]) -> set[type[_T]]:
    return {type(item) for item in items}