File: run_mock_server.py

package info (click to toggle)
zwave-js-server-python 0.67.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,820 kB
  • sloc: python: 15,886; sh: 21; javascript: 16; makefile: 2
file content (387 lines) | stat: -rw-r--r-- 13,655 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
"""Run a mock zwave-js-server instance off of a network state dump."""

from __future__ import annotations

import argparse
import asyncio
from collections import defaultdict
from collections.abc import Hashable
import json
import logging
from typing import Any

from aiohttp import WSMsgType, web, web_request

from zwave_js_server.client import SIZE_PARSE_JSON_EXECUTOR
from zwave_js_server.const import MAX_SERVER_SCHEMA_VERSION, MIN_SERVER_SCHEMA_VERSION
from zwave_js_server.model.version import VersionInfoDataType

DATEFMT = "%Y-%m-%d %H:%M:%S"
FMT = "%(asctime)s [%(levelname)s] %(message)s"

try:
    from colorlog import ColoredFormatter

    logging_formatter = ColoredFormatter(
        f"%(log_color)s{FMT}%(reset)s",
        datefmt=DATEFMT,
        log_colors={
            "DEBUG": "cyan",
            "INFO": "green",
            "WARNING": "yellow",
            "ERROR": "red",
            "CRITICAL": "red",
        },
    )
except ImportError:
    logging_formatter = logging.Formatter(fmt=FMT, datefmt=DATEFMT)


class ExitException(Exception):
    """Represent an exit error."""


# https://stackoverflow.com/a/1151686
class HashableDict(dict):
    """Dictionary that can be used as a key in a dictionary."""

    def __key(self) -> tuple:
        """Return key representation of HashableDict."""
        return tuple((k, self[k]) for k in sorted(self))

    def __hash__(self) -> int:  # type: ignore
        """Return hash representation of HashableDict."""
        return hash(self.__key())

    def __eq__(self, other: Any) -> bool:
        """Return whether HashableDict is equal to other."""
        # pylint: disable=protected-access
        return isinstance(other, HashableDict) and self.__key() == other.__key()


class MockZwaveJsServer:
    """
    Class to represent a mock zwave-js-server instance.

    The last client that connected to the server will be the one that receives ws msgs.
    """

    def __init__(
        self,
        network_state_dump: list[dict],
        events_to_replay: list[dict],
        command_results: defaultdict[HashableDict, list],
    ) -> None:
        """Initialize class."""
        self.network_state_dump = network_state_dump
        self.app = web.Application()
        self.app.add_routes(
            [
                web.get("/", self.server_handler),
                web.post("/replay", self.replay_handler),
            ]
        )
        self.primary_ws_resp: web.WebSocketResponse | None = None
        self.events_to_replay = events_to_replay
        self.command_results = command_results

    async def send_json(self, data: dict) -> None:
        """Send JSON."""
        logging.debug("Sending JSON: %s", data)
        assert self.primary_ws_resp is not None
        await self.primary_ws_resp.send_json(data)

    async def send_command_result(
        self,
        data: dict,
        message_id: str,
    ) -> None:
        """Send message."""
        await self.send_json({**data, "messageId": message_id})

    async def send_success_command_result(
        self, result: dict | None, message_id: str
    ) -> None:
        """Send success message."""
        if result is None:
            result = {}
        await self.send_command_result(
            {"result": result, "type": "result", "success": True}, message_id
        )

    async def process_record(self, record: dict) -> None:
        """Process a replay dump record."""
        if record.get("record_type") not in ("event", "command"):
            raise TypeError(f"Malformed record: {record}")
        if record["record_type"] == "event":
            await self.send_json(record["event_msg"])
        else:
            add_command_result(self.command_results, record)

    async def server_handler(
        self, request: web_request.Request
    ) -> web.WebSocketResponse:
        """Handle websocket requests to the server."""
        ws_resp = web.WebSocketResponse(autoclose=False)
        self.primary_ws_resp = ws_resp
        await ws_resp.prepare(request)

        version_info: VersionInfoDataType = self.network_state_dump[0]
        # adjust min/max schemas if needed to get things to work
        version_info["maxSchemaVersion"] = max(
            MAX_SERVER_SCHEMA_VERSION, version_info["maxSchemaVersion"]
        )
        version_info["minSchemaVersion"] = min(
            MIN_SERVER_SCHEMA_VERSION, version_info["minSchemaVersion"]
        )
        await self.send_json(version_info)

        async for msg in ws_resp:
            if msg.type == WSMsgType.TEXT:
                logging.debug("Message received: %s", msg.data)
                if msg.data == "close":
                    await ws_resp.close()
                elif msg.data == "error":
                    logging.warning("Error from client: %s", msg.data)

                try:
                    if len(msg.data) > SIZE_PARSE_JSON_EXECUTOR:
                        data: dict = await asyncio.get_event_loop().run_in_executor(
                            None, msg.json
                        )
                    else:
                        data = msg.json()
                except ValueError as err:
                    raise ExitException(f"Received invalid JSON {msg.data}") from err

                if "command" not in data:
                    raise ExitException(f"Malformed message: {data}")

                cmd = data["command"]
                message_id = data["messageId"]
                if cmd == "initialize":
                    await self.send_json(self.network_state_dump[1])
                elif cmd == "driver.get_log_config":
                    await self.send_success_command_result(
                        {
                            "config": {
                                "enabled": True,
                                "level": "silly",
                                "logToFile": False,
                                "nodeFilter": [],
                                "filename": None,
                                "forceConsole": False,
                            }
                        },
                        message_id,
                    )
                elif cmd == "start_listening":
                    await self.send_json(self.network_state_dump[2])
                    await asyncio.sleep(1)
                    for event in self.events_to_replay:
                        await self.send_json(event)
                elif resp_list := self.command_results[sanitize_msg(data)]:
                    await self.send_command_result(resp_list.pop(0), message_id)
                else:
                    raise ExitException(f"Unhandled command received: {data}")
            elif msg.type == WSMsgType.ERROR:
                logging.error(
                    "Connection closed with exception %s",
                    ws_resp.exception(),
                )

        logging.info("Connection closed")

        return ws_resp

    async def replay_handler(self, request: web_request.Request) -> web.Response:
        """Handle requests to replay dump."""
        try:
            data = await request.json()
        except json.decoder.JSONDecodeError:
            return web.Response(status=400, reason="Invalid JSON.")

        if isinstance(data, list):
            for record in data:
                try:
                    await self.process_record(record)
                except Exception as err:
                    return web.Response(status=400, reason=err.args[0])
        elif isinstance(data, dict):
            try:
                await self.process_record(data)
            except Exception as err:
                return web.Response(status=400, reason=err.args[0])
        else:
            return web.Response(status=400, reason=f"Malformed message: {data}")
        return web.Response(status=200)


