1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
|
"""Internal client for making requests and managing session with Supervisor."""
from dataclasses import dataclass, field
from http import HTTPMethod, HTTPStatus
from importlib import metadata
from typing import Any
from aiohttp import (
ClientError,
ClientResponse,
ClientResponseError,
ClientSession,
ClientTimeout,
)
from multidict import MultiDict
from yarl import URL
from .const import DEFAULT_TIMEOUT, ResponseType
from .exceptions import (
SupervisorAuthenticationError,
SupervisorBadRequestError,
SupervisorConnectionError,
SupervisorError,
SupervisorForbiddenError,
SupervisorNotFoundError,
SupervisorResponseError,
SupervisorServiceUnavailableError,
SupervisorTimeoutError,
)
from .models.base import Response, ResultType
from .utils.aiohttp import ChunkAsyncStreamIterator
VERSION = metadata.version(__package__)
def is_json(response: ClientResponse, *, raise_on_fail: bool = False) -> bool:
"""Check if response is json according to Content-Type."""
content_type = response.headers.get("Content-Type", "")
if "application/json" not in content_type:
if raise_on_fail:
raise SupervisorResponseError(
"Unexpected response received from supervisor when expecting"
f"JSON. Status: {response.status}, content type: {content_type}",
)
return False
return True
@dataclass(slots=True)
class _SupervisorClient:
"""Main class for handling connections with Supervisor."""
api_host: str
token: str
session: ClientSession | None = None
_close_session: bool = field(default=False, init=False)
async def _raise_on_status(self, response: ClientResponse) -> None:
"""Raise appropriate exception on status."""
if response.status >= HTTPStatus.BAD_REQUEST.value:
exc_type: type[SupervisorError] = SupervisorError
match response.status:
case HTTPStatus.BAD_REQUEST:
exc_type = SupervisorBadRequestError
case HTTPStatus.UNAUTHORIZED:
exc_type = SupervisorAuthenticationError
case HTTPStatus.FORBIDDEN:
exc_type = SupervisorForbiddenError
case HTTPStatus.NOT_FOUND:
exc_type = SupervisorNotFoundError
case HTTPStatus.SERVICE_UNAVAILABLE:
exc_type = SupervisorServiceUnavailableError
if is_json(response):
result = Response.from_json(await response.text())
raise exc_type(result.message, result.job_id)
raise exc_type()
async def _request(
self,
method: HTTPMethod,
uri: str,
*,
params: dict[str, str] | MultiDict[str] | None,
response_type: ResponseType,
json: dict[str, Any] | None = None,
data: Any = None,
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
) -> Response:
"""Handle a request to Supervisor."""
try:
url = URL(self.api_host).joinpath(uri)
except ValueError as err:
raise SupervisorError from err
# This check is to make sure the normalized URL string is the same as the URL
# string that was passed in. If they are different, then the passed in uri
# contained characters that were removed by the normalization
# such as ../../../../etc/passwd
if not url.raw_path.endswith(uri):
raise SupervisorError(f"Invalid request {uri}")
match response_type:
case ResponseType.TEXT:
accept = "text/plain, */*"
case _:
accept = "application/json, text/plain, */*"
headers = {
"User-Agent": f"AioHASupervisor/{VERSION}",
"Accept": accept,
"Authorization": f"Bearer {self.token}",
}
if self.session is None:
self.session = ClientSession()
self._close_session = True
try:
response = await self.session.request(
method.value,
url,
timeout=timeout,
headers=headers,
params=params,
json=json,
data=data,
)
await self._raise_on_status(response)
match response_type:
case ResponseType.JSON:
is_json(response, raise_on_fail=True)
return Response.from_json(await response.text())
case ResponseType.TEXT:
return Response(ResultType.OK, await response.text())
case ResponseType.STREAM:
return Response(
ResultType.OK, ChunkAsyncStreamIterator(response.content)
)
case _:
return Response(ResultType.OK)
except (UnicodeDecodeError, ClientResponseError) as err:
raise SupervisorResponseError(
"Unusable response received from Supervisor, check logs",
) from err
except TimeoutError as err:
raise SupervisorTimeoutError("Timeout connecting to Supervisor") from err
except ClientError as err:
raise SupervisorConnectionError(
"Error occurred connecting to supervisor",
) from err
async def get(
self,
uri: str,
*,
params: dict[str, str] | MultiDict[str] | None = None,
response_type: ResponseType = ResponseType.JSON,
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
) -> Response:
"""Handle a GET request to Supervisor."""
return await self._request(
HTTPMethod.GET,
uri,
params=params,
response_type=response_type,
timeout=timeout,
)
async def post(
self,
uri: str,
*,
params: dict[str, str] | MultiDict[str] | None = None,
response_type: ResponseType = ResponseType.NONE,
json: dict[str, Any] | None = None,
data: Any = None,
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
) -> Response:
"""Handle a POST request to Supervisor."""
return await self._request(
HTTPMethod.POST,
uri,
params=params,
response_type=response_type,
json=json,
data=data,
timeout=timeout,
)
async def put(
self,
uri: str,
*,
params: dict[str, str] | MultiDict[str] | None = None,
json: dict[str, Any] | None = None,
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
) -> Response:
"""Handle a PUT request to Supervisor."""
return await self._request(
HTTPMethod.PUT,
uri,
params=params,
response_type=ResponseType.NONE,
json=json,
timeout=timeout,
)
async def delete(
self,
uri: str,
*,
params: dict[str, str] | MultiDict[str] | None = None,
json: dict[str, Any] | None = None,
timeout: ClientTimeout | None = DEFAULT_TIMEOUT,
) -> Response:
"""Handle a DELETE request to Supervisor."""
return await self._request(
HTTPMethod.DELETE,
uri,
params=params,
response_type=ResponseType.NONE,
json=json,
timeout=timeout,
)
async def close(self) -> None:
"""Close open client session."""
if self.session and self._close_session:
await self.session.close()
class _SupervisorComponentClient:
"""Common ancestor for all component clients of supervisor."""
def __init__(self, client: _SupervisorClient) -> None:
"""Initialize sub module with client for API calls."""
self._client = client
|