1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
from __future__ import annotations
import sys
import json
import asyncio
import zigpy.state
import zigpy.zdo.types as zdo_t
import zigpy_znp.const as const
import zigpy_znp.types as t
from zigpy_znp.api import ZNP
from zigpy_znp.tools.common import ClosableFileType, setup_parser, validate_backup_json
from zigpy_znp.zigbee.application import ControllerApplication
def json_backup_to_zigpy_state(
backup: t.JSONType,
) -> tuple[zigpy.state.NetworkInfo, zigpy.state.NodeInfo]:
"""
Converts a JSON backup into a zigpy network and node info tuple.
"""
node_info = zigpy.state.NodeInfo()
node_info.nwk = 0x0000
node_info.logical_type = zdo_t.LogicalType.Coordinator
node_info.ieee, _ = t.EUI64.deserialize(
bytes.fromhex(backup["coordinator_ieee"])[::-1]
)
network_info = zigpy.state.NetworkInfo()
network_info.pan_id, _ = t.NWK.deserialize(bytes.fromhex(backup["pan_id"])[::-1])
network_info.extended_pan_id, _ = t.EUI64.deserialize(
bytes.fromhex(backup["extended_pan_id"])[::-1]
)
network_info.nwk_update_id = backup["nwk_update_id"]
network_info.nwk_manager_id = 0x0000
network_info.channel = backup["channel"]
network_info.channel_mask = t.Channels.from_channel_list(backup["channel_mask"])
network_info.security_level = backup["security_level"]
network_info.stack_specific = backup.get("stack_specific")
network_info.tc_link_key = zigpy.state.Key()
network_info.tc_link_key.key = const.DEFAULT_TC_LINK_KEY
network_info.network_key = zigpy.state.Key()
network_info.network_key.key, _ = t.KeyData.deserialize(
bytes.fromhex(backup["network_key"]["key"])
)
network_info.network_key.tx_counter = backup["network_key"]["frame_counter"]
network_info.network_key.rx_counter = 0
network_info.network_key.partner_ieee = None
network_info.network_key.seq = backup["network_key"]["sequence_number"]
network_info.children = []
network_info.nwk_addresses = {}
for obj in backup["devices"]:
node = zigpy.state.NodeInfo()
if obj["nwk_address"] is not None:
node.nwk, _ = t.NWK.deserialize(bytes.fromhex(obj["nwk_address"])[::-1])
else:
node.nwk = None
node.ieee, _ = t.EUI64.deserialize(bytes.fromhex(obj["ieee_address"])[::-1])
node.logical_type = None
# The `is_child` key is currently optional
if obj.get("is_child", True):
network_info.children.append(node.ieee)
if node.nwk is not None:
network_info.nwk_addresses[node.ieee] = node.nwk
if "link_key" in obj:
key = zigpy.state.Key()
key.key, _ = t.KeyData.deserialize(bytes.fromhex(obj["link_key"]["key"]))
key.tx_counter = obj["link_key"]["tx_counter"]
key.rx_counter = obj["link_key"]["rx_counter"]
key.partner_ieee = node.ieee
key.seq = 0
network_info.key_table.append(key)
# XXX: Devices that are not children, have no NWK address, and have no link key
# are effectively ignored, since there is no place to write them
return network_info, node_info
async def restore_network(
radio_path: str,
backup: t.JSONType,
counter_increment: int,
) -> None:
network_info, node_info = json_backup_to_zigpy_state(backup)
network_info.network_key.tx_counter += counter_increment
znp = ZNP(ControllerApplication.SCHEMA({"device": {"path": radio_path}}))
await znp.connect()
await znp.write_network_info(network_info=network_info, node_info=node_info)
await znp.reset()
await znp.disconnect()
async def main(argv: list[str]) -> None:
parser = setup_parser("Restore adapter network settings")
parser.add_argument(
"--input", "-i", type=ClosableFileType("r"), help="Input file", required=True
)
parser.add_argument(
"--counter-increment",
"-c",
type=t.uint32_t,
help="Counter increment",
default=2500,
)
args = parser.parse_args(argv)
with args.input as f:
backup = json.load(f)
validate_backup_json(backup)
await restore_network(
radio_path=args.serial,
backup=backup,
counter_increment=args.counter_increment,
)
if __name__ == "__main__":
asyncio.run(main(sys.argv[1:])) # pragma: no cover
|