1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
import secrets
from typing import Any, Tuple, List, Callable, Dict, Optional, Union
from fakeredis import _msgs as msgs
from fakeredis._commands import command, Int
from fakeredis._helpers import SimpleError, OK, casematch, SimpleString
from fakeredis.model import AccessControlList
from fakeredis.model import get_categories, get_commands_by_category
class AclCommandsMixin:
_get_command_info: Callable[[bytes], List[Any]]
def __init(self, *args: Any, **kwargs: Any) -> None:
super(AclCommandsMixin).__init__(*args, **kwargs)
self.version: Tuple[int]
self._server: Any
self._client_info: Dict[str, Union[str, int]]
@property
def _server_config(self) -> Dict[bytes, bytes]:
return self._server.config
@property
def _acl(self) -> AccessControlList:
return self._server.acl
def _check_user_password(self, username: bytes, password: Optional[bytes]) -> bool:
return self._acl.get_user_acl(username).check_password(password)
def _set_user_acl(self, username: bytes, *args: bytes) -> None:
user_acl = self._acl.get_user_acl(username)
for arg in args:
if casematch(arg, b"resetchannels"):
user_acl.reset_channels_patterns()
continue
elif casematch(arg, b"resetkeys"):
user_acl.reset_key_patterns()
continue
elif casematch(arg, b"on"):
user_acl.enabled = True
continue
elif casematch(arg, b"off"):
user_acl.enabled = False
continue
elif casematch(arg, b"nopass"):
user_acl.set_nopass()
continue
elif casematch(arg, b"reset"):
user_acl.reset()
continue
elif casematch(arg, b"nocommands"):
arg = b"-@all"
elif casematch(arg, b"allcommands"):
arg = b"+@all"
elif casematch(arg, b"allkeys"):
arg = b"~*"
elif casematch(arg, b"allchannels"):
arg = b"&*"
elif arg[0] == ord("(") and arg[-1] == ord(")"):
user_acl.add_selector(arg[1:-1])
continue
prefix = arg[0]
if prefix == ord(">"):
user_acl.add_password(arg[1:])
elif prefix == ord("<"):
user_acl.remove_password(arg[1:])
elif prefix == ord("#"):
user_acl.add_password_hex(arg[1:])
elif prefix == ord("!"):
user_acl.remove_password_hex(arg[1:])
elif prefix == ord("+") or prefix == ord("-"):
user_acl.add_command_or_category(arg)
elif prefix == ord("~"):
user_acl.add_key_pattern(arg[1:])
elif prefix == ord("&"):
user_acl.add_channel_pattern(arg[1:])
@command(name="CONFIG SET", fixed=(bytes, bytes), repeat=(bytes, bytes))
def config_set(self, *args: bytes):
if len(args) % 2 != 0:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("CONFIG SET"))
for i in range(0, len(args), 2):
self._server_config[args[i]] = args[i + 1]
return OK
@command(name="AUTH", fixed=(), repeat=(bytes,))
def _auth(self, *args: bytes) -> bytes:
if not 1 <= len(args) <= 2:
raise SimpleError(msgs.WRONG_ARGS_MSG6.format("AUTH"))
username = None if len(args) == 1 else args[0]
password = args[1] if len(args) == 2 else args[0]
if (username is None or username == b"default") and (password == self._server_config.get(b"requirepass", b"")):
self._client_info["user"] = "default"
return OK
username = username or b"default"
if len(args) >= 1 and self._check_user_password(username, password):
self._client_info["user"] = username.decode()
return OK
self._acl.add_log_record(b"auth", b"auth", b"AUTH", username, self.client_info_as_bytes)
raise SimpleError(msgs.AUTH_FAILURE)
@command(name="ACL CAT", fixed=(), repeat=(bytes,))
def acl_cat(self, *category: bytes) -> List[bytes]:
if len(category) == 0:
res = get_categories()
else:
res = get_commands_by_category(category[0])
res = [cmd.replace(b" ", b"|") for cmd in res]
return res
@command(name="ACL GENPASS", fixed=(), repeat=(bytes,))
def acl_genpass(self, *args: bytes) -> bytes:
bits = Int.decode(args[0]) if len(args) > 0 else 256
bits = bits + bits % 4 # Round to 4
nbytes: int = bits // 8
return secrets.token_hex(nbytes).encode()
@command(name="ACL SETUSER", fixed=(bytes,), repeat=(bytes,))
def acl_setuser(self, username: bytes, *args: bytes) -> bytes:
self._set_user_acl(username, *args)
return OK
@command(name="ACL LIST", fixed=(), repeat=())
def acl_list(self) -> List[bytes]:
return self._acl.as_rules()
@command(name="ACL DELUSER", fixed=(bytes,), repeat=())
def acl_deluser(self, username: bytes) -> bytes:
self._acl.del_user(username)
return OK
@command(name="ACL GETUSER", fixed=(bytes,), repeat=())
def acl_getuser(self, username: bytes) -> List[bytes]:
res = self._acl.get_user_acl(username).as_array()
return res
@command(name="ACL USERS", fixed=(), repeat=())
def acl_users(self) -> List[bytes]:
res = self._acl.get_users()
return res
@command(name="ACL WHOAMI", fixed=(), repeat=())
def acl_whoami(self) -> bytes:
return self._client_info.get("user", "").encode()
@command(name="ACL SAVE", fixed=(), repeat=())
def acl_save(self) -> SimpleString:
if b"aclfile" not in self._server_config:
raise SimpleError(msgs.MISSING_ACLFILE_CONFIG)
acl_filename = self._server_config[b"aclfile"]
with open(acl_filename, "wb") as f:
f.write(b"\n".join(self._acl.as_rules()))
return OK
@command(name="ACL LOAD", fixed=(), repeat=())
def acl_load(self) -> SimpleString:
if b"aclfile" not in self._server_config:
raise SimpleError(msgs.MISSING_ACLFILE_CONFIG)
acl_filename = self._server_config[b"aclfile"]
with open(acl_filename, "rb") as f:
rules_list = f.readlines()
for rule in rules_list:
if not rule.startswith(b"user "):
continue
splitted = rule.split(b" ")
components = list()
i = 1
while i < len(splitted):
current_component = splitted[i]
if current_component.startswith(b"("):
while not current_component.endswith(b")"):
i += 1
current_component += b" " + splitted[i]
components.append(current_component)
i += 1
self._set_user_acl(components[0], *components[1:])
return OK
@command(name="ACL LOG", fixed=(), repeat=(bytes,))
def acl_log(self, *args: bytes) -> Union[SimpleString, List[List[bytes]]]:
if len(args) == 1 and casematch(args[0], b"RESET"):
self._acl.reset_log()
return OK
count = Int.decode(args[0]) if len(args) == 1 else 0
return self._acl.log(count)
|