File: test_a01_protocol.py

package info (click to toggle)
python-roborock 2.39.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,128 kB
  • sloc: python: 10,342; makefile: 17
file content (218 lines) | stat: -rw-r--r-- 7,782 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Tests for A01 protocol encoding and decoding."""

import json
from typing import Any

import pytest

from roborock.exceptions import RoborockException
from roborock.protocols.a01_protocol import decode_rpc_response, encode_mqtt_payload
from roborock.roborock_message import (
    RoborockDyadDataProtocol,
    RoborockMessage,
    RoborockMessageProtocol,
    RoborockZeoProtocol,
)


def test_encode_mqtt_payload_basic():
    """Test basic MQTT payload encoding."""
    # Test data with proper protocol keys
    data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = {
        RoborockDyadDataProtocol.START: {"test": "data", "number": 42}
    }

    result = encode_mqtt_payload(data)

    # Verify result is a RoborockMessage
    assert isinstance(result, RoborockMessage)
    assert result.protocol == RoborockMessageProtocol.RPC_REQUEST
    assert result.version == b"A01"
    assert result.payload is not None
    assert isinstance(result.payload, bytes)
    assert len(result.payload) % 16 == 0  # Should be padded to AES block size

    # Decode the payload to verify structure
    decoded_data = decode_rpc_response(result)
    assert decoded_data == {200: {"test": "data", "number": 42}}


def test_encode_mqtt_payload_empty_data():
    """Test encoding with empty data."""
    data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = {}

    result = encode_mqtt_payload(data)

    assert isinstance(result, RoborockMessage)
    assert result.protocol == RoborockMessageProtocol.RPC_REQUEST
    assert result.payload is not None

    # Decode the payload to verify structure
    decoded_data = decode_rpc_response(result)
    assert decoded_data == {}


def test_encode_mqtt_payload_complex_data():
    """Test encoding with complex nested data."""
    data: dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any] = {
        RoborockDyadDataProtocol.STATUS: {
            "nested": {"deep": {"value": 123}},
            "list": [1, 2, 3, "test"],
            "boolean": True,
            "null": None,
        },
        RoborockZeoProtocol.MODE: "simple_value",
    }

    result = encode_mqtt_payload(data)

    assert isinstance(result, RoborockMessage)
    assert result.protocol == RoborockMessageProtocol.RPC_REQUEST
    assert result.payload is not None
    assert isinstance(result.payload, bytes)

    # Decode the payload to verify structure
    decoded_data = decode_rpc_response(result)
    assert decoded_data == {
        201: {
            "nested": {"deep": {"value": 123}},
            "list": [1, 2, 3, "test"],
            "boolean": True,
            "null": None,
        },
        204: "simple_value",
    }


def test_decode_rpc_response_valid_message():
    """Test decoding a valid RPC response."""
    # Create a valid padded JSON payload
    payload_data = {"dps": {"1": {"key": "value"}, "2": 42, "10": ["list", "data"]}}
    json_payload = json.dumps(payload_data).encode("utf-8")

    # Pad to AES block size (16 bytes)
    padding_length = 16 - (len(json_payload) % 16)
    padded_payload = json_payload + bytes([padding_length] * padding_length)

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=padded_payload)

    result = decode_rpc_response(message)

    assert isinstance(result, dict)
    assert 1 in result
    assert 2 in result
    assert 10 in result
    assert result[1] == {"key": "value"}
    assert result[2] == 42
    assert result[10] == ["list", "data"]


def test_decode_rpc_response_string_keys():
    """Test decoding with string keys that can be converted to integers."""
    payload_data = {"dps": {"1": "first", "100": "hundred", "999": {"nested": "data"}}}
    json_payload = json.dumps(payload_data).encode("utf-8")

    # Pad to AES block size
    padding_length = 16 - (len(json_payload) % 16)
    padded_payload = json_payload + bytes([padding_length] * padding_length)

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=padded_payload)

    result = decode_rpc_response(message)

    assert result[1] == "first"
    assert result[100] == "hundred"
    assert result[999] == {"nested": "data"}


def test_decode_rpc_response_missing_payload():
    """Test decoding fails when payload is missing."""
    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=None)

    with pytest.raises(RoborockException, match="Invalid A01 message format: missing payload"):
        decode_rpc_response(message)


def test_decode_rpc_response_invalid_padding():
    """Test decoding fails with invalid padding."""
    # Create invalid padded data
    invalid_payload = b"invalid padding data"

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=invalid_payload)

    with pytest.raises(RoborockException, match="Unable to unpad A01 payload"):
        decode_rpc_response(message)


def test_decode_rpc_response_invalid_json():
    """Test decoding fails with invalid JSON after unpadding."""
    # Create properly padded but invalid JSON
    invalid_json = b"invalid json data"
    padding_length = 16 - (len(invalid_json) % 16)
    padded_payload = invalid_json + bytes([padding_length] * padding_length)

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=padded_payload)

    with pytest.raises(RoborockException, match="Invalid A01 message payload"):
        decode_rpc_response(message)


def test_decode_rpc_response_missing_dps():
    """Test decoding with missing 'dps' key returns empty dict."""
    payload_data = {"other_key": "value"}
    json_payload = json.dumps(payload_data).encode("utf-8")

    # Pad to AES block size
    padding_length = 16 - (len(json_payload) % 16)
    padded_payload = json_payload + bytes([padding_length] * padding_length)

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=padded_payload)

    result = decode_rpc_response(message)
    assert result == {}


def test_decode_rpc_response_dps_not_dict():
    """Test decoding fails when 'dps' is not a dictionary."""
    payload_data = {"dps": "not_a_dict"}
    json_payload = json.dumps(payload_data).encode("utf-8")

    # Pad to AES block size
    padding_length = 16 - (len(json_payload) % 16)
    padded_payload = json_payload + bytes([padding_length] * padding_length)

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=padded_payload)

    with pytest.raises(RoborockException, match=r"Invalid A01 message format.*'dps' should be a dictionary"):
        decode_rpc_response(message)


def test_decode_rpc_response_invalid_key():
    """Test decoding fails when dps contains non-integer keys."""
    payload_data = {"dps": {"1": "valid", "not_a_number": "invalid"}}
    json_payload = json.dumps(payload_data).encode("utf-8")

    # Pad to AES block size
    padding_length = 16 - (len(json_payload) % 16)
    padded_payload = json_payload + bytes([padding_length] * padding_length)

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=padded_payload)

    with pytest.raises(RoborockException, match=r"Invalid A01 message format:.*'dps' key should be an integer"):
        decode_rpc_response(message)


def test_decode_rpc_response_empty_dps():
    """Test decoding with empty dps dictionary."""
    payload_data: dict[str, Any] = {"dps": {}}
    json_payload = json.dumps(payload_data).encode("utf-8")

    # Pad to AES block size
    padding_length = 16 - (len(json_payload) % 16)
    padded_payload = json_payload + bytes([padding_length] * padding_length)

    message = RoborockMessage(protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=padded_payload)

    result = decode_rpc_response(message)

    assert result == {}