File: test_postgres_concurrent.py

package info (click to toggle)
pyrate-limiter 4.0.2-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,120 kB
  • sloc: python: 3,223; makefile: 21
file content (182 lines) | stat: -rw-r--r-- 6,207 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import threading

import pytest

from pyrate_limiter import Duration, PostgresBucket, Rate
from pyrate_limiter.abstracts import RateItem


@pytest.mark.postgres
class TestPostgresConcurrent:

    @pytest.fixture
    def pg_pool(self):
        from psycopg_pool import ConnectionPool

        pool = ConnectionPool(
            "postgresql://postgres:postgres@localhost:5432",
            min_size=4,
            max_size=10,
            open=True,
        )
        yield pool
        pool.close()

    @pytest.fixture
    def clean_table(self, pg_pool):
        from pyrate_limiter import id_generator

        table = f"test_concurrent_{id_generator()}"
        yield table
        with pg_pool.connection() as conn:
            conn.execute(f"DROP TABLE IF EXISTS ratelimit___{table}")

    def test_concurrent_put(self, pg_pool, clean_table):
        rate_limit = 5
        rates = [Rate(rate_limit, Duration.SECOND)]
        num_threads = 8
        attempts_per_thread = 10

        results = []
        results_lock = threading.Lock()

        def worker(thread_id: int):
            bucket = PostgresBucket(pg_pool, clean_table, rates)
            thread_results = []

            for _ in range(attempts_per_thread):
                timestamp = bucket.now()
                item = RateItem(f"thread_{thread_id}", timestamp, weight=1)
                success = bucket.put(item)
                thread_results.append((timestamp, success))

            with results_lock:
                results.extend(thread_results)

        threads = [
            threading.Thread(target=worker, args=(i,)) for i in range(num_threads)
        ]

        for t in threads:
            t.start()
        for t in threads:
            t.join()

        # Verify db state
        full_table = f"ratelimit___{clean_table}"
        with pg_pool.connection() as conn:
            cur = conn.execute(f"SELECT COUNT(*) FROM {full_table}")  # noqa: S608
            total_in_db = cur.fetchone()[0]
            cur.close()

            # Check in sliding windows to make sure rate didn't exceed
            cur = conn.execute(
                f"SELECT EXTRACT(EPOCH FROM item_timestamp)::bigint as ts FROM {full_table}"  # noqa: S608
            )
            db_timestamps = [row[0] for row in cur.fetchall()]
            cur.close()

        for ts in db_timestamps:
            window_start = ts - 1  # 1 second window
            count_in_window = sum(1 for t in db_timestamps if t > window_start and t <= ts)
            assert count_in_window <= rate_limit, (
                f"Rate limit exceeded in DB: {count_in_window} items in 1-second window ending at {ts}"
            )

        # Verify anything worked
        total_success = sum(1 for _, success in results if success)
        assert total_success > 0, "No successful acquisitions"
        assert total_success == total_in_db, (
            f"Mismatch: {total_success} reported successes but {total_in_db} items in DB"
        )

        # Verify some rejections
        total_rejected = sum(1 for _, success in results if not success)
        assert total_rejected > 0, (
            "No rejections occurred - rate limiting may not be working"
        )

    def test_concurrent_put_multiple_rates(self, pg_pool, clean_table):
        rates = [
            Rate(3, 500),   # 3 per 500ms
            Rate(5, 1000),  # 5 per second
        ]
        num_threads = 4
        attempts_per_thread = 5

        results = []
        results_lock = threading.Lock()

        def worker(thread_id: int):
            bucket = PostgresBucket(pg_pool, clean_table, rates)
            thread_results = []

            for _ in range(attempts_per_thread):
                timestamp = bucket.now()
                item = RateItem(f"thread_{thread_id}", timestamp, weight=1)
                success = bucket.put(item)
                thread_results.append((timestamp, success))

            with results_lock:
                results.extend(thread_results)

        threads = [
            threading.Thread(target=worker, args=(i,)) for i in range(num_threads)
        ]

        for t in threads:
            t.start()
        for t in threads:
            t.join()

        successful_timestamps = sorted([ts for ts, success in results if success])

        # Check sliding windows for both rates
        for ts in successful_timestamps:
            # 1-second sliding window
            count_1s = sum(1 for t in successful_timestamps if ts - 1000 <= t <= ts)
            assert count_1s <= 5, f"1-second rate exceeded: {count_1s} items in window ending at {ts}"

            # 500ms sliding window
            count_500ms = sum(1 for t in successful_timestamps if ts - 500 <= t <= ts)
            assert count_500ms <= 3, f"500ms rate exceeded: {count_500ms} items in window ending at {ts}"

    def test_concurrent_put_weighted(self, pg_pool, clean_table):
        rate_limit = 10
        rates = [Rate(rate_limit, Duration.SECOND)]
        num_threads = 4
        weight = 3

        results = []
        results_lock = threading.Lock()

        def worker(thread_id: int):
            bucket = PostgresBucket(pg_pool, clean_table, rates)
            thread_results = []

            for _ in range(5):
                timestamp = bucket.now()
                item = RateItem(f"thread_{thread_id}", timestamp, weight=weight)
                success = bucket.put(item)
                thread_results.append((timestamp, success, weight))

            with results_lock:
                results.extend(thread_results)

        threads = [
            threading.Thread(target=worker, args=(i,)) for i in range(num_threads)
        ]

        for t in threads:
            t.start()
        for t in threads:
            t.join()

        successful_results = sorted([(ts, w) for ts, success, w in results if success])

        # Check sliding windows for weighted items
        for ts, _ in successful_results:
            weight_in_window = sum(w for t, w in successful_results if ts - 1000 <= t <= ts)
            assert weight_in_window <= rate_limit, (
                f"Rate limit exceeded: weight {weight_in_window} in 1-second window ending at {ts}"
            )