File: sessions.py

package info (click to toggle)
python-asyncprawcore 3.0.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,328 kB
  • sloc: python: 2,224; makefile: 4
file content (445 lines) | stat: -rw-r--r-- 15,552 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
"""asyncprawcore.sessions: Provides asyncprawcore.Session and asyncprawcore.session."""

from __future__ import annotations

import asyncio
import logging
import random
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from copy import deepcopy
from dataclasses import dataclass
from pprint import pformat
from typing import TYPE_CHECKING, BinaryIO, TextIO
from urllib.parse import urljoin

from aiohttp.web import HTTPRequestTimeout

from .auth import BaseAuthorizer
from .codes import codes
from .const import TIMEOUT, WINDOW_SIZE
from .exceptions import (
    BadJSON,
    BadRequest,
    Conflict,
    InvalidInvocation,
    NotFound,
    Redirect,
    RequestException,
    ResponseException,
    ServerError,
    SpecialError,
    TooLarge,
    TooManyRequests,
    UnavailableForLegalReasons,
    URITooLong,
)
from .rate_limit import RateLimiter
from .util import authorization_error_class

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator

    from aiohttp import ClientResponse
    from typing_extensions import Self

    from .auth import Authorizer
    from .requestor import Requestor

log = logging.getLogger(__package__)


class RetryStrategy(ABC):
    """An abstract class for scheduling request retries.

    The strategy controls both the number and frequency of retry attempts.

    Instances of this class are immutable.

    """

    @abstractmethod
    def _sleep_seconds(self) -> float | None:
        pass

    @abstractmethod
    def consume_available_retry(self) -> RetryStrategy:
        """Allow one fewer retry."""

    @abstractmethod
    def should_retry_on_failure(self) -> bool:
        """Return True when a retry should occur."""

    async def sleep(self) -> None:
        """Sleep until we are ready to attempt the request."""
        sleep_seconds = self._sleep_seconds()
        if sleep_seconds is not None:
            message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to retry"
            log.debug(message)
            await asyncio.sleep(sleep_seconds)


@dataclass(frozen=True)
class FiniteRetryStrategy(RetryStrategy):
    """A ``RetryStrategy`` that retries requests a finite number of times."""

    DEFAULT_RETRIES = 2

    retries: int = DEFAULT_RETRIES

    def _sleep_seconds(self) -> float | None:
        if self.retries < self.DEFAULT_RETRIES:
            base = 0 if self.retries > 0 else 2
            return base + 2 * random.random()  # noqa: S311
        return None

    def consume_available_retry(self) -> FiniteRetryStrategy:
        """Allow one fewer retry."""
        return type(self)(retries=self.retries - 1)

    def should_retry_on_failure(self) -> bool:
        """Return ``True`` if and only if the strategy will allow another retry."""
        return self.retries > 0


