File: base.py

package info (click to toggle)
python-fakeredis 2.29.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,772 kB
  • sloc: python: 19,002; sh: 8; makefile: 5
file content (344 lines) | stat: -rw-r--r-- 12,605 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
import functools
import math
import operator
import string
import sys
from typing import Any

import hypothesis
import hypothesis.stateful
import hypothesis.strategies as st
import pytest
import redis
from hypothesis.stateful import rule, initialize, precondition
from hypothesis.strategies import SearchStrategy

import fakeredis
from ._server_info import redis_ver, floats_kwargs, server_type

self_strategy = st.runner()


@st.composite
def sample_attr(draw, name):
    """Strategy for sampling a specific attribute from a state machine"""
    machine = draw(self_strategy)
    values = getattr(machine, name)
    position = draw(st.integers(min_value=0, max_value=len(values) - 1))
    return values[position]


keys = sample_attr("keys")
fields = sample_attr("fields")
values = sample_attr("values")
scores = sample_attr("scores")

eng_text = st.builds(lambda x: x.encode(), st.text(alphabet=string.ascii_letters, min_size=1))
ints = st.integers(min_value=-2_147_483_648, max_value=2_147_483_647)
int_as_bytes = st.builds(lambda x: str(default_normalize(x)).encode(), ints)
floats = st.floats(width=32, **floats_kwargs)
float_as_bytes = st.builds(lambda x: repr(default_normalize(x)).encode(), floats)
counts = st.integers(min_value=-3, max_value=3) | ints
# Redis has an integer overflow bug in swapdb, so we confine the numbers to
# a limited range (https://github.com/antirez/redis/issues/5737).
dbnums = st.integers(min_value=0, max_value=3) | st.integers(min_value=-1000, max_value=1000)
# The filter is to work around https://github.com/antirez/redis/issues/5632
patterns = st.text(alphabet=st.sampled_from("[]^$*.?-azAZ\\\r\n\t")) | st.binary().filter(lambda x: b"\0" not in x)
string_tests = st.sampled_from([b"+", b"-"]) | st.builds(operator.add, st.sampled_from([b"(", b"["]), fields)
# Redis has integer overflow bugs in time computations, which is why we set a maximum.
expires_seconds = st.integers(min_value=5, max_value=1_000)
expires_ms = st.integers(min_value=5_000, max_value=50_000)


class WrappedException:
    """Wraps an exception for comparison."""

    def __init__(self, exc):
        self.wrapped = exc

    def __str__(self):
        return str(self.wrapped)

    def __repr__(self):
        return "WrappedException({!r})".format(self.wrapped)

    def __eq__(self, other):
        if not isinstance(other, WrappedException):
            return NotImplemented
        if type(self.wrapped) != type(other.wrapped):  # noqa: E721
            return False
        return True
        # return self.wrapped.args == other.wrapped.args

    def __ne__(self, other):
        if not isinstance(other, WrappedException):
            return NotImplemented
        return not self == other


def wrap_exceptions(obj):
    if isinstance(obj, list):
        return [wrap_exceptions(item) for item in obj]
    elif isinstance(obj, Exception):
        return WrappedException(obj)
    else:
        return obj


def sort_list(lst):
    if isinstance(lst, list):
        return sorted(lst)
    else:
        return lst


def normalize_if_number(x):
    if isinstance(x, list):
        return [normalize_if_number(item) for item in x]
    try:
        res = float(x)
        return x if math.isnan(res) else res
    except ValueError:
        return x


def flatten(args):
    if isinstance(args, (list, tuple)):
        for arg in args:
            yield from flatten(arg)
    elif args is not None:
        yield args


def default_normalize(x: Any) -> Any:
    if redis_ver >= (7,) and (isinstance(x, float) or isinstance(x, int)):
        return 0 + x

    return x


def optional(arg: Any) -> st.SearchStrategy:
    return st.none() | st.just(arg)


