1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
from __future__ import annotations
import asyncio
import hashlib
import json
import logging
from asyncio import BaseTransport, Lock
from construct import ( # type: ignore
Bytes,
Checksum,
GreedyBytes,
Int16ub,
Int32ub,
Prefixed,
RawCopy,
Struct,
)
from Crypto.Cipher import AES
from roborock import RoborockException
from roborock.containers import BroadcastMessage
from roborock.protocol import EncryptionAdapter, Utils, _Parser
_LOGGER = logging.getLogger(__name__)
BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe"
class RoborockProtocol(asyncio.DatagramProtocol):
def __init__(self, timeout: int = 5):
self.timeout = timeout
self.transport: BaseTransport | None = None
self.devices_found: list[BroadcastMessage] = []
self._mutex = Lock()
def datagram_received(self, data: bytes, _):
"""Handle incoming broadcast datagrams."""
try:
version = data[:3]
if version == b"L01":
[parsed_msg], _ = L01Parser.parse(data)
encrypted_payload = parsed_msg.payload
if encrypted_payload is None:
raise RoborockException("No encrypted payload found in broadcast message")
ciphertext = encrypted_payload[:-16]
tag = encrypted_payload[-16:]
key = hashlib.sha256(BROADCAST_TOKEN).digest()
iv_digest_input = data[:9]
digest = hashlib.sha256(iv_digest_input).digest()
iv = digest[:12]
cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
decrypted_payload_bytes = cipher.decrypt_and_verify(ciphertext, tag)
json_payload = json.loads(decrypted_payload_bytes)
parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version)
_LOGGER.debug(f"Received L01 broadcast: {parsed_message}")
self.devices_found.append(parsed_message)
else:
# Fallback to the original protocol parser for other versions
[broadcast_message], _ = BroadcastParser.parse(data)
if broadcast_message.payload:
json_payload = json.loads(broadcast_message.payload)
parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version)
_LOGGER.debug(f"Received broadcast: {parsed_message}")
self.devices_found.append(parsed_message)
except Exception as e:
_LOGGER.warning(f"Failed to decode message: {data!r}. Error: {e}")
async def discover(self) -> list[BroadcastMessage]:
async with self._mutex:
try:
loop = asyncio.get_event_loop()
self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866))
await asyncio.sleep(self.timeout)
return self.devices_found
finally:
self.close()
self.devices_found = []
def close(self):
self.transport.close() if self.transport else None
_BroadcastMessage = Struct(
"message"
/ RawCopy(
Struct(
"version" / Bytes(3),
"seq" / Int32ub,
"protocol" / Int16ub,
"payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN),
)
),
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)
_L01BroadcastMessage = Struct(
"message"
/ RawCopy(
Struct(
"version" / Bytes(3),
"field1" / Bytes(4), # Unknown field
"field2" / Bytes(2), # Unknown field
"payload" / Prefixed(Int16ub, GreedyBytes), # Encrypted payload with length prefix
)
),
"checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)
BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
L01Parser: _Parser = _Parser(_L01BroadcastMessage, False)
|