1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
import datetime
import sys
from decimal import Decimal
from typing import Any, Generic, Optional, Type, TypeVar, Union
import pydantic as pydantic_v2
import pytest
from pydantic import v1 as pydantic_v1
from pydantic.v1.generics import GenericModel
from typing_extensions import Annotated
from litestar import Litestar, post
from litestar._openapi.schema_generation import SchemaCreator
from litestar.openapi.spec import OpenAPIType
from litestar.openapi.spec.schema import Schema
from litestar.plugins.pydantic import PydanticSchemaPlugin
from litestar.typing import FieldDefinition
from litestar.utils.helpers import get_name
from tests.helpers import get_schema_for_field_definition
T = TypeVar("T")
class PydanticV1Generic(GenericModel, Generic[T]):
foo: T
optional_foo: Optional[T]
annotated_foo: Annotated[T, object()]
class PydanticV2Generic(pydantic_v2.BaseModel, Generic[T]):
foo: T
optional_foo: Optional[T]
annotated_foo: Annotated[T, object()]
@pytest.mark.parametrize(
"model",
[
pytest.param(
PydanticV1Generic, marks=[pytest.mark.skipif(sys.version_info >= (3, 14), reason="not supported")]
),
PydanticV2Generic,
],
)
def test_schema_generation_with_generic_classes(model: Type[Union[PydanticV1Generic, PydanticV2Generic]]) -> None:
cls = model[int] # type: ignore[index]
field_definition = FieldDefinition.from_kwarg(name=get_name(cls), annotation=cls)
properties = get_schema_for_field_definition(field_definition, plugins=[PydanticSchemaPlugin()]).properties
expected_foo_schema = Schema(type=OpenAPIType.INTEGER)
expected_optional_foo_schema = Schema(one_of=[Schema(type=OpenAPIType.INTEGER), Schema(type=OpenAPIType.NULL)])
assert properties
assert properties["foo"] == expected_foo_schema
assert properties["annotated_foo"] == expected_foo_schema
assert properties["optional_foo"] == expected_optional_foo_schema
@pytest.mark.parametrize(
"constrained",
[
pydantic_v1.constr(regex="^[a-zA-Z]$"),
pydantic_v1.conlist(int, min_items=1),
pydantic_v1.conset(int, min_items=1),
pydantic_v1.conint(gt=10, lt=100),
pydantic_v1.confloat(gt=10, lt=100),
pydantic_v1.condecimal(gt=Decimal("10")),
pydantic_v1.condate(gt=datetime.date.today()),
pydantic_v2.constr(pattern="^[a-zA-Z]$"),
pydantic_v2.conlist(int, min_length=1),
pydantic_v2.conset(int, min_length=1),
pydantic_v2.conint(gt=10, lt=100),
pydantic_v2.confloat(ge=10, le=100),
pydantic_v2.condecimal(gt=Decimal("10")),
pydantic_v2.condate(gt=datetime.date.today()),
],
)
def test_is_pydantic_constrained_field(constrained: Any) -> None:
PydanticSchemaPlugin.is_constrained_field(FieldDefinition.from_annotation(constrained))
def test_v2_constrained_secrets() -> None:
# https://github.com/litestar-org/litestar/issues/3148
class Model(pydantic_v2.BaseModel):
string: pydantic_v2.SecretStr = pydantic_v2.Field(min_length=1)
bytes_: pydantic_v2.SecretBytes = pydantic_v2.Field(min_length=1)
schema = PydanticSchemaPlugin.for_pydantic_model(
FieldDefinition.from_annotation(Model), schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()])
)
assert schema.properties
assert schema.properties["string"] == Schema(min_length=1, type=OpenAPIType.STRING)
assert schema.properties["bytes_"] == Schema(min_length=1, type=OpenAPIType.STRING)
class V1ModelWithPrivateFields(pydantic_v1.BaseModel):
class Config:
underscore_fields_are_private = True
_field: str = pydantic_v1.PrivateAttr()
# include an invalid annotation here to ensure we never touch those fields
_underscore_field: str = "foo"
class V1GenericModelWithPrivateFields(pydantic_v1.generics.GenericModel, Generic[T]): # pyright: ignore
class Config:
underscore_fields_are_private = True
_field: str = pydantic_v1.PrivateAttr()
# include an invalid annotation here to ensure we never touch those fields
_underscore_field: str = "foo"
class V2ModelWithPrivateFields(pydantic_v2.BaseModel):
_field: str = pydantic_v2.PrivateAttr()
# include an invalid annotation here to ensure we never touch those fields
_underscore_field: str = "foo"
class V2GenericModelWithPrivateFields(pydantic_v2.BaseModel, Generic[T]):
_field: str = pydantic_v2.PrivateAttr()
# include an invalid annotation here to ensure we never touch those fields
_underscore_field: str = "foo"
@pytest.mark.parametrize(
"model_class",
[
V1ModelWithPrivateFields,
V1GenericModelWithPrivateFields,
V2ModelWithPrivateFields,
V2GenericModelWithPrivateFields,
],
)
def test_exclude_private_fields(model_class: Type[Union[pydantic_v1.BaseModel, pydantic_v2.BaseModel]]) -> None:
# https://github.com/litestar-org/litestar/issues/3150
schema = PydanticSchemaPlugin.for_pydantic_model(
FieldDefinition.from_annotation(model_class), schema_creator=SchemaCreator(plugins=[PydanticSchemaPlugin()])
)
assert not schema.properties
@pytest.mark.skipif(sys.version_info >= (3, 14), reason="not supported")
def test_v1_constrained_str_with_default_factory_does_not_generate_title() -> None:
# https://github.com/litestar-org/litestar/issues/3710
class Model(pydantic_v1.BaseModel):
test_str: str = pydantic_v1.Field(default_factory=str, max_length=600)
@post(path="/")
async def test(data: Model) -> str:
return "success"
schema = Litestar(route_handlers=[test]).openapi_schema.to_schema()
assert (
"title"
not in schema["components"]["schemas"][
"test_v1_constrained_str_with_default_factory_does_not_generate_title.Model"
]["properties"]["test_str"]
)
|