1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from click import Group
from litestar import Litestar, MediaType, get
from litestar.constants import UNDEFINED_SENTINELS
from litestar.plugins import CLIPluginProtocol, InitPlugin, OpenAPISchemaPlugin, PluginRegistry
from litestar.plugins.attrs import AttrsSchemaPlugin
from litestar.plugins.core import MsgspecDIPlugin
from litestar.plugins.pydantic import PydanticDIPlugin, PydanticInitPlugin, PydanticPlugin, PydanticSchemaPlugin
from litestar.plugins.sqlalchemy import SQLAlchemySerializationPlugin
from litestar.testing import create_test_client
from litestar.typing import FieldDefinition
if TYPE_CHECKING:
from litestar.config.app import AppConfig
def test_plugin_on_app_init() -> None:
@get("/", media_type=MediaType.TEXT)
def greet() -> str:
return "hello world"
tag = "on_app_init_called"
def on_startup(app: Litestar) -> None:
app.state.called = True
class PluginWithInitOnly(InitPlugin):
def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.tags.append(tag)
app_config.on_startup.append(on_startup)
app_config.route_handlers.append(greet)
return app_config
with create_test_client(plugins=[PluginWithInitOnly()]) as client:
response = client.get("/")
assert response.text == "hello world"
assert tag in client.app.tags
assert client.app.state.called
def test_plugin_registry() -> None:
class CLIPlugin(CLIPluginProtocol):
def on_cli_init(self, cli: Group) -> None:
pass
cli_plugin = CLIPlugin()
serialization_plugin = SQLAlchemySerializationPlugin()
openapi_plugin = PydanticSchemaPlugin()
init_plugin = PydanticInitPlugin()
registry = PluginRegistry([cli_plugin, serialization_plugin, openapi_plugin, init_plugin])
assert registry.openapi == (openapi_plugin,)
assert registry.cli == (cli_plugin,)
assert registry.serialization == (serialization_plugin,)
assert registry.init == (init_plugin,)
assert openapi_plugin in registry
assert serialization_plugin in registry
assert init_plugin in registry
assert cli_plugin in registry
assert set(registry) == {openapi_plugin, cli_plugin, init_plugin, serialization_plugin}
def test_plugin_registry_get() -> None:
class CLIPlugin(CLIPluginProtocol):
def on_cli_init(self, cli: Group) -> None:
pass
cli_plugin = CLIPlugin()
with pytest.raises(KeyError, match="No plugin of type 'CLIPlugin' registered"):
PluginRegistry([]).get(CLIPlugin)
assert PluginRegistry([cli_plugin]).get(CLIPlugin) is cli_plugin
def test_plugin_registry_stringified_get() -> None:
class CLIPlugin(CLIPluginProtocol):
def on_cli_init(self, cli: Group) -> None:
pass
cli_plugin = CLIPlugin()
pydantic_plugin = PydanticPlugin()
with pytest.raises(KeyError):
PluginRegistry([CLIPlugin()]).get(
"litestar2.plugins.pydantic.PydanticPlugin"
) # not a fqdn. should fail # type: ignore[list-item]
PluginRegistry([]).get("CLIPlugin") # not a fqdn. should fail # type: ignore[list-item]
assert PluginRegistry([cli_plugin, pydantic_plugin]).get(CLIPlugin) is cli_plugin
assert PluginRegistry([cli_plugin, pydantic_plugin]).get(PydanticPlugin) is pydantic_plugin
assert PluginRegistry([cli_plugin, pydantic_plugin]).get("PydanticPlugin") is pydantic_plugin
assert (
PluginRegistry([cli_plugin, pydantic_plugin]).get("litestar.plugins.pydantic.PydanticPlugin") is pydantic_plugin
)
def test_openapi_schema_plugin_is_constrained_field() -> None:
assert OpenAPISchemaPlugin.is_constrained_field(FieldDefinition.from_annotation(str)) is False
def test_openapi_schema_plugin_is_undefined_sentinel() -> None:
for value in UNDEFINED_SENTINELS:
assert OpenAPISchemaPlugin.is_undefined_sentinel(value) is False
@pytest.mark.parametrize(("init_plugin",), [(PydanticInitPlugin(),), (None,)])
@pytest.mark.parametrize(("schema_plugin",), [(PydanticSchemaPlugin(),), (None,)])
@pytest.mark.parametrize(("attrs_plugin",), [(AttrsSchemaPlugin(),), (None,)])
def test_app_get_default_plugins(
init_plugin: PydanticInitPlugin, schema_plugin: PydanticSchemaPlugin, attrs_plugin: AttrsSchemaPlugin
) -> None:
plugins = [p for p in (init_plugin, schema_plugin, attrs_plugin) if p is not None]
any_pydantic = bool(init_plugin) or bool(schema_plugin)
default_plugins = Litestar._get_default_plugins(plugins) # type: ignore[arg-type]
if not any_pydantic:
assert {type(p) for p in default_plugins} == {
PydanticPlugin,
AttrsSchemaPlugin,
PydanticDIPlugin,
MsgspecDIPlugin,
}
else:
assert {type(p) for p in default_plugins} == {
PydanticInitPlugin,
PydanticSchemaPlugin,
AttrsSchemaPlugin,
PydanticDIPlugin,
MsgspecDIPlugin,
}
|