1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
"""End-to-end tests for MQTT session.
These tests use a fake MQTT broker to verify the session implementation. We
mock out the lower level socket connections to simulate a broker which gets us
close to an "end to end" test without needing an actual MQTT broker server.
These are higher level tests that the similar tests in tests/mqtt/test_roborock_session.py
which use mocks to verify specific behaviors.
"""
from collections.abc import AsyncGenerator, Callable
from queue import Queue
import pytest
import syrupy
from roborock.mqtt.roborock_session import create_mqtt_session
from roborock.mqtt.session import MqttSession
from roborock.protocol import MessageParser
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from tests import mqtt_packet
from tests.fixtures.logging import CapturedRequestLog
from tests.fixtures.mqtt import FAKE_PARAMS, Subscriber
from tests.mock_data import LOCAL_KEY
@pytest.fixture(autouse=True)
def auto_mock_mqtt_client(mock_aiomqtt_client: None) -> None:
"""Automatically use the mock mqtt client fixture."""
@pytest.fixture(autouse=True)
def auto_fast_backoff(fast_backoff_fixture: None) -> None:
"""Automatically use the fast backoff fixture."""
@pytest.fixture(autouse=True)
def mqtt_server_fixture(mock_paho_mqtt_create_connection: None, mock_paho_mqtt_select: None) -> None:
"""Fixture to mock the MQTT connection.
This is here to pull in the mock socket pixtures into all tests used here.
"""
@pytest.fixture(name="session")
async def session_fixture(
push_mqtt_response: Callable[[bytes], None],
) -> AsyncGenerator[MqttSession, None]:
"""Fixture to create a new connected MQTT session."""
push_mqtt_response(mqtt_packet.gen_connack(rc=0, flags=2))
session = await create_mqtt_session(FAKE_PARAMS)
assert session.connected
try:
yield session
finally:
await session.close()
async def test_session_e2e_receive_message(
push_mqtt_response: Callable[[bytes], None],
session: MqttSession,
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test receiving a real Roborock message through the session."""
assert session.connected
# Subscribe to the topic. We'll next construct and push a message.
push_mqtt_response(mqtt_packet.gen_suback(mid=1))
subscriber = Subscriber()
await session.subscribe("topic-1", subscriber.append)
msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_RESPONSE,
payload=b'{"result":"ok"}',
seq=123,
)
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)
# Simulate receiving the message from the broker
push_mqtt_response(mqtt_packet.gen_publish("topic-1", mid=2, payload=payload))
# Verify it was dispatched to the subscriber
await subscriber.wait()
assert len(subscriber.messages) == 1
received_payload = subscriber.messages[0]
assert isinstance(received_payload, bytes)
assert received_payload == payload
# Verify the message payload contents
parsed_msgs, _ = MessageParser.parse(received_payload, local_key=LOCAL_KEY)
assert len(parsed_msgs) == 1
parsed_msg = parsed_msgs[0]
assert parsed_msg.protocol == RoborockMessageProtocol.RPC_RESPONSE
assert parsed_msg.seq == 123
# The payload in parsed_msg should be the decrypted bytes
assert parsed_msg.payload == b'{"result":"ok"}'
assert snapshot == log
async def test_session_e2e_publish_message(
push_mqtt_response: Callable[[bytes], None],
mqtt_received_requests: Queue,
session: MqttSession,
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test publishing a real Roborock message."""
# Publish a message to the brokwer
msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
payload=b'{"method":"get_status"}',
seq=456,
)
payload = MessageParser.build(msg, local_key=LOCAL_KEY, prefixed=False)
await session.publish("topic-1", payload)
# Verify what was sent to the broker
# We expect the payload to be present in the sent bytes
found = False
while not mqtt_received_requests.empty():
request = mqtt_received_requests.get()
if payload in request:
found = True
break
assert found, "Published payload not found in sent requests"
assert snapshot == log
|