1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
|
import dataclasses
import re
from datetime import datetime
from decimal import Decimal
from typing import Any, Dict, Generic, List, Pattern, Tuple, Type, TypeVar, Union
import pytest
from bson import Binary, Decimal128, Int64, ObjectId, Regex
from motor.motor_asyncio import AsyncIOMotorDatabase
from pymongo.database import Database
from odmantic.bson import WithBsonSerializer
from odmantic.engine import AIOEngine, SyncEngine
from odmantic.model import Model
from odmantic.typing import Annotated
pytestmark = pytest.mark.asyncio
T = TypeVar("T")
@dataclasses.dataclass
class TypeTestCase(Generic[T]):
python_type: Type[T]
bson_type: str
sample_value: T
MIN_INT32 = -(2**31)
UNDER_INT32_VALUE = MIN_INT32 - 1
MAX_INT32 = 2**31 - 1
OVER_INT32_VALUE = MAX_INT32 + 1
sample_datetime = datetime.now()
type_test_data = [
# Simple types
TypeTestCase(int, "int", 15),
TypeTestCase(int, "int", MIN_INT32),
TypeTestCase(int, "int", MAX_INT32),
TypeTestCase(int, "long", UNDER_INT32_VALUE),
TypeTestCase(int, "long", OVER_INT32_VALUE),
TypeTestCase(Int64, "long", 13),
TypeTestCase(Int64, "long", Int64(13)),
TypeTestCase(str, "string", "foo"),
TypeTestCase(float, "double", 3.14),
TypeTestCase(Decimal, "decimal", Decimal("3.14159265359")),
TypeTestCase(
Decimal, "decimal", "3.14159265359"
), # TODO split tests for odmantic type inference
TypeTestCase(Decimal128, "decimal", Decimal128(Decimal("3.14159265359"))),
TypeTestCase(Dict[str, Any], "object", {"foo": "bar", "fizz": {"foo": "bar"}}),
TypeTestCase(bool, "bool", False),
TypeTestCase(Pattern, "regex", re.compile(r"^.*$")),
TypeTestCase(Pattern, "regex", re.compile(r"^.*$", flags=re.IGNORECASE)),
TypeTestCase(
Pattern, "regex", re.compile(r"^.*$", flags=re.IGNORECASE | re.MULTILINE)
),
TypeTestCase(Regex, "regex", Regex(r"^.*$", flags=32)),
TypeTestCase(ObjectId, "objectId", ObjectId()),
TypeTestCase(bytes, "binData", b"\xf0\xf1\xf2"),
TypeTestCase(Binary, "binData", Binary(b"\xf0\xf1\xf2")),
TypeTestCase(datetime, "date", sample_datetime),
TypeTestCase(List[str], "array", ["one"]),
# Compound Types
TypeTestCase(Tuple[str, ...], "array", ("one",)), # type: ignore
TypeTestCase(List[ObjectId], "array", [ObjectId() for _ in range(5)]),
TypeTestCase(
Union[Tuple[ObjectId, ...], None], # type: ignore
"array",
tuple(ObjectId() for _ in range(5)),
),
]
def id_from_test_case(case: TypeTestCase):
return f"{case.bson_type}"
@pytest.mark.parametrize("case", type_test_data, ids=id_from_test_case)
async def test_bson_type_inference(
motor_database: AsyncIOMotorDatabase, aio_engine: AIOEngine, case: TypeTestCase
):
class ModelWithTypedField(Model):
field: case.python_type # type: ignore
# TODO: Fix objectid optional (type: ignore)
instance = await aio_engine.save(ModelWithTypedField(field=case.sample_value))
document = await motor_database[ModelWithTypedField.__collection__].find_one(
{
+ModelWithTypedField.id: instance.id, # type: ignore
+ModelWithTypedField.field: {"$type": case.bson_type},
}
)
assert document is not None, (
f"Type inference error: {case.python_type} -> {case.bson_type}"
f" ({case.sample_value})"
)
recovered_instance = ModelWithTypedField(field=document["field"])
assert recovered_instance.field == instance.field
@pytest.mark.parametrize("case", type_test_data, ids=id_from_test_case)
def test_sync_bson_type_inference(
pymongo_database: Database, sync_engine: SyncEngine, case: TypeTestCase
):
class ModelWithTypedField(Model):
field: case.python_type # type: ignore
# TODO: Fix objectid optional (type: ignore)
instance = sync_engine.save(ModelWithTypedField(field=case.sample_value))
document = pymongo_database[ModelWithTypedField.__collection__].find_one(
{
+ModelWithTypedField.id: instance.id, # type: ignore
+ModelWithTypedField.field: {"$type": case.bson_type},
}
)
assert document is not None, (
f"Type inference error: {case.python_type} -> {case.bson_type}"
f" ({case.sample_value})"
)
recovered_instance = ModelWithTypedField(field=document["field"])
assert recovered_instance.field == instance.field
async def test_custom_bson_serializable(
motor_database: AsyncIOMotorDatabase, aio_engine
):
FancyFloat = Annotated[float, WithBsonSerializer(str)]
class ModelWithCustomField(Model):
field: FancyFloat
instance = await aio_engine.save(ModelWithCustomField(field=3.14))
document = await motor_database[ModelWithCustomField.__collection__].find_one(
{
+ModelWithCustomField.id: instance.id, # type: ignore
+ModelWithCustomField.field: {"$type": "string"},
}
)
assert document is not None, "Couldn't retrieve the document with it's string value"
recovered_instance = ModelWithCustomField.model_validate_doc(document)
assert recovered_instance.field == instance.field
def test_sync_custom_bson_serializable(
pymongo_database: Database, sync_engine: SyncEngine
):
FancyFloat = Annotated[float, WithBsonSerializer(str)]
class ModelWithCustomField(Model):
field: FancyFloat
instance = sync_engine.save(ModelWithCustomField(field=3.14))
document = pymongo_database[ModelWithCustomField.__collection__].find_one(
{
+ModelWithCustomField.id: instance.id, # type: ignore
+ModelWithCustomField.field: {"$type": "string"},
}
)
assert document is not None, "Couldn't retrieve the document with it's string value"
recovered_instance = ModelWithCustomField.model_validate_doc(document)
assert recovered_instance.field == instance.field
|