File: test_utils.py

package info (click to toggle)
python-polyfactory 2.22.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,892 kB
  • sloc: python: 11,338; makefile: 103; sh: 37
file content (185 lines) | stat: -rw-r--r-- 7,112 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import random
import sys
import textwrap
from decimal import Decimal
from types import ModuleType
from typing import Any, Callable, NewType, Union

import pytest
from hypothesis import given
from hypothesis.strategies import decimals, floats, integers

from pydantic import BaseModel

from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.utils.helpers import unwrap_annotation, unwrap_new_type
from polyfactory.utils.predicates import is_new_type, is_union
from polyfactory.value_generators.constrained_numbers import (
    is_multiply_of_multiple_of_in_range,
)


def test_is_union() -> None:
    class UnionTest(BaseModel):
        union: Union[int, str]
        no_union: Any

    class UnionTestFactory(ModelFactory):
        __model__ = UnionTest

    for field_meta in UnionTestFactory.get_model_fields():
        if field_meta.name == "union":
            assert is_union(field_meta.annotation)
        else:
            assert not is_union(field_meta.annotation)

    # for python 3.10 we need to run the test as well with the union_pipe operator
    if sys.version_info >= (3, 10):

        class UnionTestWithPipe(BaseModel):
            union_pipe: int | str | None  # Pipe syntax supported from Python 3.10 onwards
            union_normal: Union[int, str]
            no_union: Any

        class UnionTestWithPipeFactory(ModelFactory):
            __model__ = UnionTestWithPipe

        for field_meta in UnionTestWithPipeFactory.get_model_fields():
            if field_meta.name in ("union_pipe", "union_normal"):
                assert is_union(field_meta.annotation)
            else:
                assert not is_union(field_meta.annotation)


def test_is_new_type() -> None:
    assert is_new_type(NewType("MyInt", int))
    assert not is_new_type(int)


def test_unwrap_new_type_is_needed() -> None:
    MyInt = NewType("MyInt", int)
    WrappedInt = NewType("WrappedInt", MyInt)

    assert unwrap_new_type(MyInt) is int
    assert unwrap_new_type(WrappedInt) is int
    assert unwrap_new_type(int) is int


def test_is_multiply_of_multiple_of_in_range_extreme_cases() -> None:
    assert is_multiply_of_multiple_of_in_range(minimum=0.0, maximum=10.0, multiple_of=20.0)
    assert not is_multiply_of_multiple_of_in_range(minimum=5.0, maximum=10.0, multiple_of=20.0)

    assert is_multiply_of_multiple_of_in_range(minimum=1.0, maximum=1.0, multiple_of=0.33333333333)
    assert is_multiply_of_multiple_of_in_range(
        minimum=Decimal(1),
        maximum=Decimal(1),
        multiple_of=Decimal("0.33333333333"),
    )
    assert not is_multiply_of_multiple_of_in_range(minimum=Decimal(1), maximum=Decimal(1), multiple_of=Decimal("0.333"))

    assert is_multiply_of_multiple_of_in_range(minimum=5, maximum=5, multiple_of=5)

    # while multiple_of=0.0 leads to ZeroDivision exception in pydantic
    # it can handle values close to zero properly, so we should support this too
    assert is_multiply_of_multiple_of_in_range(minimum=10.0, maximum=20.0, multiple_of=1e-10)
    # test corner case found by peterschutt
    assert not is_multiply_of_multiple_of_in_range(
        minimum=Decimal("999999999.9999999343812775"),
        maximum=Decimal("999999999.990476"),
        multiple_of=Decimal("-0.556"),
    )


@given(
    floats(allow_nan=False, allow_infinity=False, min_value=1e-6, max_value=1000000000),
    integers(min_value=-100000, max_value=100000),
)
def test_is_multiply_of_multiple_of_in_range_for_floats(base_multiple_of: float, multiplier: int) -> None:
    if multiplier != 0:
        for multiple_of in [base_multiple_of, -base_multiple_of]:
            minimum, maximum = sorted(
                [
                    multiplier * multiple_of + random.random() * 100,
                    (multiplier + random.randint(1, 100)) * multiple_of + random.random() * 100,
                ],
            )
            assert is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)

            minimum, maximum = sorted(
                [
                    (multiplier + (random.random() / 2 + 0.01)) * multiple_of,
                    (multiplier + (random.random() / 2 + 0.45)) * multiple_of,
                ],
            )
            assert not is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)


@given(
    integers(min_value=-1000000000, max_value=1000000000),
    integers(min_value=-100000, max_value=100000),
)
def test_is_multiply_of_multiple_of_in_range_for_int(base_multiple_of: int, multiplier: int) -> None:
    if multiplier != 0 and base_multiple_of not in [-1, 0, 1]:
        for multiple_of in [base_multiple_of, -base_multiple_of]:
            minimum, maximum = sorted(
                [
                    multiplier * multiple_of + random.randint(1, 100),
                    (multiplier + random.randint(1, 100)) * multiple_of + random.randint(1, 100),
                ],
            )
            assert is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)


@pytest.mark.skip(reason="fails on edge cases")
@given(
    decimals(min_value=Decimal("-1000000000"), max_value=Decimal("1000000000")),
    integers(min_value=-100000, max_value=100000),
)
def test_is_multiply_of_multiple_of_in_range_for_decimals(base_multiple_of: Decimal, multiplier: int) -> None:
    if multiplier != 0 and base_multiple_of != 0:
        for multiple_of in [base_multiple_of, -base_multiple_of]:
            minimum, maximum = sorted(
                [
                    multiplier * multiple_of + Decimal(random.random() * 100),
                    (multiplier + random.randint(1, 100)) * multiple_of + Decimal(random.random() * 100),
                ],
            )
            assert is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)

            minimum, maximum = sorted(
                [
                    (multiplier + Decimal(random.random() / 2 + 0.01)) * multiple_of,
                    (multiplier + Decimal(random.random() / 2 + 0.45)) * multiple_of,
                ],
            )
            assert not is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)


def test_unwrap_legacy_type_alias(create_module: Callable[[str], ModuleType]) -> None:
    """Check that legacy type aliases are properly unwrapped."""
    module = create_module(
        textwrap.dedent("""
        from typing_extensions import TypeAlias

        MyIntLegacyAlias: TypeAlias = int
        """)
    )

    unwrapped = unwrap_annotation(module.MyIntLegacyAlias)
    assert unwrapped is int


@pytest.mark.skipif(sys.version_info < (3, 12), reason="3.12 only syntax")
def test_unwrap_pep695_type_alias(create_module: Callable[[str], ModuleType]) -> None:
    """Check that PEP 695 type aliases are properly unwrapped.

    See issue #683.
    """
    module = create_module(
        textwrap.dedent("""
        type MyInt = int
        """)
    )

    unwrapped = unwrap_annotation(module.MyInt)
    assert unwrapped is int