File: test_mqtt_channel.py

package info (click to toggle)
python-roborock 3.7.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,860 kB
  • sloc: python: 14,542; makefile: 17
file content (272 lines) | stat: -rw-r--r-- 9,831 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""Tests for the MqttChannel class."""

import asyncio
import json
import logging
from collections.abc import AsyncGenerator, Callable
from unittest.mock import AsyncMock, Mock

import pytest

from roborock.data import HomeData, UserData
from roborock.devices.mqtt_channel import MqttChannel
from roborock.mqtt.session import MqttParams
from roborock.protocol import create_mqtt_decoder, create_mqtt_encoder
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol

from .. import mock_data

USER_DATA = UserData.from_dict(mock_data.USER_DATA)
TEST_MQTT_PARAMS = MqttParams(
    host="localhost",
    port=1883,
    tls=False,
    username="username",
    password="password",
    timeout=10.0,
)
TEST_LOCAL_KEY = "local_key"

TEST_REQUEST = RoborockMessage(
    protocol=RoborockMessageProtocol.RPC_REQUEST,
    payload=json.dumps({"dps": {"101": json.dumps({"id": 12345, "method": "get_status"})}}).encode(),
)
TEST_RESPONSE = RoborockMessage(
    protocol=RoborockMessageProtocol.RPC_RESPONSE,
    payload=json.dumps({"dps": {"102": json.dumps({"id": 12345, "result": {"state": "cleaning"}})}}).encode(),
)
TEST_REQUEST2 = RoborockMessage(
    protocol=RoborockMessageProtocol.RPC_REQUEST,
    payload=json.dumps({"dps": {"101": json.dumps({"id": 54321, "method": "get_status"})}}).encode(),
)
TEST_RESPONSE2 = RoborockMessage(
    protocol=RoborockMessageProtocol.RPC_RESPONSE,
    payload=json.dumps({"dps": {"102": json.dumps({"id": 54321, "result": {"state": "cleaning"}})}}).encode(),
)
ENCODER = create_mqtt_encoder(TEST_LOCAL_KEY)
DECODER = create_mqtt_decoder(TEST_LOCAL_KEY)


@pytest.fixture(name="mqtt_session", autouse=True)
def setup_mqtt_session() -> Mock:
    """Fixture to set up the MQTT session for the tests."""
    return AsyncMock()


@pytest.fixture(name="mqtt_channel", autouse=True)
def setup_mqtt_channel(mqtt_session: Mock) -> MqttChannel:
    """Fixture to set up the MQTT channel for the tests."""
    return MqttChannel(
        mqtt_session, duid="abc123", local_key=TEST_LOCAL_KEY, rriot=USER_DATA.rriot, mqtt_params=TEST_MQTT_PARAMS
    )


@pytest.fixture(name="mqtt_subscribers", autouse=True)
async def setup_subscribe_callback(mqtt_session: Mock) -> AsyncGenerator[list[Callable[[bytes], None]], None]:
    """Fixture to record messages received by the subscriber."""
    subscriber_callbacks = []

    def mock_subscribe(_: str, callback: Callable[[bytes], None]) -> Callable[[], None]:
        subscriber_callbacks.append(callback)
        return lambda: subscriber_callbacks.remove(callback)

    mqtt_session.subscribe.side_effect = mock_subscribe
    yield subscriber_callbacks
    assert not subscriber_callbacks, "Not all subscribers were unsubscribed"


@pytest.fixture(name="mqtt_message_handler")
async def setup_message_handler(mqtt_subscribers: list[Callable[[bytes], None]]) -> Callable[[bytes], None]:
    """Fixture to allow simulating incoming MQTT messages."""

    def invoke_all_callbacks(message: bytes) -> None:
        for callback in mqtt_subscribers:
            callback(message)

    return invoke_all_callbacks


@pytest.fixture
def warning_caplog(
    caplog: pytest.LogCaptureFixture,
) -> pytest.LogCaptureFixture:
    """Fixture to capture warning messages."""
    caplog.set_level(logging.WARNING)
    return caplog


async def home_home_data_no_devices() -> HomeData:
    """Mock home data API that returns no devices."""
    return HomeData(
        id=1,
        name="Test Home",
        devices=[],
        products=[],
    )


async def mock_home_data() -> HomeData:
    """Mock home data API that returns devices."""
    return HomeData.from_dict(mock_data.HOME_DATA_RAW)


async def test_publish_success(
    mqtt_session: Mock,
    mqtt_channel: MqttChannel,
    mqtt_message_handler: Callable[[bytes], None],
) -> None:
    """Test successful RPC command sending and response handling."""
    # Send a test request. We use a task so we can simulate receiving the response
    # while the command is still being processed.
    await mqtt_channel.publish(TEST_REQUEST)
    await asyncio.sleep(0.01)  # yield

    # Simulate receiving the response message via MQTT
    mqtt_message_handler(ENCODER(TEST_RESPONSE))
    await asyncio.sleep(0.01)  # yield

    # Verify the command was sent
    assert mqtt_session.publish.called
    assert mqtt_session.publish.call_args[0][0] == "rr/m/i/user123/username/abc123"
    raw_sent_msg = mqtt_session.publish.call_args[0][1]  # == b"encoded_message"
    decoded_message = next(iter(DECODER(raw_sent_msg)))
    assert decoded_message == TEST_REQUEST
    assert decoded_message.protocol == RoborockMessageProtocol.RPC_REQUEST


@pytest.mark.parametrize(("connected"), [(True), (False)])
async def test_connection_status(
    mqtt_session: Mock,
    mqtt_channel: MqttChannel,
    connected: bool,
) -> None:
    """Test successful RPC command sending and response handling."""
    mqtt_session.connected = connected
    assert mqtt_channel.is_connected is connected
    assert mqtt_channel.is_local_connected is False


async def test_message_decode_error(
    mqtt_channel: MqttChannel,
    mqtt_message_handler: Callable[[bytes], None],
    caplog: pytest.LogCaptureFixture,
) -> None:
    """Test an error during message decoding."""
    callback = Mock()
    unsub = await mqtt_channel.subscribe(callback)

    mqtt_message_handler(b"invalid_payload")
    await asyncio.sleep(0.01)  # yield

    assert len(caplog.records) == 1
    assert caplog.records[0].levelname == "WARNING"
    assert "Failed to decode message" in caplog.records[0].message
    unsub()


async def test_concurrent_subscribers(mqtt_session: Mock, mqtt_channel: MqttChannel) -> None:
    """Test multiple concurrent subscribers receive all messages."""
    # Set up multiple subscribers
    subscriber1_messages: list[RoborockMessage] = []
    subscriber2_messages: list[RoborockMessage] = []
    subscriber3_messages: list[RoborockMessage] = []

    unsub1 = await mqtt_channel.subscribe(subscriber1_messages.append)
    unsub2 = await mqtt_channel.subscribe(subscriber2_messages.append)
    unsub3 = await mqtt_channel.subscribe(subscriber3_messages.append)

    # Verify that each subscription creates a separate call to the MQTT session
    assert mqtt_session.subscribe.call_count == 3

    # All subscriptions should be to the same topic
    for call in mqtt_session.subscribe.call_args_list:
        assert call[0][0] == "rr/m/o/user123/username/abc123"

    # Get the message handlers for each subscriber
    handler1 = mqtt_session.subscribe.call_args_list[0][0][1]
    handler2 = mqtt_session.subscribe.call_args_list[1][0][1]
    handler3 = mqtt_session.subscribe.call_args_list[2][0][1]

    # Simulate receiving messages - each handler should decode the message independently
    handler1(ENCODER(TEST_REQUEST))
    handler2(ENCODER(TEST_REQUEST))
    handler3(ENCODER(TEST_REQUEST))
    await asyncio.sleep(0.01)  # yield

    # All subscribers should receive the message
    assert len(subscriber1_messages) == 1
    assert len(subscriber2_messages) == 1
    assert len(subscriber3_messages) == 1
    assert subscriber1_messages[0] == TEST_REQUEST
    assert subscriber2_messages[0] == TEST_REQUEST
    assert subscriber3_messages[0] == TEST_REQUEST

    # Send another message to all handlers
    handler1(ENCODER(TEST_RESPONSE))
    handler2(ENCODER(TEST_RESPONSE))
    handler3(ENCODER(TEST_RESPONSE))
    await asyncio.sleep(0.01)  # yield

    # All subscribers should have received both messages
    assert len(subscriber1_messages) == 2
    assert len(subscriber2_messages) == 2
    assert len(subscriber3_messages) == 2
    assert subscriber1_messages == [TEST_REQUEST, TEST_RESPONSE]
    assert subscriber2_messages == [TEST_REQUEST, TEST_RESPONSE]
    assert subscriber3_messages == [TEST_REQUEST, TEST_RESPONSE]

    # Test unsubscribing one subscriber
    unsub1()

    # Send another message only to remaining handlers
    handler2(ENCODER(TEST_REQUEST2))
    handler3(ENCODER(TEST_REQUEST2))
    await asyncio.sleep(0.01)  # yield

    # First subscriber should not have received the new message
    assert len(subscriber1_messages) == 2
    assert len(subscriber2_messages) == 3
    assert len(subscriber3_messages) == 3
    assert subscriber2_messages[2] == TEST_REQUEST2
    assert subscriber3_messages[2] == TEST_REQUEST2

    # Unsubscribe remaining subscribers
    unsub2()
    unsub3()


async def test_concurrent_subscribers_with_callback_exception(
    mqtt_session: Mock, mqtt_channel: MqttChannel, caplog: pytest.LogCaptureFixture
) -> None:
    """Test that exception in one subscriber callback doesn't affect others."""
    caplog.set_level(logging.ERROR)

    def failing_callback(message: RoborockMessage) -> None:
        raise ValueError("Callback error")

    subscriber2_messages: list[RoborockMessage] = []

    unsub1 = await mqtt_channel.subscribe(failing_callback)
    unsub2 = await mqtt_channel.subscribe(subscriber2_messages.append)

    # Get the message handlers
    handler1 = mqtt_session.subscribe.call_args_list[0][0][1]
    handler2 = mqtt_session.subscribe.call_args_list[1][0][1]

    # Simulate receiving a message - first handler will raise exception
    handler1(ENCODER(TEST_REQUEST))
    handler2(ENCODER(TEST_REQUEST))
    await asyncio.sleep(0.01)  # yield

    # Exception should be logged but other subscribers should still work
    assert len(subscriber2_messages) == 1
    assert subscriber2_messages[0] == TEST_REQUEST

    # Check that exception was logged
    error_records = [record for record in caplog.records if record.levelname == "ERROR"]
    assert len(error_records) == 1
    assert "Uncaught error in callback 'failing_callback'" in error_records[0].message

    # Unsubscribe all remaining subscribers
    unsub1()
    unsub2()