1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
from __future__ import annotations
from typing import Dict, Generic, List, TypeVar
import msgspec
import pytest
from litestar._openapi.datastructures import SchemaRegistry, _get_normalized_schema_key
from litestar.exceptions import ImproperlyConfiguredException
from litestar.openapi.spec import Reference, Schema
from litestar.params import KwargDefinition
from litestar.typing import FieldDefinition
from tests.models import DataclassPerson
@pytest.fixture()
def schema_registry() -> SchemaRegistry:
return SchemaRegistry()
def test_get_schema_for_field_definition(schema_registry: SchemaRegistry) -> None:
assert not schema_registry._schema_key_map
assert not schema_registry._schema_reference_map
assert not schema_registry._model_name_groups
field = FieldDefinition.from_annotation(str)
schema = schema_registry.get_schema_for_field_definition(field)
key = _get_normalized_schema_key(field)
assert isinstance(schema, Schema)
assert key in schema_registry._schema_key_map
assert not schema_registry._schema_reference_map
assert len(schema_registry._model_name_groups[key[-1]]) == 1
assert schema_registry._model_name_groups[key[-1]][0].schema is schema
assert schema_registry.get_schema_for_field_definition(field) is schema
def test_get_reference_for_field_definition(schema_registry: SchemaRegistry) -> None:
assert not schema_registry._schema_key_map
assert not schema_registry._schema_reference_map
assert not schema_registry._model_name_groups
field = FieldDefinition.from_annotation(str)
key = _get_normalized_schema_key(field)
assert schema_registry.get_reference_for_field_definition(field) is None
schema_registry.get_schema_for_field_definition(field)
reference = schema_registry.get_reference_for_field_definition(field)
assert isinstance(reference, Reference)
assert id(reference) in schema_registry._schema_reference_map
assert reference in schema_registry._schema_key_map[key].references
def test_get_normalized_schema_key() -> None:
class LocalClass(msgspec.Struct):
id: str
T = TypeVar("T")
# replace each of the long strings with underscores with a tuple of strings split at each underscore
assert _get_normalized_schema_key(FieldDefinition.from_annotation(LocalClass)) == (
"tests",
"unit",
"test_openapi",
"test_datastructures",
"test_get_normalized_schema_key.LocalClass",
)
assert _get_normalized_schema_key(FieldDefinition.from_annotation(DataclassPerson)) == (
"tests",
"models",
"DataclassPerson",
)
builtin_dict = Dict[str, List[int]]
assert _get_normalized_schema_key(FieldDefinition.from_annotation(builtin_dict)) == (
"typing",
"Dict_str_typing.List_int_",
)
builtin_with_custom = Dict[str, DataclassPerson]
assert _get_normalized_schema_key(FieldDefinition.from_annotation(builtin_with_custom)) == (
"typing",
"Dict_str_tests.models.DataclassPerson_",
)
class LocalGeneric(Generic[T]):
pass
assert _get_normalized_schema_key(FieldDefinition.from_annotation(LocalGeneric)) == (
"tests",
"unit",
"test_openapi",
"test_datastructures",
"test_get_normalized_schema_key.LocalGeneric",
)
generic_int = LocalGeneric[int]
generic_str = LocalGeneric[str]
assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_int)) == (
"tests",
"unit",
"test_openapi",
"test_datastructures",
"test_get_normalized_schema_key.LocalGeneric_int_",
)
assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_str)) == (
"tests",
"unit",
"test_openapi",
"test_datastructures",
"test_get_normalized_schema_key.LocalGeneric_str_",
)
assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_int)) != _get_normalized_schema_key(
FieldDefinition.from_annotation(generic_str)
)
def test_raise_on_override_for_same_field_definition() -> None:
registry = SchemaRegistry()
schema = registry.get_schema_for_field_definition(
FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="foo"))
)
# registering the same thing again with the same name should work
assert (
registry.get_schema_for_field_definition(
FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="foo"))
)
is schema
)
# registering the same *type* with a different name should result in a different schema
assert (
registry.get_schema_for_field_definition(
FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="bar"))
)
is not schema
)
# registering a different type with a previously used name should raise an exception
with pytest.raises(ImproperlyConfiguredException):
registry.get_schema_for_field_definition(
FieldDefinition.from_annotation(int, kwarg_definition=KwargDefinition(schema_component_key="foo"))
)
|