File: _cms_mixin.py

package info (click to toggle)
python-fakeredis 2.29.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,772 kB
  • sloc: python: 19,002; sh: 8; makefile: 5
file content (122 lines) | stat: -rw-r--r-- 4,696 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Command mixin for emulating `redis-py`'s Count-min sketch functionality."""

from typing import Optional, Tuple, List, Any

import probables

from fakeredis import _msgs as msgs
from fakeredis._commands import command, CommandItem, Int, Key, Float
from fakeredis._helpers import OK, SimpleString, SimpleError, casematch, Database


class CountMinSketch(probables.CountMinSketch):
    def __init__(
        self,
        width: Optional[int] = None,
        depth: Optional[int] = None,
        probability: Optional[float] = None,
        error_rate: Optional[float] = None,
    ):
        super().__init__(width=width, depth=depth, error_rate=error_rate, confidence=probability)


class CMSCommandsMixin:
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self._db: Database

    @command(
        name="CMS.INCRBY",
        fixed=(Key(CountMinSketch), bytes, bytes),
        repeat=(
            bytes,
            bytes,
        ),
        flags=msgs.FLAG_DO_NOT_CREATE,
    )
    def cms_incrby(self, key: CommandItem, *args: bytes) -> List[Tuple[bytes, int]]:
        if key.value is None:
            raise SimpleError("CMS: key does not exist")
        pairs: List[Tuple[bytes, int]] = []
        for i in range(0, len(args), 2):
            try:
                pairs.append((args[i], int(args[i + 1])))
            except ValueError:
                raise SimpleError("CMS: Cannot parse number")
        res = []
        for pair in pairs:
            res.append(key.value.add(pair[0], pair[1]))
        key.updated()
        return res

    @command(
        name="CMS.INFO",
        fixed=(Key(CountMinSketch),),
        repeat=(),
        flags=msgs.FLAG_DO_NOT_CREATE,
    )
    def cms_info(self, key: CommandItem) -> List[bytes]:
        if key.value is None:
            raise SimpleError("CMS: key does not exist")
        return [
            b"width",
            key.value.width,
            b"depth",
            key.value.depth,
            b"count",
            key.value.elements_added,
        ]

    @command(name="CMS.INITBYDIM", fixed=(Key(CountMinSketch), Int, Int), repeat=(), flags=msgs.FLAG_DO_NOT_CREATE)
    def cms_initbydim(self, key: CommandItem, width: int, depth: int) -> SimpleString:
        if key.value is not None:
            raise SimpleError("CMS key already set")
        if width < 1:
            raise SimpleError("CMS: invalid width")
        if depth < 1:
            raise SimpleError("CMS: invalid depth")
        key.update(CountMinSketch(width=width, depth=depth))
        return OK

    @command(name="CMS.INITBYPROB", fixed=(Key(CountMinSketch), Float, Float), repeat=(), flags=msgs.FLAG_DO_NOT_CREATE)
    def cms_initby_prob(self, key: CommandItem, error_rate: float, probability: float) -> SimpleString:
        if key.value is not None:
            raise SimpleError("CMS key already set")
        if error_rate <= 0 or error_rate >= 1:
            raise SimpleError("CMS: invalid overestimation value")
        if probability <= 0 or probability >= 1:
            raise SimpleError("CMS: invalid prob value")
        key.update(CountMinSketch(probability=probability, error_rate=error_rate))
        return OK

    @command(name="CMS.MERGE", fixed=(Key(CountMinSketch), Int, bytes), repeat=(bytes,), flags=msgs.FLAG_DO_NOT_CREATE)
    def cms_merge(self, dest_key: CommandItem, num_keys: int, *args: bytes) -> SimpleString:
        if dest_key.value is None:
            raise SimpleError("CMS: key does not exist")

        if num_keys < 1:
            raise SimpleError("CMS: Number of keys must be positive")
        weights = [
            1,
        ]
        for i, arg in enumerate(args):
            if casematch(b"weights", arg):
                weights = [int(i) for i in args[i + 1 :]]
                if len(weights) != num_keys:
                    raise SimpleError("CMS: wrong number of keys/weights")
                args = args[:i]
                break
        dest_key.value.clear()
        for i, arg in enumerate(args):
            item = self._db.get(arg, None)
            if item is None or not isinstance(item.value, CountMinSketch):
                raise SimpleError("CMS: key does not exist")
            for _ in range(weights[i % len(weights)]):
                dest_key.value.join(item.value)
        return OK

    @command(name="CMS.QUERY", fixed=(Key(CountMinSketch), bytes), repeat=(bytes,), flags=msgs.FLAG_DO_NOT_CREATE)
    def cms_query(self, key: CommandItem, *items: bytes) -> List[int]:
        if key.value is None:
            raise SimpleError("CMS: key does not exist")
        return [key.value.check(item) for item in items]