File: test_b01_q07_protocol.py

package info (click to toggle)
python-roborock 4.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,480 kB
  • sloc: python: 16,602; makefile: 17; sh: 6
file content (71 lines) | stat: -rw-r--r-- 2,415 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Tests for the B01 protocol message encoding and decoding."""

import json
import pathlib
from collections.abc import Generator

import pytest
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
from freezegun import freeze_time
from syrupy import SnapshotAssertion

from roborock.protocols.b01_q7_protocol import Q7RequestMessage, decode_rpc_response, encode_mqtt_payload
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol

TESTDATA_PATH = pathlib.Path("tests/protocols/testdata/b01_q7_protocol")
TESTDATA_FILES = list(TESTDATA_PATH.glob("*.json"))
TESTDATA_IDS = [x.stem for x in TESTDATA_FILES]


@pytest.fixture(autouse=True)
def fixed_time_fixture() -> Generator[None, None, None]:
    """Fixture to freeze time for predictable request IDs."""
    with freeze_time("2025-01-20T12:00:00"):
        yield


@pytest.mark.parametrize("filename", TESTDATA_FILES, ids=TESTDATA_IDS)
def test_decode_rpc_payload(filename: str, snapshot: SnapshotAssertion) -> None:
    """Test decoding a B01 RPC response protocol message."""
    with open(filename, "rb") as f:
        payload = f.read()

    message = RoborockMessage(
        protocol=RoborockMessageProtocol.RPC_RESPONSE,
        payload=payload,
        seq=12750,
        version=b"B01",
        random=97431,
        timestamp=1652547161,
    )

    decoded_message = decode_rpc_response(message)
    assert json.dumps(decoded_message, indent=2) == snapshot


@pytest.mark.parametrize(
    ("dps", "command", "params", "msg_id"),
    [
        (
            10000,
            "prop.get",
            {"property": ["status", "fault"]},
            123456789,
        ),
    ],
)
def test_encode_mqtt_payload(dps: int, command: str, params: dict[str, list[str]], msg_id: int) -> None:
    """Test encoding of MQTT payload for B01 commands."""

    message = encode_mqtt_payload(Q7RequestMessage(dps, command, params, msg_id))
    assert isinstance(message, RoborockMessage)
    assert message.protocol == RoborockMessageProtocol.RPC_REQUEST
    assert message.version == b"B01"
    assert message.payload is not None
    unpadded = unpad(message.payload, AES.block_size)
    decoded_json = json.loads(unpadded.decode("utf-8"))

    assert decoded_json["dps"][str(dps)]["method"] == command
    assert decoded_json["dps"][str(dps)]["msgId"] == str(msg_id)
    assert decoded_json["dps"][str(dps)]["params"] == params