File: local_channel.py

package info (click to toggle)
python-roborock 2.49.1-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,360 kB
  • sloc: python: 11,539; makefile: 17
file content (131 lines) | stat: -rw-r--r-- 5,075 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Module for communicating with Roborock devices over a local network."""

import asyncio
import logging
from collections.abc import Callable
from dataclasses import dataclass

from roborock.callbacks import CallbackList, decoder_callback
from roborock.exceptions import RoborockConnectionException, RoborockException
from roborock.protocol import Decoder, Encoder, create_local_decoder, create_local_encoder
from roborock.roborock_message import RoborockMessage

from .channel import Channel

_LOGGER = logging.getLogger(__name__)
_PORT = 58867


@dataclass
class _LocalProtocol(asyncio.Protocol):
    """Callbacks for the Roborock local client transport."""

    messages_cb: Callable[[bytes], None]
    connection_lost_cb: Callable[[Exception | None], None]

    def data_received(self, data: bytes) -> None:
        """Called when data is received from the transport."""
        self.messages_cb(data)

    def connection_lost(self, exc: Exception | None) -> None:
        """Called when the transport connection is lost."""
        self.connection_lost_cb(exc)


class LocalChannel(Channel):
    """Simple RPC-style channel for communicating with a device over a local network.

    Handles request/response correlation and timeouts, but leaves message
    format most parsing to higher-level components.
    """

    def __init__(self, host: str, local_key: str):
        self._host = host
        self._transport: asyncio.Transport | None = None
        self._protocol: _LocalProtocol | None = None
        self._subscribers: CallbackList[RoborockMessage] = CallbackList(_LOGGER)
        self._is_connected = False

        self._decoder: Decoder = create_local_decoder(local_key)
        self._encoder: Encoder = create_local_encoder(local_key)
        # Callback to decode messages and dispatch to subscribers
        self._data_received: Callable[[bytes], None] = decoder_callback(self._decoder, self._subscribers, _LOGGER)

    @property
    def is_connected(self) -> bool:
        """Check if the channel is currently connected."""
        return self._is_connected

    async def connect(self) -> None:
        """Connect to the device."""
        if self._is_connected:
            _LOGGER.warning("Already connected")
            return
        _LOGGER.debug("Connecting to %s:%s", self._host, _PORT)
        loop = asyncio.get_running_loop()
        protocol = _LocalProtocol(self._data_received, self._connection_lost)
        try:
            self._transport, self._protocol = await loop.create_connection(lambda: protocol, self._host, _PORT)
            self._is_connected = True
        except OSError as e:
            raise RoborockConnectionException(f"Failed to connect to {self._host}:{_PORT}") from e

    def close(self) -> None:
        """Disconnect from the device."""
        if self._transport:
            self._transport.close()
        else:
            _LOGGER.warning("Close called but transport is already None")
        self._transport = None
        self._is_connected = False

    def _connection_lost(self, exc: Exception | None) -> None:
        """Handle connection loss."""
        _LOGGER.warning("Connection lost to %s", self._host, exc_info=exc)
        self._transport = None
        self._is_connected = False

    async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]:
        """Subscribe to all messages from the device."""
        return self._subscribers.add_callback(callback)

    async def publish(self, message: RoborockMessage) -> None:
        """Send a command message.

        The caller is responsible for associating the message with its response.
        """
        if not self._transport or not self._is_connected:
            raise RoborockConnectionException("Not connected to device")

        try:
            encoded_msg = self._encoder(message)
        except Exception as err:
            _LOGGER.exception("Error encoding MQTT message: %s", err)
            raise RoborockException(f"Failed to encode MQTT message: {err}") from err
        try:
            self._transport.write(encoded_msg)
        except Exception as err:
            logging.exception("Uncaught error sending command")
            raise RoborockException(f"Failed to send message: {message}") from err


# This module provides a factory function to create LocalChannel instances.
#
# TODO: Make a separate LocalSession and use it to manage retries with the host,
# similar to how MqttSession works. For now this is a simple factory function
# for creating channels.
LocalSession = Callable[[str], LocalChannel]


def create_local_session(local_key: str) -> LocalSession:
    """Creates a local session which can create local channels.

    This plays a role similar to the MqttSession but is really just a factory
    for creating LocalChannel instances with the same local key.
    """

    def create_local_channel(host: str) -> LocalChannel:
        """Create a LocalChannel instance for the given host."""
        return LocalChannel(host, local_key)

    return create_local_channel