from __future__ import annotations

import asyncio
import logging
import socket
from datetime import timedelta
from functools import partial
from typing import Callable, cast
from unittest.mock import AsyncMock, MagicMock, call, create_autospec, patch

import pytest
from google.protobuf import message

from aioesphomeapi import APIClient
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import (
    DeviceInfoResponse,
    DisconnectRequest,
    HelloResponse,
    PingRequest,
    PingResponse,
    TextSensorStateResponse,
)
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import (
    APIConnectionCancelledError,
    APIConnectionError,
    ConnectionNotEstablishedAPIError,
    HandshakeAPIError,
    InvalidAuthAPIError,
    ReadFailedAPIError,
    RequiresEncryptionAPIError,
    ResolveAPIError,
    SocketAPIError,
    SocketClosedAPIError,
    TimeoutAPIError,
)

from .common import (
    async_fire_time_changed,
    connect,
    connect_client,
    generate_plaintext_packet,
    get_mock_protocol,
    mock_data_received,
    send_ping_request,
    send_ping_response,
    send_plaintext_connect_response,
    send_plaintext_hello,
    utcnow,
)
from .conftest import KEEP_ALIVE_INTERVAL, _create_mock_transport_protocol

KEEP_ALIVE_TIMEOUT_RATIO = 4.5


@pytest.mark.asyncio
async def test_connect(
    plaintext_connect_task_no_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ]
) -> None:
    """Test that a plaintext connection works."""
    conn, transport, protocol, connect_task = plaintext_connect_task_no_login
    mock_data_received(
        protocol,
        bytes.fromhex(
            "003602080110091a216d6173746572617672656c61792028657"
            "370686f6d652076323032332e362e3329220d6d617374657261"
            "7672656c6179"
        ),
    )
    mock_data_received(
        protocol,
        bytes.fromhex(
            "005b0a120d6d6173746572617672656c61791a1130383a33413a"
            "46323a33453a35453a36302208323032332e362e332a154a756e"
            "20323820323032332c2031383a31323a3236320965737033322d"
            "65766250506209457370726573736966"
        ),
    )
    await connect_task
    assert conn.is_connected


@pytest.mark.asyncio
async def test_timeout_sending_message(
    plaintext_connect_task_no_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
    caplog: pytest.LogCaptureFixture,
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_no_login

    mock_data_received(
        protocol,
        b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
        b"5stackatomproxy"
        b"\x00\x00$"
        b"\x00\x00\x04"
        b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
        b"ev*\x15Jan  7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif",
    )

    await connect_task

    with pytest.raises(TimeoutAPIError):
        await conn.send_messages_await_response_complex(
            (PingRequest(),), None, None, (PingResponse,), 0
        )

    transport.reset_mock()
    with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0):
        await conn.disconnect()

    transport.write.assert_called_with(b"\x00\x00\x05")

    assert "disconnect request failed" in caplog.text
    assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text


@pytest.mark.asyncio
async def test_disconnect_when_not_fully_connected(
    plaintext_connect_task_no_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
    caplog: pytest.LogCaptureFixture,
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_no_login

    # Only send the first part of the handshake
    # so we are stuck in the middle of the connection process
    mock_data_received(
        protocol,
        b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
    )

    await asyncio.sleep(0)
    transport.reset_mock()

    with (
        patch("aioesphomeapi.connection.DISCONNECT_CONNECT_TIMEOUT", 0.0),
        patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0),
    ):
        await conn.disconnect()

    with pytest.raises(
        APIConnectionError,
        match="Timed out waiting to finish connect before disconnecting",
    ):
        await connect_task

    transport.write.assert_called_with(b"\x00\x00\x05")

    assert "disconnect request failed" in caplog.text
    assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text


