1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
|
from __future__ import annotations
import dataclasses
import typing
import weakref
from collections.abc import Iterator
from flask import request
from flask.wrappers import Response
from limits import RateLimitItem, parse_many
from limits.strategies import RateLimiter
from limits.util import WindowStats
from .typing import Callable
if typing.TYPE_CHECKING:
from .extension import Limiter
class RequestLimit:
"""
Provides details of a rate limit within the context of a request
"""
#: The instance of the rate limit
limit: RateLimitItem
#: The full key for the request against which the rate limit is tested
key: str
#: Whether the limit was breached within the context of this request
breached: bool
#: Whether the limit is a shared limit
shared: bool
def __init__(
self,
extension: Limiter,
limit: RateLimitItem,
request_args: list[str],
breached: bool,
shared: bool,
) -> None:
self.extension: weakref.ProxyType[Limiter] = weakref.proxy(extension)
self.limit = limit
self.request_args = request_args
self.key = limit.key_for(*request_args)
self.breached = breached
self.shared = shared
self._window: WindowStats | None = None
@property
def limiter(self) -> RateLimiter:
return typing.cast(RateLimiter, self.extension.limiter)
@property
def window(self) -> WindowStats:
if not self._window:
self._window = self.limiter.get_window_stats(self.limit, *self.request_args)
return self._window
@property
def reset_at(self) -> int:
"""Timestamp at which the rate limit will be reset"""
return int(self.window[0] + 1)
@property
def remaining(self) -> int:
"""Quantity remaining for this rate limit"""
return self.window[1]
@dataclasses.dataclass(eq=True, unsafe_hash=True)
class Limit:
"""
simple wrapper to encapsulate limits and their context
"""
limit: RateLimitItem
key_func: Callable[[], str]
_scope: str | Callable[[str], str] | None
per_method: bool = False
methods: tuple[str, ...] | None = None
error_message: str | None = None
exempt_when: Callable[[], bool] | None = None
override_defaults: bool | None = False
deduct_when: Callable[[Response], bool] | None = None
on_breach: Callable[[RequestLimit], Response | None] | None = None
_cost: Callable[[], int] | int = 1
shared: bool = False
def __post_init__(self) -> None:
if self.methods:
self.methods = tuple([k.lower() for k in self.methods])
@property
def is_exempt(self) -> bool:
"""Check if the limit is exempt."""
if self.exempt_when:
return self.exempt_when()
return False
@property
def scope(self) -> str | None:
return (
self._scope(request.endpoint or "")
if callable(self._scope)
else self._scope
)
@property
def cost(self) -> int:
if isinstance(self._cost, int):
return self._cost
return self._cost()
@property
def method_exempt(self) -> bool:
"""Check if the limit is not applicable for this method"""
return self.methods is not None and request.method.lower() not in self.methods
def scope_for(self, endpoint: str, method: str | None) -> str:
"""
Derive final bucket (scope) for this limit given the endpoint
and request method. If the limit is shared between multiple
routes, the scope does not include the endpoint.
"""
limit_scope = self.scope
if limit_scope:
if self.shared:
scope = limit_scope
else:
scope = f"{endpoint}:{limit_scope}"
else:
scope = endpoint
if self.per_method:
assert method
scope += f":{method.upper()}"
return scope
@dataclasses.dataclass(eq=True, unsafe_hash=True)
class LimitGroup:
"""
represents a group of related limits either from a string or a callable
that returns one
"""
limit_provider: Callable[[], str] | str
key_function: Callable[[], str]
scope: str | Callable[[str], str] | None = None
methods: tuple[str, ...] | None = None
error_message: str | None = None
exempt_when: Callable[[], bool] | None = None
override_defaults: bool | None = False
deduct_when: Callable[[Response], bool] | None = None
on_breach: Callable[[RequestLimit], Response | None] | None = None
per_method: bool = False
cost: Callable[[], int] | int | None = None
shared: bool = False
def __iter__(self) -> Iterator[Limit]:
limit_str = (
self.limit_provider()
if callable(self.limit_provider)
else self.limit_provider
)
limit_items = parse_many(limit_str) if limit_str else []
for limit in limit_items:
yield Limit(
limit,
self.key_function,
self.scope,
self.per_method,
self.methods,
self.error_message,
self.exempt_when,
self.override_defaults,
self.deduct_when,
self.on_breach,
self.cost or 1,
self.shared,
)
|