File: rate_limit.py

package info (click to toggle)
python-asyncprawcore 3.0.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 4,328 kB
  • sloc: python: 2,224; makefile: 4
file content (108 lines) | stat: -rw-r--r-- 3,887 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Provide the RateLimiter class."""

from __future__ import annotations

import asyncio
import logging
import time
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import TYPE_CHECKING, Any, Callable

if TYPE_CHECKING:
    from collections.abc import AsyncGenerator, Awaitable, Mapping

    from aiohttp import ClientResponse

from asyncprawcore.const import NANOSECONDS

log = logging.getLogger(__package__)


class RateLimiter:
    """Facilitates the rate limiting of requests to Reddit.

    Rate limits are controlled based on feedback from requests to Reddit.

    """

    def __init__(self, *, window_size: int) -> None:
        """Create an instance of the RateLimit class."""
        self.remaining: int | None = None
        self.next_request_timestamp_ns: int | None = None
        self.used: int | None = None
        self.window_size: int = window_size

    @asynccontextmanager
    async def call(
        self,
        # async context manager
        request_function: Callable[..., AbstractAsyncContextManager[ClientResponse]],
        set_header_callback: Callable[[], Awaitable[dict[str, str]]],
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[ClientResponse]:
        """Rate limit the call to ``request_function``.

        :param request_function: A function call that returns an HTTP response object.
        :param set_header_callback: A callback function used to set the request headers.
            This callback is called after any necessary sleep time occurs.
        :param args: The positional arguments to ``request_function``.
        :param kwargs: The keyword arguments to ``request_function``.

        """
        await self.delay()
        kwargs["headers"] = await set_header_callback()
        async with request_function(*args, **kwargs) as response:
            self.update(response.headers)
            yield response

    async def delay(self) -> None:
        """Sleep for an amount of time to remain under the rate limit."""
        if self.next_request_timestamp_ns is None:
            return
        sleep_seconds = float(self.next_request_timestamp_ns - time.monotonic_ns()) / NANOSECONDS
        if sleep_seconds <= 0:
            return
        message = f"Sleeping: {sleep_seconds:0.2f} seconds prior to call"
        log.debug(message)
        await asyncio.sleep(sleep_seconds)

    def update(self, response_headers: Mapping[str, str]) -> None:
        """Update the state of the rate limiter based on the response headers.

        This method should only be called following an HTTP request to Reddit.

        Response headers that do not contain ``x-ratelimit`` fields will be treated as a
        single request. This behavior is to error on the safe-side as such responses
        should trigger exceptions that indicate invalid behavior.

        """
        if "x-ratelimit-remaining" not in response_headers:
            if self.remaining is not None and self.used is not None:
                self.remaining -= 1
                self.used += 1
            return

        self.remaining = int(float(response_headers["x-ratelimit-remaining"]))
        self.used = int(response_headers["x-ratelimit-used"])

        now_ns = time.monotonic_ns()
        seconds_to_reset = int(response_headers["x-ratelimit-reset"])

        if self.remaining <= 0:
            self.next_request_timestamp_ns = now_ns + max(NANOSECONDS, seconds_to_reset * NANOSECONDS)
            return

        self.next_request_timestamp_ns = int(
            now_ns
            + min(
                seconds_to_reset,
                max(
                    seconds_to_reset
                    - (self.window_size - self.window_size / (float(self.remaining) + self.used) * self.used),
                    0,
                ),
                10,
            )
            * NANOSECONDS
        )