1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
"""
Test pydantic model type hints (annotations) and that they can be
queried by :py:meth:`typing.get_type_hints`.
"""
import inspect
import sys
from functools import lru_cache
from typing import (
Any,
Generic,
Optional,
TypeVar,
)
import pytest
import typing_extensions
from pydantic import (
BaseModel,
RootModel,
)
from pydantic.dataclasses import dataclass
DEPRECATED_MODEL_MEMBERS = {
'construct',
'copy',
'dict',
'from_orm',
'json',
'json_schema',
'parse_file',
'parse_obj',
}
# Disable deprecation warnings, as we enumerate members that may be
# i.e. pydantic.warnings.PydanticDeprecatedSince20: The `__fields__` attribute is deprecated,
# use `model_fields` instead.
# Additionally, only run these tests for 3.10+
pytestmark = [
pytest.mark.filterwarnings('ignore::DeprecationWarning'),
pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher to work properly'),
]
@pytest.fixture(name='ParentModel', scope='session')
def parent_sub_model_fixture():
class UltraSimpleModel(BaseModel):
a: float
b: int = 10
class ParentModel(BaseModel):
grape: bool
banana: UltraSimpleModel
return ParentModel
@lru_cache
def get_type_checking_only_ns():
"""
When creating `BaseModel` in `pydantic.main`, some globals are imported only when `TYPE_CHECKING` is `True`, so we have to manually include them when calling `typing.get_type_hints`.
"""
from inspect import Signature
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator
from pydantic.deprecated.parse import Protocol as DeprecatedParseProtocol
from pydantic.fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr
from pydantic.fields import PrivateAttr as _PrivateAttr
return {
'Signature': Signature,
'CoreSchema': CoreSchema,
'SchemaSerializer': SchemaSerializer,
'SchemaValidator': SchemaValidator,
'DeprecatedParseProtocol': DeprecatedParseProtocol,
'ComputedFieldInfo': ComputedFieldInfo,
'FieldInfo': FieldInfo,
'ModelPrivateAttr': ModelPrivateAttr,
'_PrivateAttr': _PrivateAttr,
}
def inspect_type_hints(
obj_type, members: Optional[set[str]] = None, exclude_members: Optional[set[str]] = None, recursion_limit: int = 3
):
"""
Test an object and its members to make sure type hints can be resolved.
:param obj_type: Type to check
:param members: Explicit set of members to check, None to check all
:param exclude_members: Set of member names to exclude
:param recursion_limit: Recursion limit (0 to disallow)
"""
try:
hints = typing_extensions.get_type_hints(obj_type, localns=get_type_checking_only_ns())
assert isinstance(hints, dict), f'Type annotation(s) on {obj_type} are invalid'
except NameError as ex:
raise AssertionError(f'Type annotation(s) on {obj_type} are invalid: {str(ex)}') from ex
if recursion_limit <= 0:
return
if isinstance(obj_type, type):
# Check class members
for member_name, member_obj in inspect.getmembers(obj_type):
if member_name.startswith('_'):
# Ignore private members
continue
if (members and member_name not in members) or (exclude_members and member_name in exclude_members):
continue
if inspect.isclass(member_obj) or inspect.isfunction(member_obj):
# Inspect all child members (can't exclude specific ones)
inspect_type_hints(member_obj, recursion_limit=recursion_limit - 1)
@pytest.mark.parametrize(
('obj_type', 'members', 'exclude_members'),
[
(BaseModel, None, DEPRECATED_MODEL_MEMBERS),
(RootModel, None, DEPRECATED_MODEL_MEMBERS),
],
)
def test_obj_type_hints(obj_type, members: Optional[set[str]], exclude_members: Optional[set[str]]):
"""
Test an object and its members to make sure type hints can be resolved.
:param obj_type: Type to check
:param members: Explicit set of members to check, None to check all
:param exclude_members: Set of member names to exclude
"""
inspect_type_hints(obj_type, members, exclude_members)
def test_parent_sub_model(ParentModel):
inspect_type_hints(ParentModel, None, DEPRECATED_MODEL_MEMBERS)
def test_root_model_as_field():
class MyRootModel(RootModel[int]):
pass
class MyModel(BaseModel):
root_model: MyRootModel
inspect_type_hints(MyRootModel, None, DEPRECATED_MODEL_MEMBERS)
inspect_type_hints(MyModel, None, DEPRECATED_MODEL_MEMBERS)
def test_generics():
data_type = TypeVar('data_type')
class Result(BaseModel, Generic[data_type]):
data: data_type
inspect_type_hints(Result, None, DEPRECATED_MODEL_MEMBERS)
inspect_type_hints(Result[dict[str, Any]], None, DEPRECATED_MODEL_MEMBERS)
def test_dataclasses():
@dataclass
class MyDataclass:
a: int
b: float
inspect_type_hints(MyDataclass)
|