File: handlers.py

package info (click to toggle)
python-openapi-core 0.19.4-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,008 kB
  • sloc: python: 18,868; makefile: 47
file content (65 lines) | stat: -rw-r--r-- 2,270 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""OpenAPI core contrib starlette handlers module"""

from typing import Any
from typing import Dict
from typing import Iterable
from typing import Type

from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.responses import Response

from openapi_core.templating.media_types.exceptions import MediaTypeNotFound
from openapi_core.templating.paths.exceptions import OperationNotFound
from openapi_core.templating.paths.exceptions import PathNotFound
from openapi_core.templating.paths.exceptions import ServerNotFound
from openapi_core.templating.security.exceptions import SecurityNotFound
from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult


class StarletteOpenAPIErrorsHandler:
    OPENAPI_ERROR_STATUS: Dict[Type[BaseException], int] = {
        ServerNotFound: 400,
        SecurityNotFound: 403,
        OperationNotFound: 405,
        PathNotFound: 404,
        MediaTypeNotFound: 415,
    }

    def __call__(
        self,
        errors: Iterable[Exception],
    ) -> JSONResponse:
        data_errors = [self.format_openapi_error(err) for err in errors]
        data = {
            "errors": data_errors,
        }
        data_error_max = max(data_errors, key=self.get_error_status)
        return JSONResponse(data, status_code=data_error_max["status"])

    @classmethod
    def format_openapi_error(cls, error: BaseException) -> Dict[str, Any]:
        if error.__cause__ is not None:
            error = error.__cause__
        return {
            "title": str(error),
            "status": cls.OPENAPI_ERROR_STATUS.get(error.__class__, 400),
            "type": str(type(error)),
        }

    @classmethod
    def get_error_status(cls, error: Dict[str, Any]) -> str:
        return str(error["status"])


class StarletteOpenAPIValidRequestHandler:
    def __init__(self, request: Request, call_next: RequestResponseEndpoint):
        self.request = request
        self.call_next = call_next

    async def __call__(
        self, request_unmarshal_result: RequestUnmarshalResult
    ) -> Response:
        self.request.scope["openapi"] = request_unmarshal_result
        return await self.call_next(self.request)