File: client.py

package info (click to toggle)
python-aiohasupervisor 0.3.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 884 kB
  • sloc: python: 4,353; sh: 37; makefile: 3
file content (239 lines) | stat: -rw-r--r-- 8,003 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""Internal client for making requests and managing session with Supervisor."""

from dataclasses import dataclass, field
from http import HTTPMethod, HTTPStatus
from importlib import metadata
from typing import Any

from aiohttp import (
    ClientError,
    ClientResponse,
    ClientResponseError,
    ClientSession,
    ClientTimeout,
)
from multidict import MultiDict
from yarl import URL

from .const import DEFAULT_TIMEOUT, ResponseType
from .exceptions import (
    SupervisorAuthenticationError,
    SupervisorBadRequestError,
    SupervisorConnectionError,
    SupervisorError,
    SupervisorForbiddenError,
    SupervisorNotFoundError,
    SupervisorResponseError,
    SupervisorServiceUnavailableError,
    SupervisorTimeoutError,
)
from .models.base import Response, ResultType
from .utils.aiohttp import ChunkAsyncStreamIterator

VERSION = metadata.version(__package__)


def is_json(response: ClientResponse, *, raise_on_fail: bool = False) -> bool:
    """Check if response is json according to Content-Type."""
    content_type = response.headers.get("Content-Type", "")
    if "application/json" not in content_type:
        if raise_on_fail:
            raise SupervisorResponseError(
                "Unexpected response received from supervisor when expecting"
                f"JSON. Status: {response.status}, content type: {content_type}",
            )
        return False
    return True


@dataclass(slots=True)
class _SupervisorClient:
    """Main class for handling connections with Supervisor."""

    api_host: str
    token: str
    session: ClientSession | None = None
    _close_session: bool = field(default=False, init=False)

    async def _raise_on_status(self, response: ClientResponse) -> None:
        """Raise appropriate exception on status."""
        if response.status >= HTTPStatus.BAD_REQUEST.value:
            exc_type: type[SupervisorError] = SupervisorError
            match response.status:
                case HTTPStatus.BAD_REQUEST:
                    exc_type = SupervisorBadRequestError
                case HTTPStatus.UNAUTHORIZED:
                    exc_type = SupervisorAuthenticationError
                case HTTPStatus.FORBIDDEN:
                    exc_type = SupervisorForbiddenError
                case HTTPStatus.NOT_FOUND:
                    exc_type = SupervisorNotFoundError
                case HTTPStatus.SERVICE_UNAVAILABLE:
                    exc_type = SupervisorServiceUnavailableError

            if is_json(response):
                result = Response.from_json(await response.text())
                raise exc_type(result.message, result.job_id)
            raise exc_type()

    async def _request(
        self,
        method: HTTPMethod,
        uri: str,
        *,
        params: dict[str, str] | MultiDict[str] | None,
        response_type: ResponseType,
        json: dict[str, Any] | None = None,
        data: Any = None,
        timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
    ) -> Response:
        """Handle a request to Supervisor."""
        try:
            url = URL(self.api_host).joinpath(uri)
        except ValueError as err:
            raise SupervisorError from err

        # This check is to make sure the normalized URL string is the same as the URL
        # string that was passed in. If they are different, then the passed in uri
        # contained characters that were removed by the normalization
        # such as ../../../../etc/passwd
        if not url.raw_path.endswith(uri):
            raise SupervisorError(f"Invalid request {uri}")

        match response_type:
            case ResponseType.TEXT:
                accept = "text/plain, */*"
            case _:
                accept = "application/json, text/plain, */*"

        headers = {
            "User-Agent": f"AioHASupervisor/{VERSION}",
            "Accept": accept,
            "Authorization": f"Bearer {self.token}",
        }

        if self.session is None:
            self.session = ClientSession()
            self._close_session = True

        try:
            response = await self.session.request(
                method.value,
                url,
                timeout=timeout,
                headers=headers,
                params=params,
                json=json,
                data=data,
            )
            await self._raise_on_status(response)
            match response_type:
                case ResponseType.JSON:
                    is_json(response, raise_on_fail=True)
                    return Response.from_json(await response.text())
                case ResponseType.TEXT:
                    return Response(ResultType.OK, await response.text())
                case ResponseType.STREAM:
                    return Response(
                        ResultType.OK, ChunkAsyncStreamIterator(response.content)
                    )
                case _:
                    return Response(ResultType.OK)

        except (UnicodeDecodeError, ClientResponseError) as err:
            raise SupervisorResponseError(
                "Unusable response received from Supervisor, check logs",
            ) from err
        except TimeoutError as err:
            raise SupervisorTimeoutError("Timeout connecting to Supervisor") from err
        except ClientError as err:
            raise SupervisorConnectionError(
                "Error occurred connecting to supervisor",
            ) from err

    async def get(
        self,
        uri: str,
        *,
        params: dict[str, str] | MultiDict[str] | None = None,
        response_type: ResponseType = ResponseType.JSON,
        timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
    ) -> Response:
        """Handle a GET request to Supervisor."""
        return await self._request(
            HTTPMethod.GET,
            uri,
            params=params,
            response_type=response_type,
            timeout=timeout,
        )

    async def post(
        self,
        uri: str,
        *,
        params: dict[str, str] | MultiDict[str] | None = None,
        response_type: ResponseType = ResponseType.NONE,
        json: dict[str, Any] | None = None,
        data: Any = None,
        timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
    ) -> Response:
        """Handle a POST request to Supervisor."""
        return await self._request(
            HTTPMethod.POST,
            uri,
            params=params,
            response_type=response_type,
            json=json,
            data=data,
            timeout=timeout,
        )

    async def put(
        self,
        uri: str,
        *,
        params: dict[str, str] | MultiDict[str] | None = None,
        json: dict[str, Any] | None = None,
        timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
    ) -> Response:
        """Handle a PUT request to Supervisor."""
        return await self._request(
            HTTPMethod.PUT,
            uri,
            params=params,
            response_type=ResponseType.NONE,
            json=json,
            timeout=timeout,
        )

    async def delete(
        self,
        uri: str,
        *,
        params: dict[str, str] | MultiDict[str] | None = None,
        json: dict[str, Any] | None = None,
        timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
    ) -> Response:
        """Handle a DELETE request to Supervisor."""
        return await self._request(
            HTTPMethod.DELETE,
            uri,
            params=params,
            response_type=ResponseType.NONE,
            json=json,
            timeout=timeout,
        )

    async def close(self) -> None:
        """Close open client session."""
        if self.session and self._close_session:
            await self.session.close()


class _SupervisorComponentClient:
    """Common ancestor for all component clients of supervisor."""

    def __init__(self, client: _SupervisorClient) -> None:
        """Initialize sub module with client for API calls."""
        self._client = client