1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
|
import random
import sys
import textwrap
from decimal import Decimal
from types import ModuleType
from typing import Any, Callable, NewType, Union
import pytest
from hypothesis import given
from hypothesis.strategies import decimals, floats, integers
from pydantic import BaseModel
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.utils.helpers import unwrap_annotation, unwrap_new_type
from polyfactory.utils.predicates import is_new_type, is_union
from polyfactory.value_generators.constrained_numbers import (
is_multiply_of_multiple_of_in_range,
)
def test_is_union() -> None:
class UnionTest(BaseModel):
union: Union[int, str]
no_union: Any
class UnionTestFactory(ModelFactory):
__model__ = UnionTest
for field_meta in UnionTestFactory.get_model_fields():
if field_meta.name == "union":
assert is_union(field_meta.annotation)
else:
assert not is_union(field_meta.annotation)
# for python 3.10 we need to run the test as well with the union_pipe operator
if sys.version_info >= (3, 10):
class UnionTestWithPipe(BaseModel):
union_pipe: int | str | None # Pipe syntax supported from Python 3.10 onwards
union_normal: Union[int, str]
no_union: Any
class UnionTestWithPipeFactory(ModelFactory):
__model__ = UnionTestWithPipe
for field_meta in UnionTestWithPipeFactory.get_model_fields():
if field_meta.name in ("union_pipe", "union_normal"):
assert is_union(field_meta.annotation)
else:
assert not is_union(field_meta.annotation)
def test_is_new_type() -> None:
assert is_new_type(NewType("MyInt", int))
assert not is_new_type(int)
def test_unwrap_new_type_is_needed() -> None:
MyInt = NewType("MyInt", int)
WrappedInt = NewType("WrappedInt", MyInt)
assert unwrap_new_type(MyInt) is int
assert unwrap_new_type(WrappedInt) is int
assert unwrap_new_type(int) is int
def test_is_multiply_of_multiple_of_in_range_extreme_cases() -> None:
assert is_multiply_of_multiple_of_in_range(minimum=0.0, maximum=10.0, multiple_of=20.0)
assert not is_multiply_of_multiple_of_in_range(minimum=5.0, maximum=10.0, multiple_of=20.0)
assert is_multiply_of_multiple_of_in_range(minimum=1.0, maximum=1.0, multiple_of=0.33333333333)
assert is_multiply_of_multiple_of_in_range(
minimum=Decimal(1),
maximum=Decimal(1),
multiple_of=Decimal("0.33333333333"),
)
assert not is_multiply_of_multiple_of_in_range(minimum=Decimal(1), maximum=Decimal(1), multiple_of=Decimal("0.333"))
assert is_multiply_of_multiple_of_in_range(minimum=5, maximum=5, multiple_of=5)
# while multiple_of=0.0 leads to ZeroDivision exception in pydantic
# it can handle values close to zero properly, so we should support this too
assert is_multiply_of_multiple_of_in_range(minimum=10.0, maximum=20.0, multiple_of=1e-10)
# test corner case found by peterschutt
assert not is_multiply_of_multiple_of_in_range(
minimum=Decimal("999999999.9999999343812775"),
maximum=Decimal("999999999.990476"),
multiple_of=Decimal("-0.556"),
)
@given(
floats(allow_nan=False, allow_infinity=False, min_value=1e-6, max_value=1000000000),
integers(min_value=-100000, max_value=100000),
)
def test_is_multiply_of_multiple_of_in_range_for_floats(base_multiple_of: float, multiplier: int) -> None:
if multiplier != 0:
for multiple_of in [base_multiple_of, -base_multiple_of]:
minimum, maximum = sorted(
[
multiplier * multiple_of + random.random() * 100,
(multiplier + random.randint(1, 100)) * multiple_of + random.random() * 100,
],
)
assert is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)
minimum, maximum = sorted(
[
(multiplier + (random.random() / 2 + 0.01)) * multiple_of,
(multiplier + (random.random() / 2 + 0.45)) * multiple_of,
],
)
assert not is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)
@given(
integers(min_value=-1000000000, max_value=1000000000),
integers(min_value=-100000, max_value=100000),
)
def test_is_multiply_of_multiple_of_in_range_for_int(base_multiple_of: int, multiplier: int) -> None:
if multiplier != 0 and base_multiple_of not in [-1, 0, 1]:
for multiple_of in [base_multiple_of, -base_multiple_of]:
minimum, maximum = sorted(
[
multiplier * multiple_of + random.randint(1, 100),
(multiplier + random.randint(1, 100)) * multiple_of + random.randint(1, 100),
],
)
assert is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)
@pytest.mark.skip(reason="fails on edge cases")
@given(
decimals(min_value=Decimal("-1000000000"), max_value=Decimal("1000000000")),
integers(min_value=-100000, max_value=100000),
)
def test_is_multiply_of_multiple_of_in_range_for_decimals(base_multiple_of: Decimal, multiplier: int) -> None:
if multiplier != 0 and base_multiple_of != 0:
for multiple_of in [base_multiple_of, -base_multiple_of]:
minimum, maximum = sorted(
[
multiplier * multiple_of + Decimal(random.random() * 100),
(multiplier + random.randint(1, 100)) * multiple_of + Decimal(random.random() * 100),
],
)
assert is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)
minimum, maximum = sorted(
[
(multiplier + Decimal(random.random() / 2 + 0.01)) * multiple_of,
(multiplier + Decimal(random.random() / 2 + 0.45)) * multiple_of,
],
)
assert not is_multiply_of_multiple_of_in_range(minimum=minimum, maximum=maximum, multiple_of=multiple_of)
def test_unwrap_legacy_type_alias(create_module: Callable[[str], ModuleType]) -> None:
"""Check that legacy type aliases are properly unwrapped."""
module = create_module(
textwrap.dedent("""
from typing_extensions import TypeAlias
MyIntLegacyAlias: TypeAlias = int
""")
)
unwrapped = unwrap_annotation(module.MyIntLegacyAlias)
assert unwrapped is int
@pytest.mark.skipif(sys.version_info < (3, 12), reason="3.12 only syntax")
def test_unwrap_pep695_type_alias(create_module: Callable[[str], ModuleType]) -> None:
"""Check that PEP 695 type aliases are properly unwrapped.
See issue #683.
"""
module = create_module(
textwrap.dedent("""
type MyInt = int
""")
)
unwrapped = unwrap_annotation(module.MyInt)
assert unwrapped is int
|