1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
|
import itertools
import queue
import time
import weakref
from typing import List, Any, Tuple, Optional, Callable, Union, Match, AnyStr, Generator, Dict
from xmlrpc.client import ResponseError
import redis
from redis.connection import DefaultParser
from fakeredis.model import XStream, ZSet, Hash, ExpiringMembersSet
from . import _msgs as msgs
from ._command_args_parsing import extract_args
from ._commands import Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, Signature, CommandItem
from ._helpers import (
SimpleError,
valid_response_type,
SimpleString,
NoResponse,
casematch,
compile_pattern,
QUEUED,
decode_command_bytes,
)
def _extract_command(fields: List[bytes]) -> Tuple[Any, List[Any]]:
"""Extracts the command and command arguments from a list of `bytes` fields.
:param fields: A list of `bytes` fields containing the command and command arguments.
:return: A tuple of the command and command arguments.
Example:
```
fields = [b'GET', b'key1']
result = _extract_command(fields)
print(result) # ('GET', ['key1'])
```
"""
cmd = decode_command_bytes(fields[0])
if cmd in COMMANDS_WITH_SUB and len(fields) >= 2:
cmd += " " + decode_command_bytes(fields[1])
cmd_arguments = fields[2:]
else:
cmd_arguments = fields[1:]
return cmd, cmd_arguments
def bin_reverse(x: int, bits_count: int) -> int:
result = 0
for i in range(bits_count):
if (x >> i) & 1:
result |= 1 << (bits_count - 1 - i)
return result
class BaseFakeSocket:
_clear_watches: Callable[[], None]
ACCEPTED_COMMANDS_WHILE_PUBSUB = {
"ping",
"subscribe",
"unsubscribe",
"psubscribe",
"punsubscribe",
"ssubscribe",
"sunsubscribe",
}
_connection_error_class = redis.ConnectionError
def __init__(self, server: "FakeServer", db: int, *args: Any, **kwargs: Any) -> None: # type: ignore # noqa: F821
super(BaseFakeSocket, self).__init__(*args, **kwargs)
from fakeredis import FakeServer
self._server: FakeServer = server
self._db_num = db
self._db = server.dbs[self._db_num]
self.responses: Optional[queue.Queue[bytes]] = queue.Queue()
# Prevents parser from processing commands. Not used in this module,
# but set by aioredis module to prevent new commands being processed
# while handling a blocking command.
self._paused = False
self._parser = self._parse_commands()
self._parser.send(None)
# Assigned elsewhere
self._transaction: Optional[List[Any]]
self._in_transaction: bool
self._pubsub: int
self._transaction_failed: bool
info = kwargs.pop("client_info", dict(user="default"))
self._client_info: Dict[str, Union[str, int]] = {k.replace("_", "-"): v for k, v in info.items()}
self._server.sockets.append(self)
@property
def client_info(self):
res = {k: v for k, v in self._client_info.items() if not k.startswith("-")}
res["age"] = int(time.time()) - self._client_info.get("-created", 0)
return res
@property
def client_info_as_bytes(self) -> bytes:
return " ".join([f"{k}={v}" for k, v in self.client_info.items()]).encode()
@property
def current_user(self) -> bytes:
return self._client_info.get("user", "").encode()
@property
def protocol_version(self) -> int:
return self._client_info.get("resp", 2)
@property
def version(self) -> Tuple[int, ...]:
return self._server.version
@property
def server_type(self) -> str:
return self._server.server_type
def put_response(self, msg: Any) -> None:
"""Put a response message into the queue of responses.
:param msg: The response message.
"""
# redis.Connection.__del__ might call self.close at any time, which
# will set self.responses to None. We assume this will happen
# atomically, and the code below then protects us against this.
responses = self.responses
if responses:
responses.put(msg)
def pause(self) -> None:
self._paused = True
def resume(self) -> None:
self._paused = False
self._parser.send(b"")
def shutdown(self, _: Any) -> None:
self._parser.close()
@staticmethod
def fileno() -> int:
# Our fake socket must return an integer from `FakeSocket.fileno()` since a real selector
# will be created. The value does not matter since we replace the selector with our own
# `FakeSelector` before it is ever used.
return 0
def _cleanup(self, server: Any) -> None: # noqa: F821
"""Remove all the references to `self` from `server`.
This is called with the server lock held, but it may be some time after
self.close.
"""
for subs in server.subscribers.values():
subs.discard(self)
for subs in server.psubscribers.values():
subs.discard(self)
self._clear_watches()
def close(self) -> None:
# Mark ourselves for cleanup. This might be called from
# redis.Connection.__del__, which the garbage collection could call
# at any time, and hence we can't safely take the server lock.
# We rely on list.append being atomic.
self._server.sockets.remove(self)
self._server.closed_sockets.append(weakref.ref(self))
self._server = None # type: ignore
self._db = None
self.responses = None
@staticmethod
def _extract_line(buf: bytes) -> Tuple[bytes, bytes]:
pos = buf.find(b"\n") + 1
assert pos > 0
line = buf[:pos]
buf = buf[pos:]
assert line.endswith(b"\r\n")
return line, buf
def _parse_commands(self) -> Generator[None, Any, None]:
"""Generator that parses commands.
It is fed pieces of redis protocol data (via `send`) and calls
`_process_command` whenever it has a complete one.
"""
buf = b""
while True:
while self._paused or b"\n" not in buf:
buf += yield
line, buf = self._extract_line(buf)
assert line[:1] == b"*" # array
n_fields = int(line[1:-2])
fields = []
for i in range(n_fields):
while b"\n" not in buf:
buf += yield
line, buf = self._extract_line(buf)
assert line[:1] == b"$" # string
length = int(line[1:-2])
while len(buf) < length + 2:
buf += yield
fields.append(buf[:length])
buf = buf[length + 2 :] # +2 to skip the CRLF
self._process_command(fields)
def _process_command(self, fields: List[bytes]) -> None:
if not fields:
return
result: Any
cmd, cmd_arguments = _extract_command(fields)
try:
func, sig = self._name_to_func(cmd)
# ACL check
self._server.acl.validate_command(self.current_user, self.client_info_as_bytes, fields)
with self._server.lock:
# Clean out old connections
while True:
try:
weak_sock = self._server.closed_sockets.pop()
except IndexError:
break
else:
sock = weak_sock()
if sock:
sock._cleanup(self._server)
now = time.time()
for db in self._server.dbs.values():
db.time = now
sig.check_arity(cmd_arguments, self.version)
if self._transaction is not None and msgs.FLAG_TRANSACTION not in sig.flags:
self._transaction.append((func, sig, cmd_arguments))
result = QUEUED
else:
result = self._run_command(func, sig, cmd_arguments, False)
except SimpleError as exc:
if self._transaction is not None:
# TODO: should not apply if the exception is from _run_command
# e.g. watch inside multi
self._transaction_failed = True
if cmd == "exec" and exc.value.startswith("ERR "):
exc.value = "EXECABORT Transaction discarded because of: " + exc.value[4:]
self._transaction = None
self._transaction_failed = False
self._clear_watches()
result = exc
result = self._decode_result(result)
if not isinstance(result, NoResponse):
self.put_response(result)
def _run_command(
self, func: Optional[Callable[[Any], Any]], sig: Signature, args: List[Any], from_script: bool
) -> Any:
command_items: List[CommandItem] = []
try:
ret = sig.apply(args, self._db, self.version)
if from_script and msgs.FLAG_NO_SCRIPT in sig.flags:
raise SimpleError(msgs.COMMAND_IN_SCRIPT_MSG)
if self._pubsub and sig.name not in BaseFakeSocket.ACCEPTED_COMMANDS_WHILE_PUBSUB:
raise SimpleError(msgs.BAD_COMMAND_IN_PUBSUB_MSG)
if len(ret) == 1:
result = ret[0]
else:
args, command_items = ret
result = func(*args) # type: ignore
assert valid_response_type(result)
except SimpleError as exc:
result = exc
for command_item in command_items:
command_item.writeback(remove_empty_val=msgs.FLAG_LEAVE_EMPTY_VAL not in sig.flags)
return result
def _decode_error(self, error: SimpleError) -> ResponseError:
return DefaultParser(socket_read_size=65536).parse_error(error.value) # type: ignore
def _decode_result(self, result: Any) -> Any:
"""Convert SimpleString and SimpleError, recursively"""
if isinstance(result, list):
return [self._decode_result(r) for r in result]
elif isinstance(result, SimpleString):
return result.value
elif isinstance(result, SimpleError):
return self._decode_error(result)
else:
return result
def _blocking(self, timeout: Optional[Union[float, int]], func: Callable[[bool], Any]) -> Any:
"""Run a function until it succeeds or timeout is reached.
The timeout is in seconds, and 0 means infinite. The function
is called with a boolean to indicate whether this is the first call.
If it returns None, it is considered to have "failed" and is retried
each time the condition variable is notified, until the timeout is
reached.
Returns the function return value, or None if the timeout has passed.
"""
ret = func(True) # Call with first_pass=True
if ret is not None or self._in_transaction:
return ret
deadline = time.time() + timeout if timeout else None
while True:
timeout = (deadline - time.time()) if deadline is not None else None
if timeout is not None and timeout <= 0:
return None
if self._db.condition.wait(timeout=timeout) is False:
return None # Timeout expired
ret = func(False) # Second pass => first_pass=False
if ret is not None:
return ret
def _name_to_func(self, cmd_name: str) -> Tuple[Optional[Callable[[Any], Any]], Signature]:
"""Get the signature and the method from the command name."""
if cmd_name not in SUPPORTED_COMMANDS:
# redis remaps \r or \n in an error to ' ' to make it legal protocol
clean_name = cmd_name.replace("\r", " ").replace("\n", " ")
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name))
sig = SUPPORTED_COMMANDS[cmd_name]
if self._server.server_type not in sig.server_types:
# redis remaps \r or \n in an error to ' ' to make it legal protocol
clean_name = cmd_name.replace("\r", " ").replace("\n", " ")
raise SimpleError(msgs.UNKNOWN_COMMAND_MSG.format(clean_name))
func = getattr(self, sig.func_name, None)
return func, sig
def sendall(self, data: AnyStr) -> None:
if not self._server.connected:
raise self._connection_error_class(msgs.CONNECTION_ERROR_MSG)
if isinstance(data, str):
data = data.encode("ascii") # type: ignore
self._parser.send(data)
def _scan(self, keys, cursor, *args):
"""This is the basis of most of the ``scan`` methods.
This implementation is KNOWN to be un-performant, as it requires grabbing the full set of keys over which
we are investigating subsets.
The SCAN command, and the other commands in the SCAN family, are able to provide to the user a set of
guarantees associated with full iterations.
- A full iteration always retrieves all the elements that were present in the collection from the start to the
end of a full iteration. This means that if a given element is inside the collection when an iteration is
started and is still there when an iteration terminates, then at some point the SCAN command returned it to
the user.
- A full iteration never returns any element that was NOT present in the collection from the start to the end
of a full iteration. So if an element was removed before the start of an iteration and is never added back
to the collection for all the time an iteration lasts, the SCAN command ensures that this element will never
be returned.
However, because the SCAN command has very little state associated (just the cursor),
it has the following drawbacks:
- A given element may be returned multiple times. It is up to the application to handle the case of duplicated
elements, for example, only using the returned elements to perform operations that are safe when re-applied
multiple times.
- Elements that were not constantly present in the collection during a full iteration may be returned or not:
it is undefined.
"""
cursor = int(cursor)
(pattern, _type, count), _ = extract_args(args, ("*match", "*type", "+count"))
count = 10 if count is None else count
data = sorted(keys)
bits_len = (len(keys) - 1).bit_length()
cursor = bin_reverse(cursor, bits_len)
if cursor >= len(keys):
return [0, []]
result_cursor = cursor + count
result_data = []
regex = compile_pattern(pattern) if pattern is not None else None
def match_key(key: bytes) -> Union[bool, Match[bytes], None]:
return regex.match(key) if regex is not None else True
def match_type(key) -> bool:
return _type is None or casematch(BaseFakeSocket._key_value_type(self._db[key]).value, _type)
if pattern is not None or _type is not None:
for val in itertools.islice(data, cursor, cursor + count):
compare_val = val[0] if isinstance(val, tuple) else val
if match_key(compare_val) and match_type(compare_val):
result_data.append(val)
else:
result_data = data[cursor : cursor + count]
if result_cursor >= len(data):
result_cursor = 0
return [str(bin_reverse(result_cursor, bits_len)).encode(), result_data]
def _ttl(self, key: CommandItem, scale: float) -> int:
if not key:
return -2
elif key.expireat is None:
return -1
else:
return int(round((key.expireat - self._db.time) * scale))
def _encodefloat(self, value: float, humanfriendly: bool) -> bytes:
if self.version >= (7,):
value = 0 + value
return Float.encode(value, humanfriendly)
def _encodeint(self, value: int) -> bytes:
if self.version >= (7,):
value = 0 + value
return Int.encode(value)
@staticmethod
def _key_value_type(key: CommandItem) -> SimpleString:
if key.value is None:
return SimpleString(b"none")
elif isinstance(key.value, bytes):
return SimpleString(b"string")
elif isinstance(key.value, list):
return SimpleString(b"list")
elif isinstance(key.value, ExpiringMembersSet):
return SimpleString(b"set")
elif isinstance(key.value, ZSet):
return SimpleString(b"zset")
elif isinstance(key.value, Hash):
return SimpleString(b"hash")
elif isinstance(key.value, XStream):
return SimpleString(b"stream")
else:
assert False # pragma: nocover
|