File: test_roborock_session.py

package info (click to toggle)
python-roborock 2.39.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,128 kB
  • sloc: python: 10,342; makefile: 17
file content (231 lines) | stat: -rw-r--r-- 7,449 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""Tests for the MQTT session module."""

import asyncio
from collections.abc import AsyncGenerator, Callable
from queue import Queue
from typing import Any
from unittest.mock import AsyncMock, Mock, patch

import aiomqtt
import paho.mqtt.client as mqtt
import pytest

from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.mqtt.session import MqttParams, MqttSessionException
from tests import mqtt_packet
from tests.conftest import FakeSocketHandler

# We mock out the connection so these params are not used/verified
FAKE_PARAMS = MqttParams(
    host="localhost",
    port=1883,
    tls=False,
    username="username",
    password="password",
    timeout=10.0,
)


@pytest.fixture(autouse=True)
def mqtt_server_fixture(mock_create_connection: None, mock_select: None) -> None:
    """Fixture to prepare a fake MQTT server."""


@pytest.fixture(autouse=True)
async def mock_client_fixture() -> AsyncGenerator[None, None]:
    """Fixture to patch the MQTT underlying sync client.

    The tests use fake sockets, so this ensures that the async mqtt client does not
    attempt to listen on them directly. We instead just poll the socket for
    data ourselves.
    """

    event_loop = asyncio.get_running_loop()

    orig_class = mqtt.Client

    async def poll_sockets(client: mqtt.Client) -> None:
        """Poll the mqtt client sockets in a loop to pick up new data."""
        while True:
            event_loop.call_soon_threadsafe(client.loop_read)
            event_loop.call_soon_threadsafe(client.loop_write)
            await asyncio.sleep(0.1)

    task: asyncio.Task[None] | None = None

    def new_client(*args: Any, **kwargs: Any) -> mqtt.Client:
        """Create a new mqtt client and start the socket polling task."""
        nonlocal task
        client = orig_class(*args, **kwargs)
        task = event_loop.create_task(poll_sockets(client))
        return client

    with patch("aiomqtt.client.Client._on_socket_open"), patch("aiomqtt.client.Client._on_socket_close"), patch(
        "aiomqtt.client.Client._on_socket_register_write"
    ), patch("aiomqtt.client.Client._on_socket_unregister_write"), patch(
        "aiomqtt.client.mqtt.Client", side_effect=new_client
    ):
        yield
        if task:
            task.cancel()


@pytest.fixture
def push_response(response_queue: Queue, fake_socket_handler: FakeSocketHandler) -> Callable[[bytes], None]:
    """Fixtures to push messages."""

    def push(message: bytes) -> None:
        response_queue.put(message)
        fake_socket_handler.push_response()

    return push


class Subscriber:
    """Mock subscriber class.

    This will capture messages published on the session so the tests can verify
    they were received.
    """

    def __init__(self) -> None:
        """Initialize the subscriber."""
        self.messages: list[bytes] = []
        self.event: asyncio.Event = asyncio.Event()

    def append(self, message: bytes) -> None:
        """Append a message to the subscriber."""
        self.messages.append(message)
        self.event.set()

    async def wait(self) -> None:
        """Wait for a message to be received."""
        await self.event.wait()
        self.event.clear()


async def test_session(push_response: Callable[[bytes], None]) -> None:
    """Test the MQTT session."""

    push_response(mqtt_packet.gen_connack(rc=0, flags=2))
    session = await create_mqtt_session(FAKE_PARAMS)
    assert session.connected

    push_response(mqtt_packet.gen_suback(mid=1))
    subscriber1 = Subscriber()
    unsub1 = await session.subscribe("topic-1", subscriber1.append)

    push_response(mqtt_packet.gen_suback(mid=2))
    subscriber2 = Subscriber()
    await session.subscribe("topic-2", subscriber2.append)

    push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
    await subscriber1.wait()
    assert subscriber1.messages == [b"12345"]
    assert not subscriber2.messages

    push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890"))
    await subscriber2.wait()
    assert subscriber2.messages == [b"67890"]

    push_response(mqtt_packet.gen_publish("topic-1", mid=5, payload=b"ABC"))
    await subscriber1.wait()
    assert subscriber1.messages == [b"12345", b"ABC"]
    assert subscriber2.messages == [b"67890"]

    # Messages are no longer received after unsubscribing
    unsub1()
    push_response(mqtt_packet.gen_publish("topic-1", payload=b"ignored"))
    assert subscriber1.messages == [b"12345", b"ABC"]

    assert session.connected
    await session.close()
    assert not session.connected


async def test_session_no_subscribers(push_response: Callable[[bytes], None]) -> None:
    """Test the MQTT session."""

    push_response(mqtt_packet.gen_connack(rc=0, flags=2))
    push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
    push_response(mqtt_packet.gen_publish("topic-2", mid=4, payload=b"67890"))
    session = await create_mqtt_session(FAKE_PARAMS)
    assert session.connected

    await session.close()
    assert not session.connected


async def test_publish_command(push_response: Callable[[bytes], None]) -> None:
    """Test publishing during an MQTT session."""

    push_response(mqtt_packet.gen_connack(rc=0, flags=2))
    session = await create_mqtt_session(FAKE_PARAMS)

    push_response(mqtt_packet.gen_publish("topic-1", mid=3, payload=b"12345"))
    await session.publish("topic-1", message=b"payload")

    assert session.connected
    await session.close()
    assert not session.connected


class FakeAsyncIterator:
    """Fake async iterator that waits for messages to arrive, but they never do.

    This is used for testing exceptions in other client functions.
    """

    def __aiter__(self):
        return self

    async def __anext__(self) -> None:
        """Iterator that does not generate any messages."""
        while True:
            await asyncio.sleep(1)


async def test_publish_failure() -> None:
    """Test an MQTT error is received when publishing a message."""

    mock_client = AsyncMock()
    mock_client.messages = FakeAsyncIterator()

    mock_aenter = AsyncMock()
    mock_aenter.return_value = mock_client

    with patch("roborock.mqtt.roborock_session.aiomqtt.Client.__aenter__", mock_aenter):
        session = await create_mqtt_session(FAKE_PARAMS)
        assert session.connected

        mock_client.publish.side_effect = aiomqtt.MqttError

        with pytest.raises(MqttSessionException, match="Error publishing message"):
            await session.publish("topic-1", message=b"payload")


async def test_subscribe_failure() -> None:
    """Test an MQTT error while subscribing."""

    mock_client = AsyncMock()
    mock_client.messages = FakeAsyncIterator()

    mock_aenter = AsyncMock()
    mock_aenter.return_value = mock_client

    mock_shim = Mock()
    mock_shim.return_value.__aenter__ = mock_aenter
    mock_shim.return_value.__aexit__ = AsyncMock()

    with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim):
        session = await create_mqtt_session(FAKE_PARAMS)
        assert session.connected

        mock_client.subscribe.side_effect = aiomqtt.MqttError

        subscriber1 = Subscriber()
        with pytest.raises(MqttSessionException, match="Error subscribing to topic"):
            await session.subscribe("topic-1", subscriber1.append)

        assert not subscriber1.messages
        await session.close()