1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
import logging
from dataclasses import dataclass
from itertools import count
from socketserver import ThreadingTCPServer, StreamRequestHandler
from typing import BinaryIO, Dict, Tuple
from fakeredis import FakeRedis
from fakeredis import FakeServer
from fakeredis._server import ServerType
LOGGER = logging.getLogger("fakeredis")
LOGGER.setLevel(logging.DEBUG)
def to_bytes(value) -> bytes:
if isinstance(value, bytes):
return value
return str(value).encode()
@dataclass
class Client:
connection: FakeRedis
client_address: int
@dataclass
class Reader:
reader: BinaryIO
def load(self):
line = self.reader.readline().strip()
match line[0:1], line[1:]:
case b"*", length:
length = int(length)
array = [None] * length
for i in range(length):
array[i] = self.load()
return array
case b"$", length:
bulk_string = self.reader.read(int(length) + 2).strip()
if len(bulk_string) != int(length):
raise ValueError()
return bulk_string
case b":", value:
return int(value)
case b"+", value:
return value
case b"-", value:
return Exception(value)
case _:
return None
@dataclass
class Writer:
writer: BinaryIO
def dump(self, value, dump_bulk=False):
if isinstance(value, int):
self.writer.write(f":{value}\r\n".encode())
elif isinstance(value, (str, bytes)):
value = to_bytes(value)
if dump_bulk or b"\r" in value or b"\n" in value:
self.writer.write(b"$" + str(len(value)).encode() + b"\r\n" + value + b"\r\n")
else:
self.writer.write(b"+" + value + b"\r\n")
elif isinstance(value, (list, set)):
self.writer.write(f"*{len(value)}\r\n".encode())
for item in value:
self.dump(item, dump_bulk=True)
elif value is None:
self.writer.write("$-1\r\n".encode())
elif isinstance(value, Exception):
self.writer.write(f"-{value.args[0]}\r\n".encode())
class TCPFakeRequestHandler(StreamRequestHandler):
def setup(self) -> None:
super().setup()
if self.client_address in self.server.clients:
self.current_client = self.server.clients[self.client_address]
else:
self.current_client = Client(
connection=FakeRedis(server=self.server.fake_server),
client_address=self.client_address,
)
self.reader = Reader(self.rfile)
self.writer = Writer(self.wfile)
self.server.clients[self.client_address] = self.current_client
def handle(self):
while True:
try:
self.data = self.reader.load()
LOGGER.debug(f">>> {self.client_address[0]}: {self.data}")
res = self.current_client.connection.execute_command(*self.data)
LOGGER.debug(f"<<< {self.client_address[0]}: {res}")
self.writer.dump(res)
except Exception as e:
LOGGER.debug(f"!!! {self.client_address[0]}: {e}")
self.writer.dump(e)
break
def finish(self) -> None:
del self.server.clients[self.current_client.client_address]
super().finish()
class TcpFakeServer(ThreadingTCPServer):
def __init__(
self,
server_address: Tuple[str | bytes | bytearray, int],
bind_and_activate: bool = True,
server_type: ServerType = "redis",
server_version: Tuple[int, ...] = (7, 4),
):
super().__init__(server_address, TCPFakeRequestHandler, bind_and_activate)
self.fake_server = FakeServer(server_type=server_type, version=server_version)
self.client_ids = count(0)
self.clients: Dict[int, FakeRedis] = dict()
if __name__ == "__main__":
server = TcpFakeServer(("localhost", 19000))
server.serve_forever()
|