1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
|
"""Command mixin for emulating `redis-py`'s Count-min sketch functionality."""
from typing import Optional, Tuple, List, Any
import probables
from fakeredis import _msgs as msgs
from fakeredis._commands import command, CommandItem, Int, Key, Float
from fakeredis._helpers import OK, SimpleString, SimpleError, casematch, Database
class CountMinSketch(probables.CountMinSketch):
def __init__(
self,
width: Optional[int] = None,
depth: Optional[int] = None,
probability: Optional[float] = None,
error_rate: Optional[float] = None,
):
super().__init__(width=width, depth=depth, error_rate=error_rate, confidence=probability)
class CMSCommandsMixin:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._db: Database
@command(
name="CMS.INCRBY",
fixed=(Key(CountMinSketch), bytes, bytes),
repeat=(
bytes,
bytes,
),
flags=msgs.FLAG_DO_NOT_CREATE,
)
def cms_incrby(self, key: CommandItem, *args: bytes) -> List[Tuple[bytes, int]]:
if key.value is None:
raise SimpleError("CMS: key does not exist")
pairs: List[Tuple[bytes, int]] = []
for i in range(0, len(args), 2):
try:
pairs.append((args[i], int(args[i + 1])))
except ValueError:
raise SimpleError("CMS: Cannot parse number")
res = []
for pair in pairs:
res.append(key.value.add(pair[0], pair[1]))
key.updated()
return res
@command(
name="CMS.INFO",
fixed=(Key(CountMinSketch),),
repeat=(),
flags=msgs.FLAG_DO_NOT_CREATE,
)
def cms_info(self, key: CommandItem) -> List[bytes]:
if key.value is None:
raise SimpleError("CMS: key does not exist")
return [
b"width",
key.value.width,
b"depth",
key.value.depth,
b"count",
key.value.elements_added,
]
@command(name="CMS.INITBYDIM", fixed=(Key(CountMinSketch), Int, Int), repeat=(), flags=msgs.FLAG_DO_NOT_CREATE)
def cms_initbydim(self, key: CommandItem, width: int, depth: int) -> SimpleString:
if key.value is not None:
raise SimpleError("CMS key already set")
if width < 1:
raise SimpleError("CMS: invalid width")
if depth < 1:
raise SimpleError("CMS: invalid depth")
key.update(CountMinSketch(width=width, depth=depth))
return OK
@command(name="CMS.INITBYPROB", fixed=(Key(CountMinSketch), Float, Float), repeat=(), flags=msgs.FLAG_DO_NOT_CREATE)
def cms_initby_prob(self, key: CommandItem, error_rate: float, probability: float) -> SimpleString:
if key.value is not None:
raise SimpleError("CMS key already set")
if error_rate <= 0 or error_rate >= 1:
raise SimpleError("CMS: invalid overestimation value")
if probability <= 0 or probability >= 1:
raise SimpleError("CMS: invalid prob value")
key.update(CountMinSketch(probability=probability, error_rate=error_rate))
return OK
@command(name="CMS.MERGE", fixed=(Key(CountMinSketch), Int, bytes), repeat=(bytes,), flags=msgs.FLAG_DO_NOT_CREATE)
def cms_merge(self, dest_key: CommandItem, num_keys: int, *args: bytes) -> SimpleString:
if dest_key.value is None:
raise SimpleError("CMS: key does not exist")
if num_keys < 1:
raise SimpleError("CMS: Number of keys must be positive")
weights = [
1,
]
for i, arg in enumerate(args):
if casematch(b"weights", arg):
weights = [int(i) for i in args[i + 1 :]]
if len(weights) != num_keys:
raise SimpleError("CMS: wrong number of keys/weights")
args = args[:i]
break
dest_key.value.clear()
for i, arg in enumerate(args):
item = self._db.get(arg, None)
if item is None or not isinstance(item.value, CountMinSketch):
raise SimpleError("CMS: key does not exist")
for _ in range(weights[i % len(weights)]):
dest_key.value.join(item.value)
return OK
@command(name="CMS.QUERY", fixed=(Key(CountMinSketch), bytes), repeat=(bytes,), flags=msgs.FLAG_DO_NOT_CREATE)
def cms_query(self, key: CommandItem, *items: bytes) -> List[int]:
if key.value is None:
raise SimpleError("CMS: key does not exist")
return [key.value.check(item) for item in items]
|