File: test_optional_model_field_inference.py

package info (click to toggle)
python-polyfactory 2.22.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,892 kB
  • sloc: python: 11,338; makefile: 103; sh: 37
file content (159 lines) | stat: -rw-r--r-- 3,994 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from typing import Any, Dict, Generic, Type, TypedDict, TypeVar

import pytest
from attrs import define
from msgspec import Struct
from sqlalchemy import Column, Integer
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry

from pydantic import BaseModel
from pydantic.generics import GenericModel

from polyfactory import ConfigurationException
from polyfactory.factories import TypedDictFactory
from polyfactory.factories.attrs_factory import AttrsFactory
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.msgspec_factory import MsgspecFactory
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

try:
    from odmantic import Model

    from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory
except ImportError:
    Model, OdmanticModelFactory = None, None  # type: ignore

try:
    from beanie import Document

    from polyfactory.factories.beanie_odm_factory import BeanieDocumentFactory
except ImportError:
    BeanieDocumentFactory = None  # type: ignore
    Document = None  # type: ignore


@define
class AttrsBase:
    bool_field: bool


class ModelBase(BaseModel):
    dict_field: Dict[str, int]


class MsgspecBase(Struct):
    int_field: int


class Base(metaclass=DeclarativeMeta):
    __abstract__ = True

    registry = registry()


class SQLAlchemyBase(Base):
    __tablename__ = "model"

    id: Any = Column(Integer(), primary_key=True)


class TypedDictBase(TypedDict):
    name: str


@pytest.mark.parametrize(
    "base_factory, generic_arg",
    [
        (AttrsFactory, AttrsBase),
        (ModelFactory, ModelBase),
        (MsgspecFactory, MsgspecBase),
        (SQLAlchemyFactory, SQLAlchemyBase),
        (TypedDictFactory, TypedDictBase),
    ],
)
def test_modeL_inference_ok(base_factory: Type[BaseFactory], generic_arg: Type[Any]) -> None:
    class Foo(base_factory[generic_arg]):  # type: ignore
        ...

    assert getattr(Foo, "__model__") is generic_arg


@pytest.mark.skipif(Model is None, reason="Odmantic import error")
def test_odmantic_model_inference_ok() -> None:
    class OdmanticModelBase(Model):  # type: ignore
        name: str

    class Foo(OdmanticModelFactory[OdmanticModelBase]): ...

    assert getattr(Foo, "__model__") is OdmanticModelBase


@pytest.mark.skipif(Document is None, reason="Beanie import error")
def test_beanie_model_inference_ok() -> None:
    class BeanieBase(Document):
        name: str

    class Foo(BeanieDocumentFactory[BeanieBase]): ...

    assert getattr(Foo, "__model__") is BeanieBase


@pytest.mark.parametrize(
    "base_factory",
    [
        AttrsFactory,
        ModelFactory,
        MsgspecFactory,
        SQLAlchemyFactory,
        TypedDictFactory,
    ],
)
def test_model_without_generic_type_inference_error(base_factory: Type[BaseFactory]) -> None:
    with pytest.raises(ConfigurationException):

        class Foo(base_factory):  # type: ignore
            ...


@pytest.mark.parametrize(
    "base_factory",
    [
        AttrsFactory,
        ModelFactory,
        MsgspecFactory,
        SQLAlchemyFactory,
        TypedDictFactory,
    ],
)
def test_model_type_error(base_factory: Type[BaseFactory]) -> None:
    with pytest.raises(ConfigurationException):

        class Foo(base_factory[int]):  # type: ignore
            ...


def test_model_multiple_inheritance_cannot_infer_error() -> None:
    class PFoo(BaseModel):
        val: int

    class TDFoo(TypedDict):
        val: str

    with pytest.raises(ConfigurationException):

        class Foo(ModelFactory[PFoo], TypedDictFactory[TDFoo]):  # type: ignore
            ...


def test_generic_model_is_not_an_error() -> None:
    T = TypeVar("T")
    P = TypeVar("P")

    class Foo(GenericModel, Generic[T, P]):  # type: ignore[misc]
        val1: T
        val2: P

    class FooFactory(ModelFactory[Foo[str, int]]): ...

    assert getattr(FooFactory, "__model__") is Foo[str, int]