1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
from typing import Callable, Tuple, Union, Optional
import pytest
import pytest_asyncio
import redis
import fakeredis
from fakeredis._server import _create_version
def _check_lua_module_supported() -> bool:
redis = fakeredis.FakeRedis(lua_modules={"cjson"})
try:
redis.eval("return cjson.encode({})", 0)
return True
except Exception:
return False
@pytest_asyncio.fixture(scope="session")
def real_server_details() -> Tuple[str, Union[None, Tuple[int, ...]]]:
"""Returns server's version or exit if server is not running"""
client = None
try:
client = redis.Redis("localhost", port=6390, db=2)
client_info = client.info()
server_type = "dragonfly" if "dragonfly_version" in client_info else "redis"
if "server_name" in client_info:
server_type = client_info["server_name"]
server_version = client_info["redis_version"] if server_type != "dragonfly" else (7, 0)
server_version = _create_version(server_version) or (7,)
return server_type, server_version
except redis.ConnectionError as e:
pytest.exit(f"Real server is not running {e}")
return "redis", (6,)
finally:
if hasattr(client, "close"):
client.close() # Absent in older versions of redis-py
@pytest_asyncio.fixture(name="fake_server")
def _fake_server(request, real_server_details: Tuple[str, Tuple[int, ...]]) -> fakeredis.FakeServer:
server_type, server_version = real_server_details
server = fakeredis.FakeServer(server_type=server_type, version=server_version)
server.connected = request.node.get_closest_marker("disconnected") is None
return server
@pytest_asyncio.fixture
def r(request, create_connection: Callable[[int], redis.Redis]) -> redis.Redis:
rconn = create_connection(db=2)
connected = request.node.get_closest_marker("disconnected") is None
if connected:
rconn.flushall()
yield rconn
if connected:
rconn.flushall()
if hasattr(r, "close"):
rconn.close() # Older versions of redis-py don't have this method
def _marker_version_value(request, marker_name: str):
marker_value = request.node.get_closest_marker(marker_name)
if marker_value is None:
return (0,) if marker_name == "min_server" else (100,)
return _create_version(marker_value.args[0])
@pytest_asyncio.fixture(
name="create_connection",
params=[
pytest.param("StrictRedis", marks=pytest.mark.real),
pytest.param("FakeStrictRedis", marks=pytest.mark.fake),
],
)
def _create_connection(request) -> Callable[[int], redis.Redis]:
cls_name = request.param
server_type, server_version = request.getfixturevalue("real_server_details")
if not cls_name.startswith("Fake") and not server_version:
pytest.skip("Redis is not running")
unsupported_server_types = request.node.get_closest_marker("unsupported_server_types")
if unsupported_server_types and server_type in unsupported_server_types.args:
pytest.skip(f"Server type {server_type} is not supported")
min_server = _marker_version_value(request, "min_server")
max_server = _marker_version_value(request, "max_server")
if server_version < min_server:
pytest.skip(f"Redis server {min_server} or more required but {server_version} found")
if server_version > max_server:
pytest.skip(f"Redis server {max_server} or less required but {server_version} found")
decode_responses = request.node.get_closest_marker("decode_responses") is not None
lua_modules_marker = request.node.get_closest_marker("load_lua_modules")
lua_modules = set(lua_modules_marker.args) if lua_modules_marker else None
if lua_modules and not _check_lua_module_supported():
pytest.skip("LUA modules not supported by fakeredis")
def factory(db=2):
if cls_name.startswith("Fake"):
fake_server = request.getfixturevalue("fake_server")
cls = getattr(fakeredis, cls_name)
return cls(db=db, decode_responses=decode_responses, server=fake_server, lua_modules=lua_modules)
# Real
cls = getattr(redis, cls_name)
return cls("localhost", port=6390, db=db, decode_responses=decode_responses)
return factory
@pytest_asyncio.fixture(
name="async_redis",
params=[pytest.param("fake", marks=pytest.mark.fake), pytest.param("real", marks=pytest.mark.real)],
)
async def _req_aioredis2(request) -> redis.asyncio.Redis:
server_type, server_version = request.getfixturevalue("real_server_details")
if request.param != "fake" and not server_version:
pytest.skip("Redis is not running")
decode_responses = bool(request.node.get_closest_marker("decode_responses"))
unsupported_server_types = request.node.get_closest_marker("unsupported_server_types")
if unsupported_server_types and server_type in unsupported_server_types.args:
pytest.skip(f"Server type {server_type} is not supported")
min_server_marker = _marker_version_value(request, "min_server")
max_server_marker = _marker_version_value(request, "max_server")
if server_version < min_server_marker:
pytest.skip(f"Redis server {min_server_marker} or more required but {server_version} found")
if server_version > max_server_marker:
pytest.skip(f"Redis server {max_server_marker} or less required but {server_version} found")
lua_modules_marker = request.node.get_closest_marker("load_lua_modules")
lua_modules = set(lua_modules_marker.args) if lua_modules_marker else None
if lua_modules and not _check_lua_module_supported():
pytest.skip("LUA modules not supported by fakeredis")
fake_server: Optional[fakeredis.FakeServer]
if request.param == "fake":
fake_server = request.getfixturevalue("fake_server")
ret = fakeredis.FakeAsyncRedis(server=fake_server, lua_modules=lua_modules, decode_responses=decode_responses)
else:
ret = redis.asyncio.Redis(host="localhost", port=6390, db=2, decode_responses=decode_responses)
fake_server = None
if not fake_server or fake_server.connected:
await ret.flushall()
yield ret
if not fake_server or fake_server.connected:
await ret.flushall()
await ret.connection_pool.disconnect()
|