File: ratelimit.py

package info (click to toggle)
django-allauth 65.0.2-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 9,672 kB
  • sloc: python: 34,411; javascript: 3,070; xml: 849; makefile: 235; sh: 8
file content (149 lines) | stat: -rw-r--r-- 4,408 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import hashlib
import time
from collections import namedtuple
from typing import Optional

from django.conf import settings
from django.core.cache import cache
from django.core.exceptions import ImproperlyConfigured
from django.http import HttpResponse
from django.shortcuts import render

from allauth import app_settings
from allauth.utils import import_callable


Rate = namedtuple("Rate", "amount duration per")


def _parse_duration(duration):
    if len(duration) == 0:
        raise ValueError(duration)
    unit = duration[-1]
    value = duration[0:-1]
    unit_map = {"s": 1, "m": 60, "h": 3600, "d": 86400}
    if unit not in unit_map:
        raise ValueError("Invalid duration unit: %s" % unit)
    if len(value) == 0:
        value = 1
    else:
        value = float(value)
    return value * unit_map[unit]


def _parse_rate(rate):
    parts = rate.split("/")
    if len(parts) == 2:
        amount, duration = parts
        per = "ip"
    elif len(parts) == 3:
        amount, duration, per = parts
    else:
        raise ValueError(rate)
    amount = int(amount)
    duration = _parse_duration(duration)
    return Rate(amount, duration, per)


def _parse_rates(rates):
    ret = []
    if rates:
        rates = rates.strip()
        if rates:
            parts = rates.split(",")
            for part in parts:
                ret.append(_parse_rate(part.strip()))
    return ret


def _cache_key(request, *, action, rate, key=None, user=None):
    from allauth.account.adapter import get_adapter

    if rate.per == "ip":
        source = ("ip", get_adapter().get_client_ip(request))
    elif rate.per == "user":
        if user is None:
            if not request.user.is_authenticated:
                raise ImproperlyConfigured(
                    "ratelimit configured per user but used anonymously"
                )
            user = request.user
        source = ("user", str(user.pk))
    elif rate.per == "key":
        if key is None:
            raise ImproperlyConfigured(
                "ratelimit configured per key but no key specified"
            )
        key_hash = hashlib.sha256(key.encode("utf8")).hexdigest()
        source = (key_hash,)
    else:
        raise ValueError(rate.per)
    keys = ["allauth", "rl", action, *source]
    return ":".join(keys)


def clear(request, *, action, key=None, user=None):
    from allauth.account import app_settings

    rates = _parse_rates(app_settings.RATE_LIMITS.get(action))
    for rate in rates:
        cache_key = _cache_key(request, action=action, rate=rate, key=key, user=user)
        cache.delete(cache_key)


def consume(request, *, action, key=None, user=None, dry_run: bool = False):
    from allauth.account import app_settings

    if not request or request.method == "GET":
        return True

    rates = _parse_rates(app_settings.RATE_LIMITS.get(action))
    if not rates:
        return True

    allowed = True
    for rate in rates:
        if not _consume_rate(
            request, action=action, rate=rate, key=key, user=user, dry_run=dry_run
        ):
            allowed = False
    return allowed


def _consume_rate(request, *, action, rate, key=None, user=None, dry_run: bool = False):
    cache_key = _cache_key(request, action=action, rate=rate, key=key, user=user)
    history = cache.get(cache_key, [])
    now = time.time()
    while history and history[-1] <= now - rate.duration:
        history.pop()
    allowed = len(history) < rate.amount
    if allowed and not dry_run:
        history.insert(0, now)
        cache.set(cache_key, history, rate.duration)
    return allowed


def _handler429(request):
    from allauth.account import app_settings

    return render(request, "429." + app_settings.TEMPLATE_EXTENSION, status=429)


def respond_429(request) -> HttpResponse:
    if app_settings.HEADLESS_ENABLED and hasattr(request.allauth, "headless"):
        from allauth.headless.base.response import RateLimitResponse

        return RateLimitResponse(request)

    try:
        handler429 = import_callable(settings.ROOT_URLCONF + ".handler429")
        handler429 = import_callable(handler429)
    except (ImportError, AttributeError):
        handler429 = _handler429
    return handler429(request)


def consume_or_429(request, *args, **kwargs) -> Optional[HttpResponse]:
    if not consume(request, *args, **kwargs):
        return respond_429(request)
    return None