File: test_local_session.py

package info (click to toggle)
python-roborock 4.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,480 kB
  • sloc: python: 16,602; makefile: 17; sh: 6
file content (243 lines) | stat: -rw-r--r-- 8,974 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
"""End-to-end tests for LocalChannel using fake sockets."""

import asyncio
from collections.abc import AsyncGenerator

import pytest
import syrupy

from roborock.devices.transport.local_channel import LocalChannel
from roborock.protocol import MessageParser, create_local_decoder
from roborock.protocols.v1_protocol import LocalProtocolVersion
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from tests.fixtures.logging import CapturedRequestLog
from tests.fixtures.mqtt import Subscriber
from tests.mock_data import LOCAL_KEY

TEST_HOST = "192.168.1.100"
TEST_DEVICE_UID = "test_device_uid"
TEST_RANDOM = 23


@pytest.fixture(name="local_channel")
async def local_channel_fixture(mock_async_create_local_connection: None) -> AsyncGenerator[LocalChannel, None]:
    channel = LocalChannel(host=TEST_HOST, local_key=LOCAL_KEY, device_uid=TEST_DEVICE_UID)
    yield channel
    channel.close()


def build_raw_response(
    protocol: RoborockMessageProtocol,
    seq: int,
    payload: bytes,
    version: LocalProtocolVersion = LocalProtocolVersion.V1,
    connect_nonce: int | None = None,
    ack_nonce: int | None = None,
) -> bytes:
    """Build an encoded response message."""
    message = RoborockMessage(
        protocol=protocol,
        random=23,
        seq=seq,
        payload=payload,
        version=version.value.encode(),
    )
    return MessageParser.build(message, local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=ack_nonce)


async def test_connect(
    local_channel: LocalChannel,
    local_response_queue: asyncio.Queue[bytes],
    local_received_requests: asyncio.Queue[bytes],
    log: CapturedRequestLog,
    snapshot: syrupy.SnapshotAssertion,
) -> None:
    """Test connecting to the device."""
    # Queue HELLO response with payload to ensure it can be parsed
    local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok"))

    await local_channel.connect()

    assert local_channel.is_connected
    assert local_received_requests.qsize() == 1

    # Verify HELLO request
    request_bytes = await local_received_requests.get()
    # Note: We cannot use create_local_decoder here because HELLO_REQUEST has payload=None
    # which causes MessageParser to fail parsing. For now we verify the raw bytes.

    # Protocol is at offset 19 (2 bytes)
    # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
    assert len(request_bytes) >= 21
    protocol_bytes = request_bytes[19:21]
    assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST

    assert snapshot == log


async def test_send_command(
    local_channel: LocalChannel,
    local_response_queue: asyncio.Queue[bytes],
    local_received_requests: asyncio.Queue[bytes],
    log: CapturedRequestLog,
    snapshot: syrupy.SnapshotAssertion,
) -> None:
    """Test sending a command."""
    # Queue HELLO response
    local_response_queue.put_nowait(build_raw_response(RoborockMessageProtocol.HELLO_RESPONSE, 1, payload=b"ok"))

    await local_channel.connect()

    # Clear requests from handshake
    while not local_received_requests.empty():
        await local_received_requests.get()

    # Send command
    cmd_seq = 123
    msg = RoborockMessage(
        protocol=RoborockMessageProtocol.RPC_REQUEST,
        seq=cmd_seq,
        payload=b'{"method":"get_status"}',
    )
    # Prepare a fake response to the command.
    local_response_queue.put_nowait(
        build_raw_response(RoborockMessageProtocol.RPC_RESPONSE, cmd_seq, payload=b'{"status": "ok"}')
    )

    subscriber = Subscriber()
    unsub = await local_channel.subscribe(subscriber.append)

    await local_channel.publish(msg)

    # Verify request received by the server
    request_bytes = await local_received_requests.get()
    assert local_received_requests.empty()

    # Decode request
    decoder = create_local_decoder(local_key=LOCAL_KEY)
    msgs = list(decoder(request_bytes))
    assert len(msgs) == 1
    assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
    assert msgs[0].payload == b'{"method":"get_status"}'
    assert msgs[0].version == LocalProtocolVersion.V1.value.encode()

    # Verify response received by subscriber
    await subscriber.wait()
    assert len(subscriber.messages) == 1
    response_message = subscriber.messages[0]
    assert isinstance(response_message, RoborockMessage)
    assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
    assert response_message.payload == b'{"status": "ok"}'

    unsub()

    assert snapshot == log