def zero_or_more(*args: Any):
    return [optional(arg) for arg in args]


class Command:
    def __init__(self, *args):
        args = list(flatten(args))
        args = [default_normalize(x) for x in args]
        self.args = tuple(args)

    def __repr__(self):
        parts = [repr(arg) for arg in self.args]
        return "Command({})".format(", ".join(parts))

    @staticmethod
    def encode(arg):
        encoder = redis.connection.Encoder("utf-8", "replace", False)
        return encoder.encode(arg)

    @property
    def normalize(self):
        command = self.encode(self.args[0]).lower() if self.args else None
        # Functions that return a list in arbitrary order
        unordered = {
            b"keys",
            b"sort",
            b"hgetall",
            b"hkeys",
            b"hvals",
            b"sdiff",
            b"sinter",
            b"sunion",
            b"smembers",
            b"hexpire",
        }
        if command in unordered:
            return sort_list
        else:
            return normalize_if_number

    @property
    def testable(self):
        """Whether this command is suitable for a test.

        The fuzzer can create commands with behaviour that is
        non-deterministic, not supported, or which hits redis bugs.
        """
        N = len(self.args)
        if N == 0:
            return False
        command = self.encode(self.args[0]).lower()
        if not command.split():
            return False
        if command == b"keys" and N == 2 and self.args[1] != b"*":
            return False
        # Redis will ignore a NULL character in some commands but not others,
        # e.g., it recognizes EXEC\0 but not MULTI\00.
        # Rather than try to reproduce this quirky behavior, just skip these tests.
        if b"\0" in command:
            return False
        return True


def commands(*args, **kwargs):
    return st.builds(functools.partial(Command, **kwargs), *args)


# # TODO: all expiry-related commands
common_commands = (
    commands(st.sampled_from(["del", "persist", "type", "unlink"]), keys)
    | commands(st.just("exists"), st.lists(keys))
    | commands(st.just("keys"), st.just("*"))
    # Disabled for now due to redis giving wrong answers
    # (https://github.com/antirez/redis/issues/5632)
    # | commands(st.just('keys'), patterns)
    | commands(st.just("move"), keys, dbnums)
    | commands(st.sampled_from(["rename", "renamenx"]), keys, keys)
    # TODO: find a better solution to sort instability than throwing
    #  away the sort entirely with normalize. This also prevents us
    #  using LIMIT.
    | commands(st.just("sort"), keys, *zero_or_more("asc", "desc", "alpha"))
)


