1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
|
"""Roborock V1 Protocol Encoder."""
from __future__ import annotations
import base64
import json
import logging
import math
import secrets
import struct
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Any, Protocol, TypeVar, overload
from roborock.data import RoborockBase, RRiot
from roborock.exceptions import RoborockException, RoborockUnsupportedFeature
from roborock.protocol import Utils
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from roborock.roborock_typing import RoborockCommand
from roborock.util import get_next_int
_LOGGER = logging.getLogger(__name__)
__all__ = [
"SecurityData",
"create_security_data",
"decode_rpc_response",
"V1RpcChannel",
]
CommandType = RoborockCommand | str
ParamsType = list | dict | int | None
class LocalProtocolVersion(StrEnum):
"""Supported local protocol versions. Different from vacuum protocol versions."""
L01 = "L01"
V1 = "1.0"
@dataclass(frozen=True, kw_only=True)
class SecurityData:
"""Security data included in the request for some V1 commands."""
endpoint: str
nonce: bytes
def to_dict(self) -> dict[str, Any]:
"""Convert security data to a dictionary for sending in the payload."""
return {"security": {"endpoint": self.endpoint, "nonce": self.nonce.hex().lower()}}
def to_diagnostic_data(self) -> dict[str, Any]:
"""Convert security data to a dictionary for debugging purposes."""
return {"nonce": self.nonce.hex().lower()}
def create_security_data(rriot: RRiot) -> SecurityData:
"""Create a SecurityData instance for the given endpoint and nonce."""
nonce = secrets.token_bytes(16)
endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode()
return SecurityData(endpoint=endpoint, nonce=nonce)
@dataclass
class RequestMessage:
"""Data structure for v1 RoborockMessage payloads."""
method: RoborockCommand | str
params: ParamsType
timestamp: int = field(default_factory=lambda: math.floor(time.time()))
request_id: int = field(default_factory=lambda: get_next_int(10000, 32767))
def encode_message(
self,
protocol: RoborockMessageProtocol,
security_data: SecurityData | None = None,
version: LocalProtocolVersion = LocalProtocolVersion.V1,
) -> RoborockMessage:
"""Convert the request message to a RoborockMessage."""
return RoborockMessage(
timestamp=self.timestamp,
protocol=protocol,
payload=self._as_payload(security_data=security_data),
version=version.value.encode(),
)
def _as_payload(self, security_data: SecurityData | None) -> bytes:
"""Convert the request arguments to a dictionary."""
inner = {
"id": self.request_id,
"method": self.method,
"params": self.params or [],
**(security_data.to_dict() if security_data else {}),
}
return bytes(
json.dumps(
{
"dps": {"101": json.dumps(inner, separators=(",", ":"))},
"t": self.timestamp,
},
separators=(",", ":"),
).encode()
)
ResponseData = dict[str, Any] | list | int
@dataclass(kw_only=True, frozen=True)
class ResponseMessage:
"""Data structure for v1 RoborockMessage responses."""
request_id: int | None
"""The request ID of the response."""
data: ResponseData
"""The data of the response, where the type depends on the command."""
api_error: RoborockException | None = None
"""The API error message of the response if any."""
def decode_rpc_response(message: RoborockMessage) -> ResponseMessage:
"""Decode a V1 RPC_RESPONSE message.
This will raise a RoborockException if the message cannot be parsed. A
response object will be returned even if there is an error in the
response, as long as we can extract the request ID. This is so we can
associate an API response with a request even if there was an error.
"""
if not message.payload:
return ResponseMessage(request_id=message.seq, data={})
try:
payload = json.loads(message.payload.decode())
except (json.JSONDecodeError, TypeError) as e:
raise RoborockException(f"Invalid V1 message payload: {e} for {message.payload!r}") from e
_LOGGER.debug("Decoded V1 message payload: %s", payload)
datapoints = payload.get("dps", {})
if not isinstance(datapoints, dict):
raise RoborockException(f"Invalid V1 message format: 'dps' should be a dictionary for {message.payload!r}")
if not (data_point := datapoints.get(str(RoborockMessageProtocol.RPC_RESPONSE))):
raise RoborockException(
f"Invalid V1 message format: missing '{RoborockMessageProtocol.RPC_RESPONSE}' data point"
)
try:
data_point_response = json.loads(data_point)
except (json.JSONDecodeError, TypeError) as e:
raise RoborockException(
f"Invalid V1 message data point '{RoborockMessageProtocol.RPC_RESPONSE}': {e} for {message.payload!r}"
) from e
request_id: int | None = data_point_response.get("id")
exc: RoborockException | None = None
if error := data_point_response.get("error"):
exc = RoborockException(error)
if (result := data_point_response.get("result")) is None:
exc = RoborockException(f"Invalid V1 message format: missing 'result' in data point for {message.payload!r}")
else:
_LOGGER.debug("Decoded V1 message result: %s", result)
if isinstance(result, str):
if result == "unknown_method":
exc = RoborockUnsupportedFeature("The method called is not recognized by the device.")
elif result != "ok":
exc = RoborockException(f"Unexpected API Result: {result}")
result = {}
if not isinstance(result, dict | list | int):
raise RoborockException(
f"Invalid V1 message format: 'result' was unexpected type {type(result)}. {message.payload!r}"
)
if not request_id and exc:
raise exc
return ResponseMessage(request_id=request_id, data=result, api_error=exc)
@dataclass
class MapResponse:
"""Data structure for the V1 Map response."""
request_id: int
"""The request ID of the map response."""
data: bytes
"""The map data, decrypted and decompressed."""
def create_map_response_decoder(security_data: SecurityData) -> Callable[[RoborockMessage], MapResponse | None]:
"""Create a decoder for V1 map response messages."""
def _decode_map_response(message: RoborockMessage) -> MapResponse | None:
"""Decode a V1 map response message."""
if not message.payload or len(message.payload) < 24:
raise RoborockException("Invalid V1 map response format: missing payload")
header, body = message.payload[:24], message.payload[24:]
[endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", header)
if not endpoint.decode().startswith(security_data.endpoint):
_LOGGER.debug("Received map response not requested by this device, ignoring.")
return None
try:
decrypted = Utils.decrypt_cbc(body, security_data.nonce)
except ValueError as err:
raise RoborockException("Failed to decode map message payload") from err
decompressed = Utils.decompress(decrypted)
return MapResponse(request_id=request_id, data=decompressed)
return _decode_map_response
_T = TypeVar("_T", bound=RoborockBase)
class V1RpcChannel(Protocol):
"""Protocol for V1 RPC channels.
This is a wrapper around a raw channel that provides a high-level interface
for sending commands and receiving responses.
"""
@overload
async def send_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a command and return a decoded response."""
...
@overload
async def send_command(
self,
method: CommandType,
*,
response_type: type[_T],
params: ParamsType = None,
) -> _T:
"""Send a command and return a parsed response RoborockBase type."""
...
|