File: manager.py

package info (click to toggle)
flask-limiter 3.12-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,264 kB
  • sloc: python: 6,432; makefile: 165; sh: 67
file content (261 lines) | stat: -rw-r--r-- 10,679 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
from __future__ import annotations

import itertools
import logging
from collections.abc import Iterable

import flask
from ordered_set import OrderedSet

from .constants import ExemptionScope
from .util import get_qualified_name
from .wrappers import Limit, LimitGroup


class LimitManager:
    def __init__(
        self,
        application_limits: list[LimitGroup],
        default_limits: list[LimitGroup],
        decorated_limits: dict[str, OrderedSet[LimitGroup]],
        blueprint_limits: dict[str, OrderedSet[LimitGroup]],
        route_exemptions: dict[str, ExemptionScope],
        blueprint_exemptions: dict[str, ExemptionScope],
    ) -> None:
        self._application_limits = application_limits
        self._default_limits = default_limits
        self._decorated_limits = decorated_limits
        self._blueprint_limits = blueprint_limits
        self._route_exemptions = route_exemptions
        self._blueprint_exemptions = blueprint_exemptions
        self._endpoint_hints: dict[str, OrderedSet[str]] = {}
        self._logger = logging.getLogger("flask-limiter")

    @property
    def application_limits(self) -> list[Limit]:
        return list(itertools.chain(*self._application_limits))

    @property
    def default_limits(self) -> list[Limit]:
        return list(itertools.chain(*self._default_limits))

    def set_application_limits(self, limits: list[LimitGroup]) -> None:
        self._application_limits = limits

    def set_default_limits(self, limits: list[LimitGroup]) -> None:
        self._default_limits = limits

    def add_decorated_limit(
        self, route: str, limit: LimitGroup | None, override: bool = False
    ) -> None:
        if limit:
            if not override:
                self._decorated_limits.setdefault(route, OrderedSet()).add(limit)
            else:
                self._decorated_limits[route] = OrderedSet([limit])

    def add_blueprint_limit(self, blueprint: str, limit: LimitGroup | None) -> None:
        if limit:
            self._blueprint_limits.setdefault(blueprint, OrderedSet()).add(limit)

    def add_route_exemption(self, route: str, scope: ExemptionScope) -> None:
        self._route_exemptions[route] = scope

    def add_blueprint_exemption(self, blueprint: str, scope: ExemptionScope) -> None:
        self._blueprint_exemptions[blueprint] = scope

    def add_endpoint_hint(self, endpoint: str, callable: str) -> None:
        self._endpoint_hints.setdefault(endpoint, OrderedSet()).add(callable)

    def has_hints(self, endpoint: str) -> bool:
        return bool(self._endpoint_hints.get(endpoint))

    def resolve_limits(
        self,
        app: flask.Flask,
        endpoint: str | None = None,
        blueprint: str | None = None,
        callable_name: str | None = None,
        in_middleware: bool = False,
        marked_for_limiting: bool = False,
    ) -> tuple[list[Limit], ...]:
        before_request_context = in_middleware and marked_for_limiting
        decorated_limits = []
        hinted_limits = []
        if endpoint:
            if not in_middleware:
                if not callable_name:
                    view_func = app.view_functions.get(endpoint, None)
                    name = get_qualified_name(view_func) if view_func else ""
                else:
                    name = callable_name
                decorated_limits.extend(self.decorated_limits(name))

            for hint in self._endpoint_hints.get(endpoint, OrderedSet()):
                hinted_limits.extend(self.decorated_limits(hint))

        if blueprint:
            if not before_request_context and (
                not decorated_limits
                or all(not limit.override_defaults for limit in decorated_limits)
            ):
                decorated_limits.extend(self.blueprint_limits(app, blueprint))
        exemption_scope = self.exemption_scope(app, endpoint, blueprint)

        all_limits = (
            self.application_limits
            if in_middleware and not (exemption_scope & ExemptionScope.APPLICATION)
            else []
        )
        # all_limits += decorated_limits
        explicit_limits_exempt = all(limit.method_exempt for limit in decorated_limits)

        # all  the decorated limits explicitly declared
        # that they don't override the defaults - so, they should
        # be included.
        combined_defaults = all(
            not limit.override_defaults for limit in decorated_limits
        )
        # previous requests to this endpoint have exercised decorated
        # rate limits on callables that are not view functions. check
        # if all of them declared that they don't override defaults
        # and if so include the default limits.
        hinted_limits_request_defaults = (
            all(not limit.override_defaults for limit in hinted_limits)
            if hinted_limits
            else False
        )
        if (
            (explicit_limits_exempt or combined_defaults)
            and (
                not (before_request_context or exemption_scope & ExemptionScope.DEFAULT)
            )
        ) or hinted_limits_request_defaults:
            all_limits += self.default_limits
        return all_limits, decorated_limits

    def exemption_scope(
        self, app: flask.Flask, endpoint: str | None, blueprint: str | None
    ) -> ExemptionScope:
        view_func = app.view_functions.get(endpoint or "", None)
        name = get_qualified_name(view_func) if view_func else ""
        route_exemption_scope = self._route_exemptions.get(name, ExemptionScope.NONE)
        blueprint_instance = app.blueprints.get(blueprint) if blueprint else None

        if not blueprint_instance:
            return route_exemption_scope
        else:
            assert blueprint
            (
                blueprint_exemption_scope,
                ancestor_exemption_scopes,
            ) = self._blueprint_exemption_scope(app, blueprint)
            if (
                blueprint_exemption_scope
                & ~(ExemptionScope.DEFAULT | ExemptionScope.APPLICATION)
                or ancestor_exemption_scopes
            ):
                for exemption in ancestor_exemption_scopes.values():
                    blueprint_exemption_scope |= exemption
            return route_exemption_scope | blueprint_exemption_scope

    def decorated_limits(self, callable_name: str) -> list[Limit]:
        limits = []
        if not self._route_exemptions.get(callable_name, ExemptionScope.NONE):
            if callable_name in self._decorated_limits:
                for group in self._decorated_limits[callable_name]:
                    try:
                        for limit in group:
                            limits.append(limit)
                    except ValueError as e:
                        self._logger.error(
                            f"failed to load ratelimit for function {callable_name}: {e}",
                        )
        return limits

    def blueprint_limits(self, app: flask.Flask, blueprint: str) -> list[Limit]:
        limits: list[Limit] = []

        blueprint_instance = app.blueprints.get(blueprint) if blueprint else None
        if blueprint_instance:
            blueprint_name = blueprint_instance.name
            blueprint_ancestory = set(blueprint.split(".") if blueprint else [])

            self_exemption, ancestor_exemptions = self._blueprint_exemption_scope(
                app, blueprint
            )

            if not (
                self_exemption & ~(ExemptionScope.DEFAULT | ExemptionScope.APPLICATION)
            ):
                blueprint_self_limits = self._blueprint_limits.get(
                    blueprint_name, OrderedSet()
                )
                blueprint_limits: Iterable[LimitGroup] = (
                    itertools.chain(
                        *(
                            self._blueprint_limits.get(member, [])
                            for member in blueprint_ancestory.intersection(
                                self._blueprint_limits
                            ).difference(ancestor_exemptions)
                        )
                    )
                    if not (
                        blueprint_self_limits
                        and all(
                            limit.override_defaults for limit in blueprint_self_limits
                        )
                    )
                    and not self._blueprint_exemptions.get(
                        blueprint_name, ExemptionScope.NONE
                    )
                    & ExemptionScope.ANCESTORS
                    else blueprint_self_limits
                )
                if blueprint_limits:
                    for limit_group in blueprint_limits:
                        try:
                            limits.extend(
                                [
                                    Limit(
                                        limit.limit,
                                        limit.key_func,
                                        limit.scope,
                                        limit.per_method,
                                        limit.methods,
                                        limit.error_message,
                                        limit.exempt_when,
                                        limit.override_defaults,
                                        limit.deduct_when,
                                        limit.on_breach,
                                        limit.cost,
                                        limit.shared,
                                    )
                                    for limit in limit_group
                                ]
                            )
                        except ValueError as e:
                            self._logger.error(
                                f"failed to load ratelimit for blueprint {blueprint_name}: {e}",
                            )
        return limits

    def _blueprint_exemption_scope(
        self, app: flask.Flask, blueprint_name: str
    ) -> tuple[ExemptionScope, dict[str, ExemptionScope]]:
        name = app.blueprints[blueprint_name].name
        exemption = self._blueprint_exemptions.get(name, ExemptionScope.NONE) & ~(
            ExemptionScope.ANCESTORS
        )

        ancestory = set(blueprint_name.split("."))
        ancestor_exemption = {
            k
            for k, f in self._blueprint_exemptions.items()
            if f & ExemptionScope.DESCENDENTS
        }.intersection(ancestory)

        return exemption, {
            k: self._blueprint_exemptions.get(k, ExemptionScope.NONE)
            for k in ancestor_exemption
        }