File: a01_protocol.py

package info (click to toggle)
python-roborock 4.12.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,480 kB
  • sloc: python: 16,602; makefile: 17; sh: 6
file content (74 lines) | stat: -rw-r--r-- 2,396 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Roborock A01 Protocol encoding and decoding."""

import json
import logging
from collections.abc import Callable
from typing import Any

from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad

from roborock.exceptions import RoborockException
from roborock.roborock_message import (
    RoborockDyadDataProtocol,
    RoborockMessage,
    RoborockMessageProtocol,
    RoborockZeoProtocol,
)

_LOGGER = logging.getLogger(__name__)

A01_VERSION = b"A01"


def _no_encode(value: Any) -> Any:
    return value


def encode_mqtt_payload(
    data: dict[RoborockDyadDataProtocol, Any]
    | dict[RoborockZeoProtocol, Any]
    | dict[RoborockDyadDataProtocol | RoborockZeoProtocol, Any],
    value_encoder: Callable[[Any], Any] | None = None,
) -> RoborockMessage:
    """Encode payload for A01 commands over MQTT.

    Args:
        data: The data to encode.
        value_encoder: A function to encode the values of the dictionary.

    Returns:
        RoborockMessage: The encoded message.
    """
    if value_encoder is None:
        value_encoder = _no_encode
    dps_data = {"dps": {key: value_encoder(value) for key, value in data.items()}}
    payload = pad(json.dumps(dps_data).encode("utf-8"), AES.block_size)
    return RoborockMessage(
        protocol=RoborockMessageProtocol.RPC_REQUEST,
        version=A01_VERSION,
        payload=payload,
    )


def decode_rpc_response(message: RoborockMessage) -> dict[int, Any]:
    """Decode a V1 RPC_RESPONSE message."""
    if not message.payload:
        raise RoborockException("Invalid A01 message format: missing payload")
    try:
        unpadded = unpad(message.payload, AES.block_size)
    except ValueError as err:
        raise RoborockException(f"Unable to unpad A01 payload: {err}")

    try:
        payload = json.loads(unpadded.decode())
    except (json.JSONDecodeError, TypeError) as e:
        raise RoborockException(f"Invalid A01 message payload: {e} for {message.payload!r}") from e

    datapoints = payload.get("dps", {})
    if not isinstance(datapoints, dict):
        raise RoborockException(f"Invalid A01 message format: 'dps' should be a dictionary for {message.payload!r}")
    try:
        return {int(key): value for key, value in datapoints.items()}
    except ValueError:
        raise RoborockException(f"Invalid A01 message format: 'dps' key should be an integer for {message.payload!r}")