class Session:
    """The low-level connection interface to Reddit's API."""

    RETRY_EXCEPTIONS = (ConnectionError, HTTPRequestTimeout)
    RETRY_STATUSES = {
        520,
        522,
        codes["bad_gateway"],
        codes["gateway_timeout"],
        codes["internal_server_error"],
        codes["request_timeout"],
        codes["service_unavailable"],
    }
    STATUS_EXCEPTIONS = {
        codes["bad_gateway"]: ServerError,
        codes["bad_request"]: BadRequest,
        codes["conflict"]: Conflict,
        codes["found"]: Redirect,
        codes["forbidden"]: authorization_error_class,
        codes["gateway_timeout"]: ServerError,
        codes["internal_server_error"]: ServerError,
        codes["media_type"]: SpecialError,
        codes["moved_permanently"]: Redirect,
        codes["not_found"]: NotFound,
        codes["request_entity_too_large"]: TooLarge,
        codes["request_uri_too_large"]: URITooLong,
        codes["service_unavailable"]: ServerError,
        codes["too_many_requests"]: TooManyRequests,
        codes["unauthorized"]: authorization_error_class,
        codes[
            "unavailable_for_legal_reasons"
        ]: UnavailableForLegalReasons,  # Cloudflare's status (not named in requests)
        520: ServerError,
        522: ServerError,
    }
    SUCCESS_STATUSES = {codes["accepted"], codes["created"], codes["ok"]}

    @staticmethod
    def _log_request(
        *,
        data: list[tuple[str, object]] | None,
        method: str,
        params: dict[str, object],
        url: str,
    ) -> None:
        log.debug("Fetching: %s %s at %s", method, url, time.monotonic())
        log.debug("Data: %s", pformat(data))
        log.debug("Params: %s", pformat(params))

    @staticmethod
    def _preprocess_dict(data: dict[str, object]) -> dict[str, object]:
        new_data = {}
        for key, value in data.items():
            if isinstance(value, bool):
                new_data[key] = str(value).lower()
            elif value is not None:
                new_data[key] = str(value) if not isinstance(value, str) else value
        return new_data

    @property
    def _requestor(self) -> Requestor:
        return self._authorizer._authenticator._requestor

    async def __aenter__(self) -> Self:
        """Allow this object to be used as a context manager."""
        return self

    async def __aexit__(self, *_args) -> None:
        """Allow this object to be used as a context manager."""
        await self.close()

    def __init__(
        self,
        authorizer: BaseAuthorizer | None,
        window_size: int = WINDOW_SIZE,
    ) -> None:
        """Prepare the connection to Reddit's API.

        :param authorizer: An instance of :class:`.Authorizer`.
        :param window_size: The size of the rate limit reset window in seconds.

        """
        if not isinstance(authorizer, BaseAuthorizer):
            msg = f"invalid Authorizer: {authorizer}"
            raise InvalidInvocation(msg)
        self._authorizer = authorizer
        self._rate_limiter = RateLimiter(window_size=window_size)
        self._retry_strategy_class = FiniteRetryStrategy

    async def _do_retry(
        self,
        *,
        data: list[tuple[str, object]] | None,
        json: dict[str, object] | None,
        method: str,
        params: dict[str, object],
        retry_strategy_state: FiniteRetryStrategy,
        status: str,
        timeout: float,
        url: str,
    ) -> dict[str, object] | str | None:
        log.warning("Retrying due to %s: %s %s", status, method, url)
        return await self._request_with_retries(
            data=data,
            json=json,
            method=method,
            params=params,
            retry_strategy_state=retry_strategy_state.consume_available_retry(),
            timeout=timeout,
            url=url,
            # noqa: E501
        )

    @asynccontextmanager
    async def _make_request(
        self,
        data: list[tuple[str, object]] | None,
        json: dict[str, object] | None,
        method: str,
        params: dict[str, object],
        timeout: float,
        url: str,
    ) -> AsyncGenerator[ClientResponse]:
        async with self._rate_limiter.call(
            self._requestor.request,
            self._set_header_callback,
            method,
            url,
            allow_redirects=False,
            data=data,
            json=json,
            params=params,
            timeout=timeout,
        ) as response:
            log.debug(
                "Response: %s (%s bytes) (rst-%s:rem-%s:used-%s ratelimit) at %s",
                response.status,
                response.headers.get("content-length"),
                response.headers.get("x-ratelimit-reset"),
                response.headers.get("x-ratelimit-remaining"),
                response.headers.get("x-ratelimit-used"),
                time.monotonic(),
            )
            yield response

    def _preprocess_data(
        self,
        data: dict[str, object],
        files: dict[str, BinaryIO | TextIO] | None,
    ) -> dict[str, object]:
        """Preprocess data and files before request.

        This is to convert requests that are formatted for the ``requests`` package to
        be compatible with the ``aiohttp`` package. The motivation for this is so that
        ``praw`` and ``asyncpraw`` can remain as similar as possible and thus making
        contributions to ``asyncpraw`` simpler.

        This method does the following:

        - Removes keys that have a value of ``None`` from ``data``.
        - Moves ``files`` into ``data``.

        :param data: Dictionary, bytes, or file-like object to send in the body of the
            request.
        :param files: Dictionary, mapping ``filename`` to file-like object to add to
            ``data``.

        """
        if isinstance(data, dict):
            data = self._preprocess_dict(data)
            if files is not None:
                data.update(files)
        return data

    def _preprocess_params(self, params: dict[str, object]) -> dict[str, object]:
        """Preprocess params before request.

        This is to convert requests that are formatted for the ``requests`` package to
        be compatible with ``aiohttp`` package. The motivation for this is so that
        ``praw`` and ``asyncpraw`` can remain as similar as possible and thus making
        contributions to ``asyncpraw`` simpler.

        This method does the following:

        - Removes keys that have a value of ``None`` from ``params``.
        - Casts bool values in ``params`` to str.

        :param params: The query parameters to send with the request.

        """
        return self._preprocess_dict(params)

    async def _request_with_retries(  # noqa: PLR0912
        self,
        *,
        data: list[tuple[str, object]] | None,
        json: dict[str, object] | None,
        method: str,
        params: dict[str, object],
        retry_strategy_state: FiniteRetryStrategy | None = None,
        timeout: float,
        url: str,
    ) -> dict[str, object] | str | None:
        if retry_strategy_state is None:
            retry_strategy_state = self._retry_strategy_class()

        await retry_strategy_state.sleep()
        self._log_request(data=data, method=method, params=params, url=url)

        try:
            async with self._make_request(
                data=data,
                json=json,
                method=method,
                params=params,
                timeout=timeout,
                url=url,
            ) as response:
                retry_status = None
                if response.status == codes["unauthorized"]:
                    self._authorizer._clear_access_token()
                    if hasattr(self._authorizer, "refresh"):
                        retry_status = f"{response.status} status"
                elif response.status in self.RETRY_STATUSES:
                    retry_status = f"{response.status} status"

                if retry_status is not None and retry_strategy_state.should_retry_on_failure():
                    return await self._do_retry(
                        data=data,
                        json=json,
                        method=method,
                        params=params,
                        retry_strategy_state=retry_strategy_state,
                        status=retry_status,
                        timeout=timeout,
                        url=url,
                    )
                if response.status == codes["no_content"]:
                    return None
                if response.status in self.STATUS_EXCEPTIONS:
                    if response.status == codes["media_type"]:
                        # since exception class needs response.json
                        raise self.STATUS_EXCEPTIONS[response.status](response, await response.json())
                    raise self.STATUS_EXCEPTIONS[response.status](response)
                if response.status not in self.SUCCESS_STATUSES:
                    raise ResponseException(response)
                if response.headers.get("content-length") == "0":
                    return ""
                try:
                    return await response.json()
                except ValueError:
                    raise BadJSON(response) from None
        except RequestException as exception:
            if retry_strategy_state.should_retry_on_failure() and isinstance(  # noqa: E501
                exception.original_exception, self.RETRY_EXCEPTIONS
            ):
                return await self._do_retry(
                    data=data,
                    json=json,
                    method=method,
                    params=params,
                    retry_strategy_state=retry_strategy_state,
                    status=repr(exception.original_exception),
                    timeout=timeout,
                    url=url,
                )
            raise

    async def _set_header_callback(self) -> dict[str, str]:
        refresh_method = getattr(self._authorizer, "refresh", None)
        if not self._authorizer.is_valid() and refresh_method is not None:
            await refresh_method()
        return {"Authorization": f"bearer {self._authorizer.access_token}"}

    async def close(self) -> None:
        """Close the session and perform any clean up."""
        await self._requestor.close()

    async def request(
        self,
        method: str,
        path: str,
        data: dict[str, object] | None = None,
        files: dict[str, BinaryIO | TextIO] | None = None,
        json: dict[str, object] | None = None,
        params: dict[str, object] | None = None,
        timeout: float = TIMEOUT,
    ) -> dict[str, object] | str | None:
        """Return the json content from the resource at ``path``.

        :param method: The request verb. E.g., ``"GET"``, ``"POST"``, ``"PUT"``.
        :param path: The path of the request. This path will be combined with the
            ``oauth_url`` of the Requestor.
        :param data: Dictionary, bytes, or file-like object to send in the body of the
            request.
        :param files: Dictionary, mapping ``filename`` to file-like object.
        :param json: Object to be serialized to JSON in the body of the request.
        :param params: The query parameters to send with the request.
        :param timeout: Specifies a particular timeout, in seconds.

        Automatically refreshes the access token if it becomes invalid and a refresh
        token is available.

        :raises: :class:`.InvalidInvocation` in such a case if a refresh token is not
            available.

        """
        params = self._preprocess_params(deepcopy(params) or {})
        params["raw_json"] = "1"
        if isinstance(data, dict):
            data = self._preprocess_data(deepcopy(data), files)
            data["api_type"] = "json"
            data_list = sorted(data.items())
        else:
            data_list = data
        if isinstance(json, dict):
            json = deepcopy(json)
            json["api_type"] = "json"
        url = urljoin(self._requestor.oauth_url, path)
        return await self._request_with_retries(
            data=data_list,
            json=json,
            method=method,
            params=params,
            timeout=timeout,
            url=url,
        )


def session(
    authorizer: Authorizer | None = None,
    window_size: int = WINDOW_SIZE,
) -> Session:
    """Return a :class:`.Session` instance.

    :param authorizer: An instance of :class:`.Authorizer`.
    :param window_size: The size of the rate limit reset window in seconds.

    """
    return Session(authorizer=authorizer, window_size=window_size)