File: mqtt.py

package info (click to toggle)
python-roborock 4.10.0-1
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 2,476 kB
  • sloc: python: 16,570; makefile: 17; sh: 6
file content (104 lines) | stat: -rw-r--r-- 3,698 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Common code for MQTT tests."""

import asyncio
import io
import logging
from collections.abc import Callable
from queue import Queue

from roborock.mqtt.session import MqttParams
from roborock.roborock_message import RoborockMessage

from .logging import CapturedRequestLog

_LOGGER = logging.getLogger(__name__)

# Used by fixtures to handle incoming requests and prepare responses
MqttRequestHandler = Callable[[bytes], bytes | None]


class FakeMqttSocketHandler:
    """Fake socket used by the test to simulate a connection to the broker.

    The socket handler is used to intercept the socket send and recv calls and
    populate the response buffer with data to be sent back to the client. The
    handle request callback handles the incoming requests and prepares the responses.
    """

    def __init__(
        self, handle_request: MqttRequestHandler, response_queue: Queue[bytes], log: CapturedRequestLog
    ) -> None:
        self.response_buf = io.BytesIO()
        self.handle_request = handle_request
        self.response_queue = response_queue
        self.log = log
        self.client_connected = False

    def pending(self) -> int:
        """Return the number of bytes in the response buffer."""
        return len(self.response_buf.getvalue())

    def handle_socket_recv(self, read_size: int) -> bytes:
        """Intercept a client recv() and populate the buffer."""
        if self.pending() == 0:
            raise BlockingIOError("No response queued")

        self.response_buf.seek(0)
        data = self.response_buf.read(read_size)
        _LOGGER.debug("Response: 0x%s", data.hex())
        # Consume the rest of the data in the buffer
        remaining_data = self.response_buf.read()
        self.response_buf = io.BytesIO(remaining_data)
        return data

    def handle_socket_send(self, client_request: bytes) -> int:
        """Receive an incoming request from the client."""
        self.client_connected = True
        _LOGGER.debug("Request: 0x%s", client_request.hex())
        self.log.add_log_entry("[mqtt >]", client_request)
        if (response := self.handle_request(client_request)) is not None:
            # Enqueue a response to be sent back to the client in the buffer.
            # The buffer will be emptied when the client calls recv() on the socket
            _LOGGER.debug("Queued: 0x%s", response.hex())
            self.log.add_log_entry("[mqtt <]", response)
            self.response_buf.write(response)
        return len(client_request)

    def push_response(self) -> None:
        """Push a response to the client."""
        if not self.response_queue.empty() and self.client_connected:
            response = self.response_queue.get()
            # Enqueue a response to be sent back to the client in the buffer.
            # The buffer will be emptied when the client calls recv() on the socket
            _LOGGER.debug("Queued: 0x%s", response.hex())
            self.log.add_log_entry("[mqtt <]", response)
            self.response_buf.write(response)


FAKE_PARAMS = MqttParams(
    host="localhost",
    port=1883,
    tls=False,
    username="username",
    password="password",
    timeout=10.0,
)


class Subscriber:
    """Mock subscriber class.

    We use this to hold on to received messages for verification.
    """

    def __init__(self) -> None:
        self.messages: list[RoborockMessage | bytes] = []
        self._event = asyncio.Event()

    def append(self, message: RoborockMessage | bytes) -> None:
        self.messages.append(message)
        self._event.set()

    async def wait(self) -> None:
        await asyncio.wait_for(self._event.wait(), timeout=1.0)
        self._event.clear()