1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
from typing import Any, Dict, Generic, Type, TypedDict, TypeVar
import pytest
from attrs import define
from msgspec import Struct
from sqlalchemy import Column, Integer
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry
from pydantic import BaseModel
from pydantic.generics import GenericModel
from polyfactory import ConfigurationException
from polyfactory.factories import TypedDictFactory
from polyfactory.factories.attrs_factory import AttrsFactory
from polyfactory.factories.base import BaseFactory
from polyfactory.factories.msgspec_factory import MsgspecFactory
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory
try:
from odmantic import Model
from polyfactory.factories.odmantic_odm_factory import OdmanticModelFactory
except ImportError:
Model, OdmanticModelFactory = None, None # type: ignore
try:
from beanie import Document
from polyfactory.factories.beanie_odm_factory import BeanieDocumentFactory
except ImportError:
BeanieDocumentFactory = None # type: ignore
Document = None # type: ignore
@define
class AttrsBase:
bool_field: bool
class ModelBase(BaseModel):
dict_field: Dict[str, int]
class MsgspecBase(Struct):
int_field: int
class Base(metaclass=DeclarativeMeta):
__abstract__ = True
registry = registry()
class SQLAlchemyBase(Base):
__tablename__ = "model"
id: Any = Column(Integer(), primary_key=True)
class TypedDictBase(TypedDict):
name: str
@pytest.mark.parametrize(
"base_factory, generic_arg",
[
(AttrsFactory, AttrsBase),
(ModelFactory, ModelBase),
(MsgspecFactory, MsgspecBase),
(SQLAlchemyFactory, SQLAlchemyBase),
(TypedDictFactory, TypedDictBase),
],
)
def test_modeL_inference_ok(base_factory: Type[BaseFactory], generic_arg: Type[Any]) -> None:
class Foo(base_factory[generic_arg]): # type: ignore
...
assert getattr(Foo, "__model__") is generic_arg
@pytest.mark.skipif(Model is None, reason="Odmantic import error")
def test_odmantic_model_inference_ok() -> None:
class OdmanticModelBase(Model): # type: ignore
name: str
class Foo(OdmanticModelFactory[OdmanticModelBase]): ...
assert getattr(Foo, "__model__") is OdmanticModelBase
@pytest.mark.skipif(Document is None, reason="Beanie import error")
def test_beanie_model_inference_ok() -> None:
class BeanieBase(Document):
name: str
class Foo(BeanieDocumentFactory[BeanieBase]): ...
assert getattr(Foo, "__model__") is BeanieBase
@pytest.mark.parametrize(
"base_factory",
[
AttrsFactory,
ModelFactory,
MsgspecFactory,
SQLAlchemyFactory,
TypedDictFactory,
],
)
def test_model_without_generic_type_inference_error(base_factory: Type[BaseFactory]) -> None:
with pytest.raises(ConfigurationException):
class Foo(base_factory): # type: ignore
...
@pytest.mark.parametrize(
"base_factory",
[
AttrsFactory,
ModelFactory,
MsgspecFactory,
SQLAlchemyFactory,
TypedDictFactory,
],
)
def test_model_type_error(base_factory: Type[BaseFactory]) -> None:
with pytest.raises(ConfigurationException):
class Foo(base_factory[int]): # type: ignore
...
def test_model_multiple_inheritance_cannot_infer_error() -> None:
class PFoo(BaseModel):
val: int
class TDFoo(TypedDict):
val: str
with pytest.raises(ConfigurationException):
class Foo(ModelFactory[PFoo], TypedDictFactory[TDFoo]): # type: ignore
...
def test_generic_model_is_not_an_error() -> None:
T = TypeVar("T")
P = TypeVar("P")
class Foo(GenericModel, Generic[T, P]): # type: ignore[misc]
val1: T
val2: P
class FooFactory(ModelFactory[Foo[str, int]]): ...
assert getattr(FooFactory, "__model__") is Foo[str, int]
|