def _hashable_value(item: dict | list | Hashable) -> tuple | list | Hashable:
    """Return hashable value from item."""
    if isinstance(item, dict):
        return make_dict_hashable(item)
    if isinstance(item, list):
        return make_list_hashable(item)
    return item


def make_list_hashable(lst: list) -> tuple:
    """Make a list hashable."""
    return tuple(_hashable_value(item) for item in lst)


def make_dict_hashable(dct: dict) -> HashableDict:
    """Convert a dictionary to a hashable dictionary."""
    return HashableDict({key: _hashable_value(value) for key, value in dct.items()})


def sanitize_msg(msg: dict) -> HashableDict:
    """Sanitize command message."""
    msg = msg.copy()
    msg.pop("messageId", None)
    return make_dict_hashable(msg)


def add_command_result(
    command_results: defaultdict[HashableDict, list],
    record: dict,
) -> None:
    """Add a command result to command_results map."""
    if "result_msg" not in record:
        logging.warning(
            "The following record cannot be used because the client did not wait for "
            "a response: %s",
            record,
        )
        return
    command_msg = sanitize_msg(record["command_msg"])
    # Response message doesn't need to be sanitized here because it will be sanitized
    # in the MockZwaveJsServer.send_command_result method.
    result_msg = record["result_msg"]
    command_results[command_msg].append(result_msg)


def get_args() -> argparse.Namespace:
    """Get arguments."""
    parser = argparse.ArgumentParser(description="Mock Z-Wave JS Server")
    parser.add_argument(
        "network_state_path", type=str, help="File path to network state dump JSON."
    )
    parser.add_argument("--host", type=str, help="Host to bind to", default="127.0.0.1")
    parser.add_argument(
        "--port", type=int, help="Port to run on (defaults to 3000)", default=3000
    )
    parser.add_argument(
        "--log-level",
        type=str.upper,
        help="Log level for the mock server instance",
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR"],
    )
    parser.add_argument(
        "--events-to-replay-path",
        type=str,
        help=(
            "File path to events to replay JSON. Events provided by "
            "--combined-replay-dump-path option will be first, followed by events "
            "from this file."
        ),
        default=None,
    )
    parser.add_argument(
        "--command-results-path",
        type=str,
        help=(
            "File path to command result JSON. Command results provided by "
            "--combined-replay-dump-path option will be first, followed by results "
            "from this file."
        ),
        default=None,
    )
    parser.add_argument(
        "--combined-replay-dump-path",
        type=str,
        help=(
            "File path to the combined event and command result dump JSON. Events and "
            "command results will be extracted in the order they were received."
        ),
        default=None,
    )
    return parser.parse_args()


def main() -> None:
    """Run main entrypoint."""
    args = get_args()

    with open(args.network_state_path, encoding="utf8") as fp:
        network_state_dump: list[dict] = json.load(fp)

    events_to_replay = []
    command_results: defaultdict[HashableDict, list] = defaultdict(list)

    if args.combined_replay_dump_path:
        with open(args.combined_replay_dump_path, encoding="utf8") as fp:
            records: list[dict] = json.load(fp)

            for record in records:
                if record.get("record_type") not in ("event", "command"):
                    raise ExitException(
                        f"Invalid record in combined replay dump file: {record}"
                    )
                if record["record_type"] == "event":
                    events_to_replay.append(record["event_msg"])
                else:
                    add_command_result(command_results, record)

    if args.events_to_replay_path:
        with open(args.events_to_replay_path, encoding="utf8") as fp:
            records = json.load(fp)
            if (
                bad_record := next(
                    (
                        record
                        for record in records
                        if record.get("record_type") != "event"
                    ),
                    None,
                )
            ) is not None:
                raise ExitException(
                    f"Malformed record in events to replay file: {bad_record}"
                )
            events_to_replay.extend([record["event_msg"] for record in records])

    if args.command_results_path:
        with open(args.command_results_path, encoding="utf8") as fp:
            records = json.load(fp)
            if (
                bad_record := next(
                    (
                        record
                        for record in records
                        if record.get("record_type") != "command"
                    ),
                    None,
                )
            ) is not None:
                raise ExitException(
                    f"Malformed record in command results dump file: {bad_record}"
                )
            for record in records:
                add_command_result(command_results, record)

    # adapted from homeassistant.bootstrap.async_enable_logging
    logging.basicConfig(level=args.log_level)
    logging.getLogger().handlers[0].setFormatter(logging_formatter)

    server = MockZwaveJsServer(network_state_dump, events_to_replay, command_results)
    web.run_app(server.app, host=args.host, port=args.port)


if __name__ == "__main__":
    try:
        main()
    except ExitException as error:
        logging.error("Fatal error: %s", error)