1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
|
"""End-to-end tests for LocalChannel using fake sockets."""
import asyncio
from collections.abc import AsyncGenerator
import pytest
import syrupy
from roborock.devices.transport.local_channel import LocalChannel
from roborock.protocol import MessageParser, create_local_decoder
from roborock.protocols.v1_protocol import LocalProtocolVersion
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from tests.fixtures.logging import CapturedRequestLog
from tests.fixtures.mqtt import Subscriber
from tests.mock_data import LOCAL_KEY
TEST_HOST = "192.168.1.100"
TEST_DEVICE_UID = "test_device_uid"
TEST_RANDOM = 23
@pytest.fixture(name="local_channel")
async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID)
yield channel
channel.close()
def build_raw_response(
protocol: RoborockMessageProtocol,
seq: int,
payload: bytes,
version: LocalProtocolVersion = LocalProtocolVersion.V1,
connect_nonce: int | None = None,
ack_nonce: int | None = None,
) -> bytes:
"""Build an encoded response message."""
message = RoborockMessage(
protocol=protocol,
random=23,
seq=seq,
payload=payload,
version=version.value.encode(),
)
return MessageParser.build(message, local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=ack_nonce)
async def test_connect(
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test connecting to the device."""
# Queue HELLO response with payload to ensure it can be parsed
local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok"))
await local_channel.connect()
assert local_channel.is_connected
assert local_received_requests.qsize() == 1
# Verify HELLO request
request_bytes = await local_received_requests.get()
# Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None
# which causes MessageParser to fail parsing. For now we verify the raw bytes.
# Protocol is at offset 19 (2 bytes)
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
assert len(request_bytes) >= 21
protocol_bytes = request_bytes[19:21]
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
assert snapshot == log
async def test_send_command(
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test sending a command."""
# Queue HELLO response
local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok"))
await local_channel.connect()
# Clear requests from handshake
while not local_received_requests.empty():
await local_received_requests.get()
# Send command
cmd_seq = 123
msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
seq=cmd_seq,
payload=b'{"method":"get_status"}',
)
# Prepare a fake response to the command.
local_response_queue.put_nowait(
build_raw_response(RoborockMessageProtocol.RPC_RESPONSE, cmd_seq, payload=b'{"status": "ok"}')
)
subscriber = Subscriber()
unsub = await local_channel.subscribe(subscriber.append)
await local_channel.publish(msg)
# Verify request received by the server
request_bytes = await local_received_requests.get()
assert local_received_requests.empty()
# Decode request
decoder = create_local_decoder(local_key=LOCAL_KEY)
msgs = list(decoder(request_bytes))
assert len(msgs) == 1
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
assert msgs[0].payload == b'{"method":"get_status"}'
assert msgs[0].version == LocalProtocolVersion.V1.value.encode()
# Verify response received by subscriber
await subscriber.wait()
assert len(subscriber.messages) == 1
response_message = subscriber.messages[0]
assert isinstance(response_message, RoborockMessage)
assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
assert response_message.payload == b'{"status": "ok"}'
unsub()
assert snapshot == log
async def test_l01_session(
local_channel: LocalChannel,
local_response_queue: asyncio.Queue[bytes],
local_received_requests: asyncio.Queue[bytes],
log: CapturedRequestLog,
snapshot: syrupy.SnapshotAssertion,
) -> None:
"""Test connecting to a device that speaks the L01 protocol.
Note that this test currently has a delay because the actual local client
will delay before retrying with L01 after a failed 1.0 attempt. This should
also be improved in the actual client itself, but likely requires a closer
look at the actual device response in that scenario or moving to a serial
request/response behavior rather than publish/subscribe.
"""
# Client first attempts 1.0 and we reply with a fake invalid response. The
# response is arbitrary, and this could be improved by capturing a real L01
# device response to a 1.0 message.
local_response_queue.put_nowait(b"\x00")
# The client attempts L01 protocol as a followup. The connect nonce uses
# a deterministic number from deterministic_message_fixtures.
connect_nonce = 9090
local_response_queue.put_nowait(
build_raw_response(
RoborockMessageProtocol.HELLO_RESPONSE,
1,
payload=b"ok",
version=LocalProtocolVersion.L01,
connect_nonce=connect_nonce,
ack_nonce=None,
)
)
await local_channel.connect()
assert local_channel.is_connected
# Verify 1.0 HELLO request
request_bytes = await local_received_requests.get()
# Protocol is at offset 19 (2 bytes)
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
assert len(request_bytes) >= 21
protocol_bytes = request_bytes[19:21]
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
# Verify L01 HELLO request
request_bytes = await local_received_requests.get()
# Protocol is at offset 19 (2 bytes)
# Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
assert len(request_bytes) >= 21
protocol_bytes = request_bytes[19:21]
assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST
assert local_received_requests.empty()
# Verify the channel switched to L01 protocol
assert local_channel.protocol_version == LocalProtocolVersion.L01.value
# We have established a connection. Now send some messages.
# Publish an L01 command. Currently the caller of the local channel needs to
# determine the protocol version to use, but this could be pushed inside of
# the channel in the future.
cmd_seq = 123
msg = RoborockMessage(
protocol=RoborockMessageProtocol.RPC_REQUEST,
seq=cmd_seq,
payload=b'{"method":"get_status"}',
version=b"L01",
)
# Prepare a fake response to the command.
local_response_queue.put_nowait(
build_raw_response(
RoborockMessageProtocol.RPC_RESPONSE,
cmd_seq,
payload=b'{"status": "ok"}',
version=LocalProtocolVersion.L01,
connect_nonce=connect_nonce,
ack_nonce=TEST_RANDOM,
)
)
# Set up a subscriber to listen for the response then publish the message.
subscriber = Subscriber()
unsub = await local_channel.subscribe(subscriber.append)
await local_channel.publish(msg)
# Verify request received by the server
request_bytes = await local_received_requests.get()
decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=TEST_RANDOM)
msgs = list(decoder(request_bytes))
assert len(msgs) == 1
assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
assert msgs[0].payload == b'{"method":"get_status"}'
assert msgs[0].version == LocalProtocolVersion.L01.value.encode()
# Verify fake response published by the server, received by subscriber
await subscriber.wait()
assert len(subscriber.messages) == 1
response_message = subscriber.messages[0]
assert isinstance(response_message, RoborockMessage)
assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
assert response_message.payload == b'{"status": "ok"}'
assert response_message.version == LocalProtocolVersion.L01.value.encode()
unsub()
assert snapshot == log
|