File: test_type_hints.py

package info (click to toggle)
pydantic 2.12.5-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 7,640 kB
  • sloc: python: 75,984; javascript: 181; makefile: 115; sh: 38
file content (168 lines) | stat: -rw-r--r-- 5,098 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Test pydantic model type hints (annotations) and that they can be
queried by :py:meth:`typing.get_type_hints`.
"""

import inspect
import sys
from functools import lru_cache
from typing import (
    Any,
    Generic,
    Optional,
    TypeVar,
)

import pytest
import typing_extensions

from pydantic import (
    BaseModel,
    RootModel,
)
from pydantic.dataclasses import dataclass

DEPRECATED_MODEL_MEMBERS = {
    'construct',
    'copy',
    'dict',
    'from_orm',
    'json',
    'json_schema',
    'parse_file',
    'parse_obj',
}

# Disable deprecation warnings, as we enumerate members that may be
# i.e. pydantic.warnings.PydanticDeprecatedSince20: The `__fields__` attribute is deprecated,
#      use `model_fields` instead.
# Additionally, only run these tests for 3.10+
pytestmark = [
    pytest.mark.filterwarnings('ignore::DeprecationWarning'),
    pytest.mark.skipif(sys.version_info < (3, 10), reason='requires python3.10 or higher to work properly'),
]


@pytest.fixture(name='ParentModel', scope='session')
def parent_sub_model_fixture():
    class UltraSimpleModel(BaseModel):
        a: float
        b: int = 10

    class ParentModel(BaseModel):
        grape: bool
        banana: UltraSimpleModel

    return ParentModel


@lru_cache
def get_type_checking_only_ns():
    """
    When creating `BaseModel` in `pydantic.main`, some globals are imported only when `TYPE_CHECKING` is `True`, so we have to manually include them when calling `typing.get_type_hints`.
    """

    from inspect import Signature

    from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator

    from pydantic.deprecated.parse import Protocol as DeprecatedParseProtocol
    from pydantic.fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr
    from pydantic.fields import PrivateAttr as _PrivateAttr

    return {
        'Signature': Signature,
        'CoreSchema': CoreSchema,
        'SchemaSerializer': SchemaSerializer,
        'SchemaValidator': SchemaValidator,
        'DeprecatedParseProtocol': DeprecatedParseProtocol,
        'ComputedFieldInfo': ComputedFieldInfo,
        'FieldInfo': FieldInfo,
        'ModelPrivateAttr': ModelPrivateAttr,
        '_PrivateAttr': _PrivateAttr,
    }


def inspect_type_hints(
    obj_type, members: Optional[set[str]] = None, exclude_members: Optional[set[str]] = None, recursion_limit: int = 3
):
    """
    Test an object and its members to make sure type hints can be resolved.
    :param obj_type: Type to check
    :param members: Explicit set of members to check, None to check all
    :param exclude_members: Set of member names to exclude
    :param recursion_limit: Recursion limit (0 to disallow)
    """

    try:
        hints = typing_extensions.get_type_hints(obj_type, localns=get_type_checking_only_ns())
        assert isinstance(hints, dict), f'Type annotation(s) on {obj_type} are invalid'
    except NameError as ex:
        raise AssertionError(f'Type annotation(s) on {obj_type} are invalid: {str(ex)}') from ex

    if recursion_limit <= 0:
        return

    if isinstance(obj_type, type):
        # Check class members
        for member_name, member_obj in inspect.getmembers(obj_type):
            if member_name.startswith('_'):
                # Ignore private members
                continue
            if (members and member_name not in members) or (exclude_members and member_name in exclude_members):
                continue

            if inspect.isclass(member_obj) or inspect.isfunction(member_obj):
                # Inspect all child members (can't exclude specific ones)
                inspect_type_hints(member_obj, recursion_limit=recursion_limit - 1)


@pytest.mark.parametrize(
    ('obj_type', 'members', 'exclude_members'),
    [
        (BaseModel, None, DEPRECATED_MODEL_MEMBERS),
        (RootModel, None, DEPRECATED_MODEL_MEMBERS),
    ],
)
def test_obj_type_hints(obj_type, members: Optional[set[str]], exclude_members: Optional[set[str]]):
    """
    Test an object and its members to make sure type hints can be resolved.
    :param obj_type: Type to check
    :param members: Explicit set of members to check, None to check all
    :param exclude_members: Set of member names to exclude
    """
    inspect_type_hints(obj_type, members, exclude_members)


def test_parent_sub_model(ParentModel):
    inspect_type_hints(ParentModel, None, DEPRECATED_MODEL_MEMBERS)


def test_root_model_as_field():
    class MyRootModel(RootModel[int]):
        pass

    class MyModel(BaseModel):
        root_model: MyRootModel

    inspect_type_hints(MyRootModel, None, DEPRECATED_MODEL_MEMBERS)
    inspect_type_hints(MyModel, None, DEPRECATED_MODEL_MEMBERS)


def test_generics():
    data_type = TypeVar('data_type')

    class Result(BaseModel, Generic[data_type]):
        data: data_type

    inspect_type_hints(Result, None, DEPRECATED_MODEL_MEMBERS)
    inspect_type_hints(Result[dict[str, Any]], None, DEPRECATED_MODEL_MEMBERS)


def test_dataclasses():
    @dataclass
    class MyDataclass:
        a: int
        b: float

    inspect_type_hints(MyDataclass)