File: broadcast_protocol.py

package info (click to toggle)
python-roborock 2.49.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,360 kB
  • sloc: python: 11,539; makefile: 17
file content (114 lines) | stat: -rw-r--r-- 4,057 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations

import asyncio
import hashlib
import json
import logging
from asyncio import BaseTransport, Lock

from construct import (  # type: ignore
    Bytes,
    Checksum,
    GreedyBytes,
    Int16ub,
    Int32ub,
    Prefixed,
    RawCopy,
    Struct,
)
from Crypto.Cipher import AES

from roborock import RoborockException
from roborock.containers import BroadcastMessage
from roborock.protocol import EncryptionAdapter, Utils, _Parser

_LOGGER = logging.getLogger(__name__)

BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe"


class RoborockProtocol(asyncio.DatagramProtocol):
    def __init__(self, timeout: int = 5):
        self.timeout = timeout
        self.transport: BaseTransport | None = None
        self.devices_found: list[BroadcastMessage] = []
        self._mutex = Lock()

    def datagram_received(self, data: bytes, _):
        """Handle incoming broadcast datagrams."""
        try:
            version = data[:3]
            if version == b"L01":
                [parsed_msg], _ = L01Parser.parse(data)
                encrypted_payload = parsed_msg.payload
                if encrypted_payload is None:
                    raise RoborockException("No encrypted payload found in broadcast message")
                ciphertext = encrypted_payload[:-16]
                tag = encrypted_payload[-16:]

                key = hashlib.sha256(BROADCAST_TOKEN).digest()
                iv_digest_input = data[:9]
                digest = hashlib.sha256(iv_digest_input).digest()
                iv = digest[:12]

                cipher = AES.new(key, AES.MODE_GCM, nonce=iv)
                decrypted_payload_bytes = cipher.decrypt_and_verify(ciphertext, tag)
                json_payload = json.loads(decrypted_payload_bytes)
                parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version)
                _LOGGER.debug(f"Received L01 broadcast: {parsed_message}")
                self.devices_found.append(parsed_message)
            else:
                # Fallback to the original protocol parser for other versions
                [broadcast_message], _ = BroadcastParser.parse(data)
                if broadcast_message.payload:
                    json_payload = json.loads(broadcast_message.payload)
                    parsed_message = BroadcastMessage(duid=json_payload["duid"], ip=json_payload["ip"], version=version)
                    _LOGGER.debug(f"Received broadcast: {parsed_message}")
                    self.devices_found.append(parsed_message)
        except Exception as e:
            _LOGGER.warning(f"Failed to decode message: {data!r}. Error: {e}")

    async def discover(self) -> list[BroadcastMessage]:
        async with self._mutex:
            try:
                loop = asyncio.get_event_loop()
                self.transport, _ = await loop.create_datagram_endpoint(lambda: self, local_addr=("0.0.0.0", 58866))
                await asyncio.sleep(self.timeout)
                return self.devices_found
            finally:
                self.close()
                self.devices_found = []

    def close(self):
        self.transport.close() if self.transport else None


_BroadcastMessage = Struct(
    "message"
    / RawCopy(
        Struct(
            "version" / Bytes(3),
            "seq" / Int32ub,
            "protocol" / Int16ub,
            "payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN),
        )
    ),
    "checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)

_L01BroadcastMessage = Struct(
    "message"
    / RawCopy(
        Struct(
            "version" / Bytes(3),
            "field1" / Bytes(4),  # Unknown field
            "field2" / Bytes(2),  # Unknown field
            "payload" / Prefixed(Int16ub, GreedyBytes),  # Encrypted payload with length prefix
        )
    ),
    "checksum" / Checksum(Int32ub, Utils.crc, lambda ctx: ctx.message.data),
)


BroadcastParser: _Parser = _Parser(_BroadcastMessage, False)
L01Parser: _Parser = _Parser(_L01BroadcastMessage, False)