1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional, TypeVar, Union
import pytest
from pydantic import BaseModel, Field
from pydantic import __version__ as pydantic_version
from polyfactory.factories.dataclass_factory import DataclassFactory
from polyfactory.factories.pydantic_factory import ModelFactory
class _Sentinel: ...
@dataclass
class Node:
value: int
union_child: Union[Node, int] # noqa: UP007
list_child: List[Node] # noqa: UP006
optional_child: Optional[Node] # noqa: RUF100, UP007, UP045
child: Node = field(default=_Sentinel) # type: ignore[assignment]
def __post_init__(self) -> None:
# Emulate recursive models set by external init, e.g. ORM relationships
if self.child is _Sentinel: # type: ignore[comparison-overlap]
self.child = self
def test_recursive_model() -> None:
factory = DataclassFactory.create_factory(Node)
result = factory.build()
assert result.child is result, "Default is not used"
assert isinstance(result.union_child, int)
assert result.optional_child is None
assert result.list_child == []
assert factory.build(child={"child": None}).child.child is None
class PydanticNode(BaseModel):
value: int
union_child: Union[PydanticNode, int] # noqa: UP007
list_child: List[PydanticNode] # noqa: UP006
optional_union_child: Union[PydanticNode, None] # noqa: UP007
optional_child: Optional[PydanticNode] # noqa: RUF100, UP007, UP045
child: PydanticNode = Field(default=_Sentinel) # type: ignore[assignment]
recursive_key: Dict[PydanticNode, Any] # noqa: UP006
recursive_value: Dict[str, PydanticNode] # noqa: UP006
@pytest.mark.parametrize("factory_use_construct", (True, False))
def test_recursive_pydantic_models(factory_use_construct: bool) -> None:
factory = ModelFactory.create_factory(PydanticNode)
result = factory.build(factory_use_construct)
assert result.child is _Sentinel, "Default is not used" # type: ignore[comparison-overlap]
assert isinstance(result.union_child, int)
assert result.optional_union_child is None
assert result.optional_child is None
assert result.list_child == []
assert result.recursive_key == {}
assert result.recursive_value == {}
@dataclass
class Author:
name: str
books: List[Book] # noqa: UP006
_DEFAULT_AUTHOR = Author(name="default", books=[])
@dataclass
class Book:
name: str
author: Author = field(default_factory=lambda: _DEFAULT_AUTHOR)
def test_recursive_list_model() -> None:
factory = DataclassFactory.create_factory(Author)
assert factory.build().books[0].author is _DEFAULT_AUTHOR
assert factory.build(books=[]).books == []
book_factory = DataclassFactory.create_factory(Book)
assert book_factory.build().author.books == []
assert book_factory.build(author=None).author is None
@pytest.mark.skipif(pydantic_version.startswith("1"), reason="Pydantic v2+ is required for JsonValue")
def test_recursive_type_annotation() -> None:
from pydantic import JsonValue
class RecursiveTypeModel(BaseModel):
json_value: JsonValue
factory = ModelFactory.create_factory(RecursiveTypeModel)
results = factory.batch(50)
valid_types = {int, str, bool, float, dict, list, type(None)}
assert _get_types(result.json_value for result in results) == valid_types
assert _get_types(result.json_value for result in factory.coverage()) == valid_types
RecursiveType = Union[List["RecursiveType"], int]
def test_recursive_model_with_forward_ref() -> None:
@dataclass
class RecursiveTypeModel:
json_value: RecursiveType
factory = DataclassFactory.create_factory(
RecursiveTypeModel,
__forward_references__={"RecursiveType": int},
)
results = factory.batch(50)
valid_types = {int, list}
assert _get_types(result.json_value for result in results) == valid_types
assert _get_types(result.json_value for result in factory.coverage()) == valid_types
_T = TypeVar("_T")
def _get_types(items: Iterable[_T]) -> set[type[_T]]:
return {type(item) for item in items}
|