async def test_l01_session(
    local_channel: LocalChannel,
    local_response_queue: asyncio.Queue[bytes],
    local_received_requests: asyncio.Queue[bytes],
    log: CapturedRequestLog,
    snapshot: syrupy.SnapshotAssertion,
) -> None:
    """Test connecting to a device that speaks the L01 protocol.

    Note that this test currently has a delay because the actual local client
    will delay before retrying with L01 after a failed 1.0 attempt. This should
    also be improved in the actual client itself, but likely requires a closer
    look at the actual device response in that scenario or moving to a serial
    request/response behavior rather than publish/subscribe.
    """
    # Client first attempts 1.0 and we reply with a fake invalid response. The
    # response is arbitrary, and this could be improved by capturing a real L01
    # device response to a 1.0 message.
    local_response_queue.put_nowait(b"\x00")
    # The client attempts L01 protocol as a followup. The connect nonce uses
    # a deterministic number from deterministic_message_fixtures.
    connect_nonce = 9090
    local_response_queue.put_nowait(
        build_raw_response(
            RoborockMessageProtocol.HELLO_RESPONSE,
            1,
            payload=b"ok",
            version=LocalProtocolVersion.L01,
            connect_nonce=connect_nonce,
            ack_nonce=None,
        )
    )

    await local_channel.connect()

    assert local_channel.is_connected

    # Verify 1.0 HELLO request
    request_bytes = await local_received_requests.get()
    # Protocol is at offset 19 (2 bytes)
    # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
    assert len(request_bytes) >= 21
    protocol_bytes = request_bytes[19:21]
    assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST

    # Verify L01 HELLO request
    request_bytes = await local_received_requests.get()
    # Protocol is at offset 19 (2 bytes)
    # Prefix(4) + Version(3) + Seq(4) + Random(4) + Timestamp(4) = 19
    assert len(request_bytes) >= 21
    protocol_bytes = request_bytes[19:21]
    assert int.from_bytes(protocol_bytes, "big") == RoborockMessageProtocol.HELLO_REQUEST

    assert local_received_requests.empty()

    # Verify the channel switched to L01 protocol
    assert local_channel.protocol_version == LocalProtocolVersion.L01.value

    # We have established a connection. Now send some messages.
    # Publish an L01 command. Currently the caller of the local channel needs to
    # determine the protocol version to use, but this could be pushed inside of
    # the channel in the future.
    cmd_seq = 123
    msg = RoborockMessage(
        protocol=RoborockMessageProtocol.RPC_REQUEST,
        seq=cmd_seq,
        payload=b'{"method":"get_status"}',
        version=b"L01",
    )
    # Prepare a fake response to the command.
    local_response_queue.put_nowait(
        build_raw_response(
            RoborockMessageProtocol.RPC_RESPONSE,
            cmd_seq,
            payload=b'{"status": "ok"}',
            version=LocalProtocolVersion.L01,
            connect_nonce=connect_nonce,
            ack_nonce=TEST_RANDOM,
        )
    )

    # Set up a subscriber to listen for the response then publish the message.
    subscriber = Subscriber()
    unsub = await local_channel.subscribe(subscriber.append)
    await local_channel.publish(msg)

    # Verify request received by the server
    request_bytes = await local_received_requests.get()
    decoder = create_local_decoder(local_key=LOCAL_KEY, connect_nonce=connect_nonce, ack_nonce=TEST_RANDOM)
    msgs = list(decoder(request_bytes))
    assert len(msgs) == 1
    assert msgs[0].protocol == RoborockMessageProtocol.RPC_REQUEST
    assert msgs[0].payload == b'{"method":"get_status"}'
    assert msgs[0].version == LocalProtocolVersion.L01.value.encode()

    # Verify fake response published by the server, received by subscriber
    await subscriber.wait()
    assert len(subscriber.messages) == 1
    response_message = subscriber.messages[0]
    assert isinstance(response_message, RoborockMessage)
    assert response_message.protocol == RoborockMessageProtocol.RPC_RESPONSE
    assert response_message.payload == b'{"status": "ok"}'
    assert response_message.version == LocalProtocolVersion.L01.value.encode()

    unsub()

    assert snapshot == log