@pytest.mark.asyncio
async def test_requires_encryption_propagates(conn: APIConnection):
    loop = asyncio.get_event_loop()
    protocol = get_mock_protocol(conn)
    with patch.object(loop, "create_connection") as create_connection:
        create_connection.return_value = (MagicMock(), protocol)

        conn._socket = MagicMock()
        await conn._connect_init_frame_helper()
        loop.call_soon(conn._frame_helper.ready_future.set_result, None)
        conn.connection_state = ConnectionState.CONNECTED

        with pytest.raises(RequiresEncryptionAPIError):
            task = asyncio.create_task(conn._connect_hello_login(login=True))
            await asyncio.sleep(0)
            mock_data_received(protocol, b"\x01\x00\x00")
            await task

    await asyncio.sleep(0)
    await asyncio.sleep(0)
    assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
    conn.force_disconnect()
    assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
    conn.report_fatal_error(Exception("test"))
    assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)


@pytest.mark.asyncio
async def test_plaintext_connection(
    plaintext_connect_task_no_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test that a plaintext connection works."""
    messages = []
    conn, transport, protocol, connect_task = plaintext_connect_task_no_login

    def on_msg(msg):
        messages.append(msg)

    remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
    mock_data_received(
        protocol,
        b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
    )
    mock_data_received(protocol, b"5stackatomproxy")
    mock_data_received(protocol, b"\x00\x00$")
    mock_data_received(protocol, b"\x00\x00\x04")
    mock_data_received(
        protocol,
        b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
    )
    mock_data_received(
        protocol, b"ev*\x15Jan  7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
    )
    await asyncio.sleep(0)
    await connect_task
    assert conn.is_connected
    assert len(messages) == 2
    assert isinstance(messages[0], HelloResponse)
    assert isinstance(messages[1], DeviceInfoResponse)
    assert messages[1].name == "m5stackatomproxy"
    remove()
    conn.force_disconnect()
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_start_connection_socket_error(
    conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
    """Test handling of socket error during start connection."""
    loop = asyncio.get_event_loop()

    with patch.object(loop, "create_connection", side_effect=OSError("Socket error")):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await asyncio.sleep(0)
        with pytest.raises(APIConnectionError, match="Socket error"):
            await connect_task

    async_fire_time_changed(utcnow() + timedelta(seconds=600))
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_start_connection_cannot_increase_recv_buffer(
    conn: APIConnection,
    resolve_host,
    aiohappyeyeballs_start_connection: MagicMock,
    caplog: pytest.LogCaptureFixture,
):
    """Test failing to increase the recv buffer."""
    loop = asyncio.get_event_loop()
    transport = MagicMock()
    connected = asyncio.Event()
    tried_sizes = []

    def _setsockopt(*args, **kwargs):
        if args[0] == socket.SOL_SOCKET and args[1] == socket.SO_RCVBUF:
            size = args[2]
            tried_sizes.append(size)
            raise OSError("Socket error")

    mock_socket: socket.socket = create_autospec(
        socket.socket, spec_set=True, instance=True
    )
    mock_socket.type = socket.SOCK_STREAM
    mock_socket.fileno.return_value = 1
    mock_socket.getpeername.return_value = ("10.0.0.512", 323)
    mock_socket.setsockopt = _setsockopt
    mock_socket.sendmsg.side_effect = OSError("Socket error")
    aiohappyeyeballs_start_connection.return_value = mock_socket

    with patch.object(
        loop,
        "create_connection",
        side_effect=partial(_create_mock_transport_protocol, transport, connected),
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await asyncio.sleep(0)
        await connected.wait()
        protocol = conn._frame_helper
        send_plaintext_hello(protocol)
        await connect_task

    assert "Unable to increase the socket receive buffer size to 131072" in caplog.text
    assert tried_sizes == [2097152, 1048576, 524288, 262144, 131072]

    # Failure to increase the buffer size should not cause the connection to fail
    assert conn.is_connected


@pytest.mark.asyncio
async def test_start_connection_can_only_increase_buffer_size_to_262144(
    conn: APIConnection,
    resolve_host,
    aiohappyeyeballs_start_connection: MagicMock,
    caplog: pytest.LogCaptureFixture,
):
    """Test the receive buffer can only be increased to 262144."""
    loop = asyncio.get_event_loop()
    transport = MagicMock()
    connected = asyncio.Event()
    tried_sizes = []

    def _setsockopt(*args, **kwargs):
        if args[0] == socket.SOL_SOCKET and args[1] == socket.SO_RCVBUF:
            size = args[2]
            tried_sizes.append(size)
            if size != 262144:
                raise OSError("Socket error")

    mock_socket: socket.socket = create_autospec(
        socket.socket, spec_set=True, instance=True
    )
    mock_socket.type = socket.SOCK_STREAM
    mock_socket.fileno.return_value = 1
    mock_socket.getpeername.return_value = ("10.0.0.512", 323)
    mock_socket.setsockopt = _setsockopt
    mock_socket.sendmsg.side_effect = OSError("Socket error")
    aiohappyeyeballs_start_connection.return_value = mock_socket

    with patch.object(
        loop,
        "create_connection",
        side_effect=partial(_create_mock_transport_protocol, transport, connected),
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await asyncio.sleep(0)
        await connected.wait()
        protocol = conn._frame_helper
        send_plaintext_hello(protocol)
        await connect_task

    assert "Unable to increase the socket receive buffer size" not in caplog.text
    assert tried_sizes == [2097152, 1048576, 524288, 262144]

    # Failure to increase the buffer size should not cause the connection to fail
    assert conn.is_connected


@pytest.mark.asyncio
async def test_start_connection_times_out(
    conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
    """Test handling of start connection timing out."""
    asyncio.get_event_loop()

    async def _mock_socket_connect(*args, **kwargs):
        await asyncio.sleep(500)

    with (
        patch(
            "aioesphomeapi.connection.aiohappyeyeballs.start_connection",
            side_effect=_mock_socket_connect,
        ),
        patch("aioesphomeapi.connection.TCP_CONNECT_TIMEOUT", 0.0),
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await asyncio.sleep(0)

        async_fire_time_changed(utcnow() + timedelta(seconds=200))
        await asyncio.sleep(0)

    with pytest.raises(APIConnectionError, match="Timeout while connecting"):
        await connect_task

    async_fire_time_changed(utcnow() + timedelta(seconds=600))
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_start_connection_os_error(conn: APIConnection, resolve_host):
    """Test handling of start connection has an OSError."""
    asyncio.get_event_loop()

    with patch(
        "aioesphomeapi.connection.aiohappyeyeballs.start_connection",
        side_effect=OSError("Socket error"),
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await asyncio.sleep(0)
        with pytest.raises(APIConnectionError, match="Socket error"):
            await connect_task

    async_fire_time_changed(utcnow() + timedelta(seconds=600))
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_start_connection_is_cancelled(conn: APIConnection, resolve_host):
    """Test handling of start connection is cancelled."""
    asyncio.get_event_loop()

    with patch(
        "aioesphomeapi.connection.aiohappyeyeballs.start_connection",
        side_effect=asyncio.CancelledError,
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await asyncio.sleep(0)
        with pytest.raises(APIConnectionError, match="Starting connection cancelled"):
            await connect_task

    async_fire_time_changed(utcnow() + timedelta(seconds=600))
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_finish_connection_is_cancelled(
    conn: APIConnection, resolve_host, aiohappyeyeballs_start_connection
):
    """Test handling of finishing connection being cancelled."""
    loop = asyncio.get_event_loop()

    with patch.object(loop, "create_connection", side_effect=asyncio.CancelledError):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await asyncio.sleep(0)
        with pytest.raises(APIConnectionError, match="Finishing connection cancelled"):
            await connect_task

    async_fire_time_changed(utcnow() + timedelta(seconds=600))
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_finish_connection_times_out(
    plaintext_connect_task_no_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test handling of finish connection timing out."""
    conn, transport, protocol, connect_task = plaintext_connect_task_no_login
    messages = []

    def on_msg(msg):
        messages.append(msg)

    remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
    mock_data_received(
        protocol,
        b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
    )
    await asyncio.sleep(0)

    async_fire_time_changed(utcnow() + timedelta(seconds=200))
    await asyncio.sleep(0)

    with pytest.raises(
        APIConnectionError, match="Timeout waiting for HelloResponse after 30.0s"
    ):
        await connect_task

    async_fire_time_changed(utcnow() + timedelta(seconds=600))
    await asyncio.sleep(0)

    assert not conn.is_connected
    remove()
    conn.force_disconnect()
    await asyncio.sleep(0)


@pytest.mark.parametrize(
    ("exception_map"),
    [
        (OSError("Socket error"), HandshakeAPIError),
        (APIConnectionError, APIConnectionError),
        (SocketClosedAPIError, SocketClosedAPIError),
        (asyncio.TimeoutError, TimeoutAPIError),
        (asyncio.CancelledError, APIConnectionError),
    ],
)
@pytest.mark.asyncio
async def test_plaintext_connection_fails_handshake(
    conn: APIConnection,
    resolve_host: AsyncMock,
    aiohappyeyeballs_start_connection: MagicMock,
    exception_map: tuple[Exception, Exception],
) -> None:
    """Test that the frame helper is closed before the underlying socket.

    If we don't do this, asyncio will get confused and not release the socket.
    """
    loop = asyncio.get_event_loop()
    exception, raised_exception = exception_map
    messages = []
    transport = MagicMock()
    connected = asyncio.Event()

    class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
        """Plaintext frame helper that raises exception on handshake."""

    def _create_failing_mock_transport_protocol(
        transport: asyncio.Transport,
        connected: asyncio.Event,
        create_func: Callable[[], APIPlaintextFrameHelper],
        **kwargs,
    ) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
        protocol: APIPlaintextFrameHelperHandshakeException = create_func()
        protocol._transport = cast(asyncio.Transport, transport)
        protocol._writer = transport.write
        protocol.ready_future.set_exception(exception)
        connected.set()
        return transport, protocol

    def on_msg(msg):
        messages.append(msg)

    remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
    transport = MagicMock()

    call_order = []

    def _socket_close_call():
        call_order.append("socket_close")

    def _frame_helper_close_call():
        call_order.append("frame_helper_close")

    async def _do_finish_connect(self, *args, **kwargs):
        try:
            await conn._connect_init_frame_helper()
        finally:
            conn._socket.close = _socket_close_call
            conn._frame_helper.close = _frame_helper_close_call

    with (
        patch(
            "aioesphomeapi.connection.APIPlaintextFrameHelper",
            APIPlaintextFrameHelperHandshakeException,
        ),
        patch.object(
            loop,
            "create_connection",
            side_effect=partial(
                _create_failing_mock_transport_protocol, transport, connected
            ),
        ),
        patch.object(conn, "_do_finish_connect", _do_finish_connect),
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await connected.wait()

    with (pytest.raises(raised_exception),):
        await asyncio.sleep(0)
        await connect_task

    # Ensure the frame helper is closed before the socket
    # so asyncio releases the socket
    assert call_order == ["frame_helper_close", "socket_close"]
    assert not conn.is_connected
    remove()
    conn.force_disconnect()
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_connect_wrong_password(
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, True)

    with pytest.raises(InvalidAuthAPIError):
        await connect_task

    assert not conn.is_connected


@pytest.mark.asyncio
async def test_connect_correct_password(
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    await connect_task

    assert conn.is_connected


@pytest.mark.asyncio
async def test_connect_wrong_version(
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    send_plaintext_hello(protocol, 3, 2)
    send_plaintext_connect_response(protocol, False)

    with pytest.raises(APIConnectionError, match="Incompatible API version"):
        await connect_task

    assert conn.is_connected is False


@pytest.mark.asyncio
async def test_connect_wrong_name(
    plaintext_connect_task_expected_name: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_expected_name
    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    with pytest.raises(
        APIConnectionError,
        match="Expected 'test' but server sent a different name: 'fake'",
    ):
        await connect_task

    assert conn.is_connected is False


@pytest.mark.asyncio
async def test_force_disconnect_fails(
    caplog: pytest.LogCaptureFixture,
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    await connect_task
    assert conn.is_connected

    with patch.object(protocol, "_writer", side_effect=OSError):
        conn.force_disconnect()
    assert "Failed to send (forced) disconnect request" in caplog.text
    await asyncio.sleep(0)


@pytest.mark.parametrize(
    ("exception_map"),
    [
        (OSError("original message"), ReadFailedAPIError),
        (APIConnectionError("original message"), APIConnectionError),
        (SocketClosedAPIError("original message"), SocketClosedAPIError),
    ],
)
@pytest.mark.asyncio
async def test_connection_lost_while_connecting(
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
    exception_map: tuple[Exception, Exception],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    exception, raised_exception = exception_map
    protocol.connection_lost(exception)

    with pytest.raises(raised_exception, match="original message"):
        await connect_task

    assert not conn.is_connected


@pytest.mark.parametrize(
    ("exception_map"),
    [
        (OSError("original message"), SocketAPIError),
        (APIConnectionError("original message"), APIConnectionError),
        (SocketClosedAPIError("original message"), SocketClosedAPIError),
    ],
)
@pytest.mark.asyncio
async def test_connection_error_during_hello(
    conn: APIConnection,
    resolve_host,
    aiohappyeyeballs_start_connection,
    exception_map: tuple[Exception, Exception],
) -> None:
    loop = asyncio.get_event_loop()
    transport = MagicMock()
    connected = asyncio.Event()
    exception, raised_exception = exception_map

    with (
        patch.object(
            loop,
            "create_connection",
            side_effect=partial(_create_mock_transport_protocol, transport, connected),
        ),
        patch.object(conn, "_connect_hello_login", side_effect=exception),
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await connected.wait()

    with pytest.raises(raised_exception, match="original message"):
        await connect_task

    assert not conn.is_connected


@pytest.mark.parametrize(
    ("exception_map"),
    [
        (OSError("original message"), APIConnectionCancelledError),
        (APIConnectionError("original message"), APIConnectionError),
        (SocketClosedAPIError("original message"), SocketClosedAPIError),
    ],
)
@pytest.mark.asyncio
async def test_connection_cancelled_during_hello(
    conn: APIConnection,
    resolve_host,
    aiohappyeyeballs_start_connection,
    exception_map: tuple[Exception, Exception],
) -> None:
    loop = asyncio.get_event_loop()
    transport = MagicMock()
    connected = asyncio.Event()
    exception, raised_exception = exception_map

    async def _mock_frame_helper_error(*args, **kwargs):
        conn._frame_helper.connection_lost(exception)
        raise asyncio.CancelledError

    with (
        patch.object(
            loop,
            "create_connection",
            side_effect=partial(_create_mock_transport_protocol, transport, connected),
        ),
        patch.object(conn, "_connect_hello_login", _mock_frame_helper_error),
    ):
        connect_task = asyncio.create_task(connect(conn, login=False))
        await connected.wait()

    with pytest.raises(raised_exception, match="original message"):
        await connect_task

    assert not conn.is_connected


@pytest.mark.asyncio
async def test_connect_resolver_times_out(
    conn: APIConnection, aiohappyeyeballs_start_connection
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
    transport = MagicMock()
    connected = asyncio.Event()
    event_loop = asyncio.get_running_loop()

    with (
        patch(
            "aioesphomeapi.host_resolver.async_resolve_host",
            side_effect=asyncio.TimeoutError,
        ),
        patch.object(
            event_loop,
            "create_connection",
            side_effect=partial(_create_mock_transport_protocol, transport, connected),
        ),
        pytest.raises(
            ResolveAPIError,
            match="Timeout while resolving IP address for fake.address",
        ),
    ):
        await connect(conn, login=False)


@pytest.mark.asyncio
async def test_disconnect_fails_to_send_response(
    connection_params: ConnectionParams,
    resolve_host,
    aiohappyeyeballs_start_connection,
) -> None:
    loop = asyncio.get_event_loop()
    transport = MagicMock()
    connected = asyncio.Event()
    client = APIClient(
        address="mydevice.local",
        port=6052,
        password=None,
    )
    expected_disconnect = None

    async def _on_stop(_expected_disconnect: bool) -> None:
        nonlocal expected_disconnect
        expected_disconnect = _expected_disconnect

    with patch.object(
        loop,
        "create_connection",
        side_effect=partial(_create_mock_transport_protocol, transport, connected),
    ):
        connect_task = asyncio.create_task(
            connect_client(client, login=False, on_stop=_on_stop)
        )
        await connected.wait()
        protocol = client._connection._frame_helper
        send_plaintext_hello(protocol)
        await connect_task
        transport.reset_mock()

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    await connect_task
    assert client._connection.is_connected

    with patch.object(protocol, "_writer", side_effect=OSError):
        disconnect_request = DisconnectRequest()
        mock_data_received(protocol, generate_plaintext_packet(disconnect_request))

    # Wait one loop iteration for the disconnect to be processed
    await asyncio.sleep(0)
    assert expected_disconnect is True


@pytest.mark.asyncio
async def test_disconnect_success_case(
    connection_params: ConnectionParams,
    resolve_host,
    aiohappyeyeballs_start_connection,
) -> None:
    loop = asyncio.get_running_loop()
    transport = MagicMock()
    connected = asyncio.Event()
    client = APIClient(
        address="mydevice.local",
        port=6052,
        password=None,
    )
    expected_disconnect = None

    async def _on_stop(_expected_disconnect: bool) -> None:
        nonlocal expected_disconnect
        expected_disconnect = _expected_disconnect

    with patch.object(
        loop,
        "create_connection",
        side_effect=partial(_create_mock_transport_protocol, transport, connected),
    ):
        connect_task = asyncio.create_task(
            connect_client(client, login=False, on_stop=_on_stop)
        )
        await connected.wait()
        protocol = client._connection._frame_helper
        send_plaintext_hello(protocol)
        await connect_task
        transport.reset_mock()

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    await connect_task
    assert client._connection.is_connected

    disconnect_request = DisconnectRequest()
    mock_data_received(protocol, generate_plaintext_packet(disconnect_request))

    # Wait one loop iteration for the disconnect to be processed
    await asyncio.sleep(0)
    assert expected_disconnect is True
    assert not client._connection


@pytest.mark.asyncio
async def test_ping_disconnects_after_no_responses(
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    await connect_task

    ping_request_bytes = b"\x00\x00\x07"

    assert conn.is_connected
    transport.reset_mock()
    expected_calls = []
    start_time = utcnow()
    max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO)
    for count in range(1, max_pings_to_disconnect_after + 1):
        async_fire_time_changed(
            start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
        )
        assert transport.write.call_count == count
        expected_calls.append(call(ping_request_bytes))
        assert transport.write.mock_calls == expected_calls

    assert conn.is_connected is True

    # We should disconnect once we reach more than 4 missed pings
    async_fire_time_changed(
        start_time
        + timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
    )
    assert transport.write.call_count == max_pings_to_disconnect_after + 1

    assert conn.is_connected is False


@pytest.mark.asyncio
async def test_ping_does_not_disconnect_if_we_get_responses(
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    await connect_task
    ping_request_bytes = b"\x00\x00\x07"

    assert conn.is_connected
    transport.reset_mock()
    start_time = utcnow()
    max_pings_to_disconnect_after = int(KEEP_ALIVE_TIMEOUT_RATIO)
    for count in range(1, max_pings_to_disconnect_after + 2):
        async_fire_time_changed(
            start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
        )
        send_ping_response(protocol)

    # We should only send 1 ping request if we are getting responses
    assert transport.write.call_count == 1
    assert transport.write.mock_calls == [call(ping_request_bytes)]

    # We should disconnect if we are getting ping responses
    assert conn.is_connected is True


def test_raise_during_send_messages_when_not_yet_connected(conn: APIConnection) -> None:
    """Test that we raise when sending messages before we are connected."""
    with pytest.raises(ConnectionNotEstablishedAPIError):
        conn.send_message(PingRequest())


@pytest.mark.asyncio
async def test_respond_to_ping_request(
    caplog: pytest.LogCaptureFixture,
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login

    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)

    await connect_task
    assert conn.is_connected

    transport.reset_mock()
    send_ping_request(protocol)
    # We should respond to ping requests
    ping_response_bytes = b"\x00\x00\x08"
    assert transport.write.call_count == 1
    assert transport.write.mock_calls == [call(ping_response_bytes)]


@pytest.mark.asyncio
async def test_unknown_protobuf_message_type_logged(
    api_client: tuple[
        APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
    ],
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test unknown protobuf messages are logged but do not cause the connection to collapse."""
    client, connection, transport, protocol = api_client
    response: message.Message = DeviceInfoResponse(
        name="realname",
        friendly_name="My Device",
        has_deep_sleep=True,
    )
    caplog.set_level(logging.DEBUG)
    client.set_debug(True)
    bytes_ = response.SerializeToString()
    message_with_invalid_protobuf_number = (
        b"\0"
        + _cached_varuint_to_bytes(len(bytes_))
        + _cached_varuint_to_bytes(16385)
        + bytes_
    )

    mock_data_received(protocol, message_with_invalid_protobuf_number)

    assert "Skipping unknown message type 16385" in caplog.text
    assert connection.is_connected
    connection.force_disconnect()
    await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_bad_protobuf_message_drops_connection(
    api_client: tuple[
        APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
    ],
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test ad bad protobuf messages is logged and causes the connection to collapse."""
    client, connection, transport, protocol = api_client
    msg: message.Message = TextSensorStateResponse(
        key=1, state="invalid", missing_state=False
    )
    caplog.clear()
    caplog.set_level(logging.DEBUG)
    client.set_debug(True)
    bytes_ = msg.SerializeToString()
    # Replace the bytes with invalid UTF-8
    bytes_ = bytes.replace(bytes_, b"invalid", b"inval\xe9 ")

    message_with_bad_protobuf_data = (
        b"\0"
        + _cached_varuint_to_bytes(len(bytes_))
        + _cached_varuint_to_bytes(27)
        + bytes_
    )
    mock_data_received(protocol, message_with_bad_protobuf_data)
    assert "Invalid protobuf message: type=TextSensorStateResponse" in caplog.text
    assert connection.is_connected is False


@pytest.mark.asyncio
async def test_connection_cannot_be_reused(
    plaintext_connect_task_with_login: tuple[
        APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
    ],
) -> None:
    """Test that we raise when trying to connect when already connected."""
    conn, transport, protocol, connect_task = plaintext_connect_task_with_login
    send_plaintext_hello(protocol)
    send_plaintext_connect_response(protocol, False)
    await connect_task
    with pytest.raises(RuntimeError):
        await conn.start_connection()


@pytest.mark.asyncio
async def test_attempting_to_finish_unstarted_connection(
    conn: APIConnection,
) -> None:
    """Test that we raise when trying to finish an unstarted connection."""
    with pytest.raises(RuntimeError):
        await conn.finish_connection(login=False)
