import asyncio
import sys

if sys.version_info >= (3, 11):
    from asyncio import timeout as async_timeout
else:
    from async_timeout import timeout as async_timeout
import pytest
import pytest_asyncio
import redis
import redis.asyncio

from fakeredis import FakeServer, aioredis
from test import testtools

pytestmark = []
pytestmark.extend(
    [
        pytest.mark.asyncio,
    ]
)


@pytest_asyncio.fixture
async def conn(async_redis: redis.asyncio.Redis):
    """A single connection, rather than a pool."""
    async with async_redis.client() as conn:
        yield conn


async def test_ping(async_redis: redis.asyncio.Redis):
    pong = await async_redis.ping()
    assert pong is True


async def test_types(async_redis: redis.asyncio.Redis):
    await async_redis.hset("hash", mapping={"key1": "value1", "key2": "value2", "key3": 123})
    result = await async_redis.hgetall("hash")
    assert result == {b"key1": b"value1", b"key2": b"value2", b"key3": b"123"}


async def test_transaction(async_redis: redis.asyncio.Redis):
    async with async_redis.pipeline(transaction=True) as tr:
        tr.set("key1", "value1")
        tr.set("key2", "value2")
        ok1, ok2 = await tr.execute()
    assert ok1
    assert ok2
    result = await async_redis.get("key1")
    assert result == b"value1"


async def test_transaction_fail(async_redis: redis.asyncio.Redis):
    await async_redis.set("foo", "1")
    async with async_redis.pipeline(transaction=True) as tr:
        await tr.watch("foo")
        await async_redis.set("foo", "2")  # Different connection
        tr.multi()
        tr.get("foo")
        with pytest.raises(redis.asyncio.WatchError):
            await tr.execute()


async def test_pubsub(async_redis, event_loop):
    queue = asyncio.Queue()

    async def reader(ps):
        while True:
            message = await ps.get_message(ignore_subscribe_messages=True, timeout=5)
            if message is not None:
                if message.get("data") == b"stop":
                    break
                queue.put_nowait(message)

    async with async_timeout(5), async_redis.pubsub() as ps:
        await ps.subscribe("channel")
        task = event_loop.create_task(reader(ps))
        await async_redis.publish("channel", "message1")
        await async_redis.publish("channel", "message2")
        result1 = await queue.get()
        result2 = await queue.get()
        assert result1 == {"channel": b"channel", "pattern": None, "type": "message", "data": b"message1"}
        assert result2 == {"channel": b"channel", "pattern": None, "type": "message", "data": b"message2"}
        await async_redis.publish("channel", "stop")
        await task


@pytest.mark.slow
async def test_pubsub_timeout(async_redis: redis.asyncio.Redis):
    async with async_redis.pubsub() as ps:
        await ps.subscribe("channel")
        await ps.get_message(timeout=0.5)  # Subscription message
        message = await ps.get_message(timeout=0.5)
        assert message is None


@pytest.mark.slow
async def test_pubsub_disconnect(async_redis: redis.asyncio.Redis):
    async with async_redis.pubsub() as ps:
        await ps.subscribe("channel")
        await ps.connection.disconnect()
        message = await ps.get_message(timeout=0.5)  # Subscription message
        assert message is not None
        message = await ps.get_message(timeout=0.5)
        assert message is None


async def test_blocking_ready(async_redis, conn):
    """Blocking command which does not need to block."""
    await async_redis.rpush("list", "x")
    result = await conn.blpop("list", timeout=1)
    assert result == (b"list", b"x")


@pytest.mark.slow
async def test_blocking_timeout(conn):
    """Blocking command that times out without completing."""
    result = await conn.blpop("missing", timeout=1)
    assert result is None


@pytest.mark.slow
async def test_blocking_unblock(async_redis, conn, event_loop):
    """Blocking command that gets unblocked after some time."""

    async def unblock():
        await asyncio.sleep(0.1)
        await async_redis.rpush("list", "y")

    task = event_loop.create_task(unblock())
    result = await conn.blpop("list", timeout=1)
    assert result == (b"list", b"y")
    await task


