File: redis.py

package info (click to toggle)
python-limits 4.4.1-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,064 kB
  • sloc: python: 7,833; makefile: 162; sh: 59
file content (314 lines) | stat: -rw-r--r-- 10,613 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
from __future__ import annotations

import time
from typing import TYPE_CHECKING, cast

from deprecated.sphinx import versionchanged
from packaging.version import Version

from limits.typing import Literal, RedisClient

from ..util import get_package_data
from .base import MovingWindowSupport, SlidingWindowCounterSupport, Storage

if TYPE_CHECKING:
    import redis


@versionchanged(
    version="4.3",
    reason=(
        "Added support for using the redis client from :pypi:`valkey`"
        " if :paramref:`uri` has the ``valkey://`` schema"
    ),
)
class RedisStorage(Storage, MovingWindowSupport, SlidingWindowCounterSupport):
    """
    Rate limit storage with redis as backend.

    Depends on :pypi:`redis` (or :pypi:`valkey` if :paramref:`uri` starts with
    ``valkey://``)
    """

    STORAGE_SCHEME = [
        "redis",
        "rediss",
        "redis+unix",
        "valkey",
        "valkeys",
        "valkey+unix",
    ]
    """The storage scheme for redis"""

    DEPENDENCIES = {"redis": Version("3.0"), "valkey": Version("6.0")}

    RES_DIR = "resources/redis/lua_scripts"

    SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
    SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
        f"{RES_DIR}/acquire_moving_window.lua"
    )
    SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
    SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")

    SCRIPT_SLIDING_WINDOW = get_package_data(f"{RES_DIR}/sliding_window.lua")
    SCRIPT_ACQUIRE_SLIDING_WINDOW = get_package_data(
        f"{RES_DIR}/acquire_sliding_window.lua"
    )

    lua_moving_window: redis.commands.core.Script
    lua_acquire_moving_window: redis.commands.core.Script
    lua_sliding_window: redis.commands.core.Script
    lua_acquire_sliding_window: redis.commands.core.Script

    PREFIX = "LIMITS"
    target_server: Literal["redis", "valkey"]

    def __init__(
        self,
        uri: str,
        connection_pool: redis.connection.ConnectionPool | None = None,
        wrap_exceptions: bool = False,
        **options: float | str | bool,
    ) -> None:
        """
        :param uri: uri of the form ``redis://[:password]@host:port``,
         ``redis://[:password]@host:port/db``,
         ``rediss://[:password]@host:port``, ``redis+unix:///path/to/sock`` etc.
         This uri is passed directly to :func:`redis.from_url` except for the
         case of ``redis+unix://`` where it is replaced with ``unix://``.

         If the uri scheme is ``valkey`` the implementation used will be from
         :pypi:`valkey`.
        :param connection_pool: if provided, the redis client is initialized with
         the connection pool and any other params passed as :paramref:`options`
        :param wrap_exceptions: Whether to wrap storage exceptions in
         :exc:`limits.errors.StorageError` before raising it.
        :param options: all remaining keyword arguments are passed
         directly to the constructor of :class:`redis.Redis`
        :raise ConfigurationError: when the :pypi:`redis` library is not available
        """
        super().__init__(uri, wrap_exceptions=wrap_exceptions, **options)
        self.target_server = "valkey" if uri.startswith("valkey") else "redis"
        self.dependency = self.dependencies[self.target_server].module

        uri = uri.replace(f"{self.target_server}+unix", "unix")

        if not connection_pool:
            self.storage = self.dependency.from_url(uri, **options)
        else:
            if self.target_server == "redis":
                self.storage = self.dependency.Redis(
                    connection_pool=connection_pool, **options
                )
            else:
                self.storage = self.dependency.Valkey(
                    connection_pool=connection_pool, **options
                )
        self.initialize_storage(uri)

    @property
    def base_exceptions(
        self,
    ) -> type[Exception] | tuple[type[Exception], ...]:  # pragma: no cover
        return (  # type: ignore[no-any-return]
            self.dependency.RedisError
            if self.target_server == "redis"
            else self.dependency.ValkeyError
        )

    def initialize_storage(self, _uri: str) -> None:
        self.lua_moving_window = self.get_connection().register_script(
            self.SCRIPT_MOVING_WINDOW
        )
        self.lua_acquire_moving_window = self.get_connection().register_script(
            self.SCRIPT_ACQUIRE_MOVING_WINDOW
        )
        self.lua_clear_keys = self.get_connection().register_script(
            self.SCRIPT_CLEAR_KEYS
        )
        self.lua_incr_expire = self.get_connection().register_script(
            self.SCRIPT_INCR_EXPIRE
        )
        self.lua_sliding_window = self.get_connection().register_script(
            self.SCRIPT_SLIDING_WINDOW
        )
        self.lua_acquire_sliding_window = self.get_connection().register_script(
            self.SCRIPT_ACQUIRE_SLIDING_WINDOW
        )

    def get_connection(self, readonly: bool = False) -> RedisClient:
        return cast(RedisClient, self.storage)

    def _current_window_key(self, key: str) -> str:
        """
        Return the current window's storage key (Sliding window strategy)

        Contrary to other strategies that have one key per rate limit item,
        this strategy has two keys per rate limit item than must be on the same machine.
        To keep the current key and the previous key on the same Redis cluster node,
        curly braces are added.

        Eg: "{constructed_key}"
        """
        return f"{{{key}}}"

    def _previous_window_key(self, key: str) -> str:
        """
        Return the previous window's storage key (Sliding window strategy).

        Curvy braces are added on the common pattern with the current window's key,
        so the current and the previous key are stored on the same Redis cluster node.

        Eg: "{constructed_key}/-1"
        """
        return f"{self._current_window_key(key)}/-1"

    def prefixed_key(self, key: str) -> str:
        return f"{self.PREFIX}:{key}"

    def get_moving_window(self, key: str, limit: int, expiry: int) -> tuple[float, int]:
        """
        returns the starting point and the number of entries in the moving
        window

        :param key: rate limit key
        :param expiry: expiry of entry
        :return: (start of window, number of acquired entries)
        """
        key = self.prefixed_key(key)
        timestamp = time.time()
        if window := self.lua_moving_window([key], [timestamp - expiry, limit]):
            return float(window[0]), window[1]

        return timestamp, 0

    def get_sliding_window(
        self, key: str, expiry: int
    ) -> tuple[int, float, int, float]:
        previous_key = self.prefixed_key(self._previous_window_key(key))
        current_key = self.prefixed_key(self._current_window_key(key))
        if window := self.lua_sliding_window([previous_key, current_key], [expiry]):
            return (
                int(window[0] or 0),
                max(0, float(window[1] or 0)) / 1000,
                int(window[2] or 0),
                max(0, float(window[3] or 0)) / 1000,
            )
        return 0, 0.0, 0, 0.0

    def incr(
        self,
        key: str,
        expiry: int,
        elastic_expiry: bool = False,
        amount: int = 1,
    ) -> int:
        """
        increments the counter for a given rate limit key


        :param key: the key to increment
        :param expiry: amount in seconds for the key to expire in
        :param amount: the number to increment by
        """
        key = self.prefixed_key(key)
        if elastic_expiry:
            value = self.get_connection().incrby(key, amount)
            self.get_connection().expire(key, expiry)
            return value
        else:
            return int(self.lua_incr_expire([key], [expiry, amount]))

    def get(self, key: str) -> int:
        """

        :param key: the key to get the counter value for
        """

        key = self.prefixed_key(key)
        return int(self.get_connection(True).get(key) or 0)

    def clear(self, key: str) -> None:
        """
        :param key: the key to clear rate limits for
        """
        key = self.prefixed_key(key)
        self.get_connection().delete(key)

    def acquire_entry(
        self,
        key: str,
        limit: int,
        expiry: int,
        amount: int = 1,
    ) -> bool:
        """
        :param key: rate limit key to acquire an entry in
        :param limit: amount of entries allowed
        :param expiry: expiry of the entry

        :param amount: the number of entries to acquire
        """
        key = self.prefixed_key(key)
        timestamp = time.time()
        acquired = self.lua_acquire_moving_window(
            [key], [timestamp, limit, expiry, amount]
        )

        return bool(acquired)

    def acquire_sliding_window_entry(
        self,
        key: str,
        limit: int,
        expiry: int,
        amount: int = 1,
    ) -> bool:
        """
        Acquire an entry. Shift the current window to the previous window if it expired.

        :param key: rate limit key to acquire an entry in
        :param limit: amount of entries allowed
        :param expiry: expiry of the entry
        :param amount: the number of entries to acquire
        """
        previous_key = self.prefixed_key(self._previous_window_key(key))
        current_key = self.prefixed_key(self._current_window_key(key))
        acquired = self.lua_acquire_sliding_window(
            [previous_key, current_key], [limit, expiry, amount]
        )
        return bool(acquired)

    def get_expiry(self, key: str) -> float:
        """
        :param key: the key to get the expiry for

        """

        key = self.prefixed_key(key)
        return max(self.get_connection(True).ttl(key), 0) + time.time()

    def check(self) -> bool:
        """
        check if storage is healthy
        """
        try:
            return self.get_connection().ping()
        except:  # noqa
            return False

    def reset(self) -> int | None:
        """
        This function calls a Lua Script to delete keys prefixed with
        ``self.PREFIX`` in blocks of 5000.

        .. warning::
           This operation was designed to be fast, but was not tested
           on a large production based system. Be careful with its usage as it
           could be slow on very large data sets.

        """

        prefix = self.prefixed_key("*")
        return int(self.lua_clear_keys([prefix]))