1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
|
from typing import Any, List, Union, Dict
import fakeredis
from fakeredis import _msgs as msgs
from fakeredis._commands import command, DbIndex, Int
from fakeredis._helpers import SimpleError, OK, SimpleString, Database, casematch
PONG = SimpleString(b"PONG")
class ConnectionCommandsMixin:
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(ConnectionCommandsMixin, self).__init__(*args, **kwargs)
self._db: Database
self._db_num: int
self._pubsub: int
self._client_info: Dict[str, Union[str, int]]
self._server: Any
@command((bytes,))
def echo(self, message: bytes) -> bytes:
return message
@command((), (bytes,))
def ping(self, *args: bytes) -> Union[List[bytes], bytes, SimpleString]:
if len(args) > 1:
msg = msgs.WRONG_ARGS_MSG6.format("ping")
raise SimpleError(msg)
if self._pubsub:
return [b"pong", args[0] if args else b""]
else:
return args[0] if args else PONG
@command(name="SELECT", fixed=(DbIndex,))
def select(self, index: DbIndex) -> SimpleString:
self._db = self._server.dbs[index]
self._db_num = index # type: ignore
return OK
@command(name="CLIENT SETINFO", fixed=(bytes, bytes), repeat=())
def client_setinfo(self, lib_data: bytes, value: bytes) -> SimpleString:
if casematch(lib_data, b"LIB-NAME"):
self._client_info["lib-name"] = value.decode("utf-8")
return OK
if casematch(lib_data, b"LIB-VER"):
self._client_info["lib-ver"] = value.decode("utf-8")
return OK
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
@command(name="CLIENT SETNAME", fixed=(bytes,), repeat=())
def client_setname(self, value: bytes) -> SimpleString:
self._client_info["name"] = value.decode("utf-8")
return OK
@command(name="CLIENT GETNAME", fixed=(), repeat=())
def client_getname(self) -> bytes:
return self._client_info.get("name", "").encode("utf-8")
@command(name="CLIENT ID", fixed=(), repeat=())
def client_getid(self) -> int:
return self._client_info.get("id", 1)
@command(name="CLIENT INFO", fixed=(), repeat=())
def client_info_cmd(self) -> bytes:
return self.client_info_as_bytes
@command(name="CLIENT LIST", fixed=(), repeat=(bytes,))
def client_list_cmd(self, *args: bytes) -> bytes:
sockets = self._server.sockets.copy()
i = 0
filter_ids = set()
while i < len(args):
if casematch(args[i], b"TYPE") and i + 1 < len(args):
i += 2
if casematch(args[i], b"ID") and i + 1 < len(args):
i += 1
while i < len(args):
filter_ids.add(Int.decode(args[i]))
i += 1
else:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
if len(filter_ids) > 0:
sockets = [sock for sock in sockets if sock._client_info["id"] in filter_ids]
res = [item.client_info_as_bytes for item in sockets]
return b"\n".join(res)
@command(name="HELLO", fixed=(), repeat=(bytes,))
def hello(self, *args: bytes) -> List[bytes]:
self._client_info["resp"] = 2 if len(args) == 0 else Int.decode(args[0])
i = 1
while i < len(args):
if args[i] == b"SETNAME" and i + 1 < len(args):
self._client_info["name"] = args[i + 1].decode("utf-8")
i += 2
elif args[i] == b"AUTH" and i + 2 < len(args):
user = args[i + 1]
password = args[i + 2]
self._server._acl.get_user_acl(user).check_password(password)
i += 3
else:
raise SimpleError(msgs.SYNTAX_ERROR_MSG)
data = dict(
server="fakeredis",
version=fakeredis.__version__,
proto=self._client_info["resp"],
id=self._client_info.get("id", 1),
mode="standalone",
role="master",
modules=[],
)
return data
|