@hypothesis.settings(max_examples=1000)
class CommonMachine(hypothesis.stateful.RuleBasedStateMachine):
    create_command_strategy = st.nothing()

    def __init__(self):
        super().__init__()
        try:
            self.real = redis.StrictRedis("localhost", port=6390, db=2)
            self.real.ping()
        except redis.ConnectionError:
            pytest.skip("redis is not running")
        if self.real.info("server").get("arch_bits") != 64:
            self.real.connection_pool.disconnect()
            pytest.skip("redis server is not 64-bit")
        self.fake = fakeredis.FakeStrictRedis(server=fakeredis.FakeServer(version=redis_ver), port=6390, db=2)
        # Disable the response parsing so that we can check the raw values returned
        self.fake.response_callbacks.clear()
        self.real.response_callbacks.clear()
        self.transaction_normalize = []
        self.keys = []
        self.fields = []
        self.values = []
        self.scores = []
        self.initialized_data = False
        try:
            self.real.execute_command("discard")
        except redis.ResponseError:
            pass
        self.real.flushall()

    def teardown(self):
        self.real.connection_pool.disconnect()
        self.fake.connection_pool.disconnect()
        super().teardown()

    @staticmethod
    def _evaluate(client, command):
        try:
            result = client.execute_command(*command.args)
            if result != "QUEUED":
                result = command.normalize(result)
            exc = None
        except Exception as e:
            result = exc = e
        return wrap_exceptions(result), exc

    def _compare(self, command):
        fake_result, fake_exc = self._evaluate(self.fake, command)
        real_result, real_exc = self._evaluate(self.real, command)

        if fake_exc is not None and real_exc is None:
            print(f"{fake_exc} raised on only on fake when running {command}", file=sys.stderr)
            raise fake_exc
        elif real_exc is not None and fake_exc is None:
            assert real_exc == fake_exc, f"Expected exception `{real_exc}` not raised when running {command}"
        elif real_exc is None and isinstance(real_result, list) and command.args and command.args[0].lower() == "exec":
            assert fake_result is not None
            # Transactions need to use the normalize functions of the
            # component commands.
            assert len(self.transaction_normalize) == len(real_result)
            assert len(self.transaction_normalize) == len(fake_result)
            for n, r, f in zip(self.transaction_normalize, real_result, fake_result):
                assert n(f) == n(r)
            self.transaction_normalize = []
        elif isinstance(fake_result, list):
            assert len(fake_result) == len(real_result), (
                f"Discrepancy when running command {command}, fake({fake_result}) != real({real_result})",
            )
            for i in range(len(fake_result)):
                assert fake_result[i] == real_result[i] or (
                    type(fake_result[i]) is float and fake_result[i] == pytest.approx(real_result[i])
                ), f"Discrepancy when running command {command}, fake({fake_result}) != real({real_result})"

        else:
            assert fake_result == real_result or (
                type(fake_result) is float and fake_result == pytest.approx(real_result)
            ), f"Discrepancy when running command {command}, fake({fake_result}) != real({real_result})"
            if real_result == b"QUEUED":
                # Since redis removes the distinction between simple strings and
                # bulk strings, this might not actually indicate that we're in a
                # transaction. But it is extremely unlikely that hypothesis will
                # find such examples.
                self.transaction_normalize.append(command.normalize)
        if len(command.args) == 1 and Command.encode(command.args[0]).lower() in (b"discard", b"exec"):
            self.transaction_normalize = []

    @initialize(
        attrs=st.fixed_dictionaries(
            dict(
                keys=st.lists(eng_text, min_size=2, max_size=5, unique=True),
                fields=st.lists(eng_text, min_size=2, max_size=5, unique=True),
                values=st.lists(eng_text | int_as_bytes | float_as_bytes, min_size=2, max_size=5, unique=True),
                scores=st.lists(floats, min_size=2, max_size=5, unique=True),
            )
        )
    )
    def init_attrs(self, attrs):
        for key, value in attrs.items():
            setattr(self, key, value)

    # hypothesis doesn't allow ordering of @initialize, so we have to put
    # preconditions on rules to ensure we call init_data exactly once and
    # after init_attrs.
    @precondition(lambda self: not self.initialized_data)
    @rule(commands=self_strategy.flatmap(lambda self: st.lists(self.create_command_strategy)))
    def init_data(self, commands):
        for command in commands:
            self._compare(command)
        self.initialized_data = True

    @precondition(lambda self: self.initialized_data)
    @rule(command=self_strategy.flatmap(lambda self: self.command_strategy))
    def one_command(self, command):
        self._compare(command)


class BaseTest:
    """Base class for test classes."""

    command_strategy: SearchStrategy
    create_command_strategy = st.nothing()
    command_strategy_redis7 = st.nothing()
    command_strategy_redis_only = st.nothing()

    @pytest.mark.slow
    def test(self):
        class Machine(CommonMachine):
            create_command_strategy = self.create_command_strategy
            command_strategy = self.command_strategy
            if server_type == "redis":
                command_strategy = command_strategy | self.command_strategy_redis_only
            if server_type == "redis" and redis_ver >= (7,):
                command_strategy = command_strategy | self.command_strategy_redis7

        # hypothesis.settings.register_profile("debug", max_examples=10, verbosity=hypothesis.Verbosity.debug)
        # hypothesis.settings.load_profile("debug")
        hypothesis.stateful.run_state_machine_as_test(Machine)