async def test_wrongtype_error(async_redis: redis.asyncio.Redis):
    await async_redis.set("foo", "bar")
    with pytest.raises(redis.asyncio.ResponseError, match="^WRONGTYPE"):
        await async_redis.rpush("foo", "baz")


async def test_syntax_error(async_redis: redis.asyncio.Redis):
    with pytest.raises(redis.asyncio.ResponseError, match="^wrong number of arguments for 'get' command$"):
        await async_redis.execute_command("get")


@pytest.mark.decode_responses
async def test_never_decode(async_redis: redis.asyncio.Redis):
    assert async_redis.connection_pool.get_encoder().decode_responses

    await async_redis.execute_command("set", "key", "some ascii")
    text = await async_redis.execute_command("get", "key")
    assert isinstance(text, str)
    bytestr = await async_redis.execute_command("get", "key", NEVER_DECODE=True)
    assert isinstance(bytestr, bytes)


@testtools.run_test_if_lupa
class TestScripts:
    async def test_no_script_error(self, async_redis: redis.asyncio.Redis):
        with pytest.raises(redis.exceptions.NoScriptError):
            await async_redis.evalsha("0123456789abcdef0123456789abcdef", 0)

    @pytest.mark.max_server("6.2.7")
    async def test_failed_script_error6(self, async_redis):
        await async_redis.set("foo", "bar")
        with pytest.raises(redis.asyncio.ResponseError, match="^Error running script"):
            await async_redis.eval('return redis.call("ZCOUNT", KEYS[1])', 1, "foo")

    @pytest.mark.min_server("7")
    async def test_failed_script_error7(self, async_redis):
        await async_redis.set("foo", "bar")
        with pytest.raises(redis.asyncio.ResponseError):
            await async_redis.eval('return redis.call("ZCOUNT", KEYS[1])', 1, "foo")


async def test_type(async_redis: redis.asyncio.Redis):
    await async_redis.set("string_key", "value")
    await async_redis.lpush("list_key", "value")
    await async_redis.sadd("set_key", "value")
    await async_redis.zadd("zset_key", {"value": 1})
    await async_redis.hset("hset_key", "key", "value")

    assert b"string" == await async_redis.type("string_key")  # noqa: E721
    assert b"list" == await async_redis.type("list_key")  # noqa: E721
    assert b"set" == await async_redis.type("set_key")  # noqa: E721
    assert b"zset" == await async_redis.type("zset_key")  # noqa: E721
    assert b"hash" == await async_redis.type("hset_key")  # noqa: E721
    assert b"none" == await async_redis.type("none_key")  # noqa: E721


async def test_xdel(async_redis: redis.asyncio.Redis):
    stream = "stream"

    # deleting from an empty stream doesn't do anything
    assert await async_redis.xdel(stream, 1) == 0

    m1 = await async_redis.xadd(stream, {"foo": "bar"})
    m2 = await async_redis.xadd(stream, {"foo": "bar"})
    m3 = await async_redis.xadd(stream, {"foo": "bar"})

    # xdel returns the number of deleted elements
    assert await async_redis.xdel(stream, m1) == 1
    assert await async_redis.xdel(stream, m2, m3) == 2


async def test_connection_with_username_and_password():
    server = FakeServer()
    r = aioredis.FakeRedis(server=server, username="username", password="password")

    test_value = "this_is_a_test"
    await r.hset("test:key", "test_hash", test_value)
    result = await r.hget("test:key", "test_hash")
    assert result.decode() == test_value


@pytest.mark.asyncio
async def test_cause_fakeredis_bug(async_redis):
    if sys.version_info < (3, 11):
        return

    async def worker_task():
        assert await async_redis.rpush("list1", "list1_val") == 1  # 1
        assert await async_redis.blpop("list2") == (b"list2", b"list2_val")  # 4
        assert await async_redis.set("foo", "bar") is True  # 5

    async with asyncio.TaskGroup() as tg:
        tg.create_task(worker_task())
        assert await async_redis.blpop("list1") == (b"list1", b"list1_val")  # 2
        assert await async_redis.rpush("list2", "list2_val") == 1  # 3

    # await async_redis.get("foo")  # uncomment to make test pass
    assert await async_redis.get("foo") == b"bar"
