1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
|
from __future__ import annotations
from http import HTTPStatus
from typing import Any
import pytest
from litestar import get
from litestar.exceptions.http_exceptions import HTTPException, ValidationException
from litestar.plugins.problem_details import ProblemDetailsConfig, ProblemDetailsException, ProblemDetailsPlugin
from litestar.testing.helpers import create_test_client
@pytest.mark.parametrize(
("exception", "expected"),
[
(
ProblemDetailsException(),
{
"status": 500,
"detail": HTTPStatus(500).phrase,
},
),
(
ProblemDetailsException(status_code=400, detail="validation error", instance="https://example.net/error"),
{
"status": 400,
"detail": "validation error",
"instance": "https://example.net/error",
},
),
(
ProblemDetailsException(
status_code=400,
detail="validation error",
extra={"error": "must be positive integer", "pointer": "#age"},
),
{
"status": 400,
"detail": "validation error",
"error": "must be positive integer",
"pointer": "#age",
},
),
(
ProblemDetailsException(
status_code=400,
detail="validation error",
extra=[{"error": "must be positive integer", "pointer": "#age"}],
),
{
"status": 400,
"detail": "validation error",
"extra": [{"error": "must be positive integer", "pointer": "#age"}],
},
),
(
ProblemDetailsException(type_="https://example.net/validation-error"),
{
"type": "https://example.net/validation-error",
"status": 500,
"detail": HTTPStatus(500).phrase,
},
),
],
)
def test_raising_problem_details_exception(exception: ProblemDetailsException, expected: dict[str, Any]) -> None:
@get("/")
async def get_foo() -> None:
raise exception
with create_test_client([get_foo], plugins=[ProblemDetailsPlugin()]) as client:
response = client.get("/")
assert response.headers["content-type"] == "application/problem+json"
assert response.json() == expected
assert response.status_code == expected["status"]
@pytest.mark.parametrize("enable", (True, False))
def test_enable_for_all_http_exceptions(enable: bool) -> None:
@get("/")
async def get_foo() -> None:
raise HTTPException()
config = ProblemDetailsConfig(enable_for_all_http_exceptions=enable)
with create_test_client([get_foo], plugins=[ProblemDetailsPlugin(config)]) as client:
response = client.get("/")
if enable:
assert response.headers["content-type"] == "application/problem+json"
else:
assert response.headers["content-type"] != "application/problem+json"
def test_exception_to_problem_detail_map() -> None:
def validation_exception_to_problem_details_exception(exc: ValidationException) -> ProblemDetailsException:
return ProblemDetailsException(
type_="validation-error", detail=exc.detail, extra=exc.extra, status_code=exc.status_code
)
@get("/")
async def get_foo() -> None:
raise ValidationException(detail="Not enough balance", extra=errors)
errors = {"accounts": ["/account/1", "/account/2"]}
config = ProblemDetailsConfig(
exception_to_problem_detail_map={ValidationException: validation_exception_to_problem_details_exception}
)
with create_test_client([get_foo], plugins=[ProblemDetailsPlugin(config)]) as client:
response = client.get("/")
assert response.status_code == 400
assert response.headers["content-type"] == "application/problem+json"
assert response.json() == {
"type": "validation-error",
"status": 400,
"detail": "Not enough balance",
"accounts": ["/account/1", "/account/2"],
}
|