File: test_datastructures.py

package info (click to toggle)
litestar 2.19.0-2
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 12,500 kB
  • sloc: python: 70,169; makefile: 254; javascript: 105; sh: 60
file content (142 lines) | stat: -rw-r--r-- 5,245 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

from typing import Dict, Generic, List, TypeVar

import msgspec
import pytest

from litestar._openapi.datastructures import SchemaRegistry, _get_normalized_schema_key
from litestar.exceptions import ImproperlyConfiguredException
from litestar.openapi.spec import Reference, Schema
from litestar.params import KwargDefinition
from litestar.typing import FieldDefinition
from tests.models import DataclassPerson


@pytest.fixture()
def schema_registry() -> SchemaRegistry:
    return SchemaRegistry()


def test_get_schema_for_field_definition(schema_registry: SchemaRegistry) -> None:
    assert not schema_registry._schema_key_map
    assert not schema_registry._schema_reference_map
    assert not schema_registry._model_name_groups
    field = FieldDefinition.from_annotation(str)
    schema = schema_registry.get_schema_for_field_definition(field)
    key = _get_normalized_schema_key(field)
    assert isinstance(schema, Schema)
    assert key in schema_registry._schema_key_map
    assert not schema_registry._schema_reference_map
    assert len(schema_registry._model_name_groups[key[-1]]) == 1
    assert schema_registry._model_name_groups[key[-1]][0].schema is schema
    assert schema_registry.get_schema_for_field_definition(field) is schema


def test_get_reference_for_field_definition(schema_registry: SchemaRegistry) -> None:
    assert not schema_registry._schema_key_map
    assert not schema_registry._schema_reference_map
    assert not schema_registry._model_name_groups
    field = FieldDefinition.from_annotation(str)
    key = _get_normalized_schema_key(field)

    assert schema_registry.get_reference_for_field_definition(field) is None
    schema_registry.get_schema_for_field_definition(field)
    reference = schema_registry.get_reference_for_field_definition(field)
    assert isinstance(reference, Reference)
    assert id(reference) in schema_registry._schema_reference_map
    assert reference in schema_registry._schema_key_map[key].references


def test_get_normalized_schema_key() -> None:
    class LocalClass(msgspec.Struct):
        id: str

    T = TypeVar("T")

    # replace each of the long strings with underscores with a tuple of strings split at each underscore
    assert _get_normalized_schema_key(FieldDefinition.from_annotation(LocalClass)) == (
        "tests",
        "unit",
        "test_openapi",
        "test_datastructures",
        "test_get_normalized_schema_key.LocalClass",
    )

    assert _get_normalized_schema_key(FieldDefinition.from_annotation(DataclassPerson)) == (
        "tests",
        "models",
        "DataclassPerson",
    )

    builtin_dict = Dict[str, List[int]]
    assert _get_normalized_schema_key(FieldDefinition.from_annotation(builtin_dict)) == (
        "typing",
        "Dict_str_typing.List_int_",
    )

    builtin_with_custom = Dict[str, DataclassPerson]
    assert _get_normalized_schema_key(FieldDefinition.from_annotation(builtin_with_custom)) == (
        "typing",
        "Dict_str_tests.models.DataclassPerson_",
    )

    class LocalGeneric(Generic[T]):
        pass

    assert _get_normalized_schema_key(FieldDefinition.from_annotation(LocalGeneric)) == (
        "tests",
        "unit",
        "test_openapi",
        "test_datastructures",
        "test_get_normalized_schema_key.LocalGeneric",
    )

    generic_int = LocalGeneric[int]
    generic_str = LocalGeneric[str]

    assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_int)) == (
        "tests",
        "unit",
        "test_openapi",
        "test_datastructures",
        "test_get_normalized_schema_key.LocalGeneric_int_",
    )

    assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_str)) == (
        "tests",
        "unit",
        "test_openapi",
        "test_datastructures",
        "test_get_normalized_schema_key.LocalGeneric_str_",
    )

    assert _get_normalized_schema_key(FieldDefinition.from_annotation(generic_int)) != _get_normalized_schema_key(
        FieldDefinition.from_annotation(generic_str)
    )


def test_raise_on_override_for_same_field_definition() -> None:
    registry = SchemaRegistry()
    schema = registry.get_schema_for_field_definition(
        FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="foo"))
    )
    # registering the same thing again with the same name should work
    assert (
        registry.get_schema_for_field_definition(
            FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="foo"))
        )
        is schema
    )
    # registering the same *type* with a different name should result in a different schema
    assert (
        registry.get_schema_for_field_definition(
            FieldDefinition.from_annotation(str, kwarg_definition=KwargDefinition(schema_component_key="bar"))
        )
        is not schema
    )
    # registering a different type with a previously used name should raise an exception
    with pytest.raises(ImproperlyConfiguredException):
        registry.get_schema_for_field_definition(
            FieldDefinition.from_annotation(int, kwarg_definition=KwargDefinition(schema_component_key="foo"))
        )