1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
|
"""V1 Rpc Channel for Roborock devices.
This is a wrapper around the V1 channel that provides a higher level interface
for sending typed commands and receiving typed responses. This also provides
a simple interface for sending commands and receiving responses over both MQTT
and local connections, preferring local when available.
"""
import asyncio
import logging
from collections.abc import Callable
from typing import Any, Protocol, TypeVar, overload
from roborock.containers import RoborockBase
from roborock.exceptions import RoborockException
from roborock.protocols.v1_protocol import (
CommandType,
ParamsType,
RequestMessage,
ResponseData,
SecurityData,
decode_rpc_response,
)
from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol
from .local_channel import LocalChannel
from .mqtt_channel import MqttChannel
_LOGGER = logging.getLogger(__name__)
_TIMEOUT = 10.0
_T = TypeVar("_T", bound=RoborockBase)
class V1RpcChannel(Protocol):
"""Protocol for V1 RPC channels.
This is a wrapper around a raw channel that provides a high-level interface
for sending commands and receiving responses.
"""
@overload
async def send_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a command and return a decoded response."""
...
@overload
async def send_command(
self,
method: CommandType,
*,
response_type: type[_T],
params: ParamsType = None,
) -> _T:
"""Send a command and return a parsed response RoborockBase type."""
...
class BaseV1RpcChannel(V1RpcChannel):
"""Base implementation that provides the typed response logic."""
async def send_command(
self,
method: CommandType,
*,
response_type: type[_T] | None = None,
params: ParamsType = None,
) -> _T | Any:
"""Send a command and return either a decoded or parsed response."""
decoded_response = await self._send_raw_command(method, params=params)
if response_type is not None:
return response_type.from_dict(decoded_response)
return decoded_response
async def _send_raw_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a raw command and return the decoded response. Must be implemented by subclasses."""
raise NotImplementedError
class PickFirstAvailable(BaseV1RpcChannel):
"""A V1 RPC channel that tries multiple channels and picks the first that works."""
def __init__(
self,
channel_cbs: list[Callable[[], V1RpcChannel | None]],
) -> None:
"""Initialize the pick-first-available channel."""
self._channel_cbs = channel_cbs
async def _send_raw_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> Any:
"""Send a command and return a parsed response RoborockBase type."""
for channel_cb in self._channel_cbs:
if channel := channel_cb():
return await channel.send_command(method, params=params)
raise RoborockException("No available connection to send command")
class PayloadEncodedV1RpcChannel(BaseV1RpcChannel):
"""Protocol for V1 channels that send encoded commands."""
def __init__(
self,
name: str,
channel: MqttChannel | LocalChannel,
payload_encoder: Callable[[RequestMessage], RoborockMessage],
) -> None:
"""Initialize the channel with a raw channel and an encoder function."""
self._name = name
self._channel = channel
self._payload_encoder = payload_encoder
async def _send_raw_command(
self,
method: CommandType,
*,
params: ParamsType = None,
) -> ResponseData:
"""Send a command and return a parsed response RoborockBase type."""
request_message = RequestMessage(method, params=params)
_LOGGER.debug(
"Sending command (%s, request_id=%s): %s, params=%s", self._name, request_message.request_id, method, params
)
message = self._payload_encoder(request_message)
future: asyncio.Future[ResponseData] = asyncio.Future()
def find_response(response_message: RoborockMessage) -> None:
try:
decoded = decode_rpc_response(response_message)
except RoborockException as ex:
_LOGGER.debug("Exception while decoding message (%s): %s", response_message, ex)
return
_LOGGER.debug("Received response (request_id=%s): %s", self._name, decoded.request_id)
if decoded.request_id == request_message.request_id:
if decoded.api_error:
future.set_exception(decoded.api_error)
else:
future.set_result(decoded.data)
unsub = await self._channel.subscribe(find_response)
try:
await self._channel.publish(message)
return await asyncio.wait_for(future, timeout=_TIMEOUT)
except TimeoutError as ex:
future.cancel()
raise RoborockException(f"Command timed out after {_TIMEOUT}s") from ex
finally:
unsub()
def create_mqtt_rpc_channel(mqtt_channel: MqttChannel, security_data: SecurityData) -> V1RpcChannel:
"""Create a V1 RPC channel using an MQTT channel."""
return PayloadEncodedV1RpcChannel(
"mqtt",
mqtt_channel,
lambda x: x.encode_message(RoborockMessageProtocol.RPC_REQUEST, security_data=security_data),
)
def create_local_rpc_channel(local_channel: LocalChannel) -> V1RpcChannel:
"""Create a V1 RPC channel using a local channel."""
return PayloadEncodedV1RpcChannel(
"local",
local_channel,
lambda x: x.encode_message(RoborockMessageProtocol.GENERAL_REQUEST),
)
|