1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
from __future__ import annotations
import typing
import asyncio
import logging
import zigpy.config
import zigpy.serial
import zigpy_znp.config as conf
import zigpy_znp.frames as frames
import zigpy_znp.logger as log
from zigpy_znp.types import Bytes
from zigpy_znp.exceptions import InvalidFrame
LOGGER = logging.getLogger(__name__)
class BufferTooShort(Exception):
pass
class ZnpMtProtocol(zigpy.serial.SerialProtocol):
def __init__(self, api, *, url: str | None = None) -> None:
super().__init__()
self._api = api
self.url = url
def close(self) -> None:
"""Closes the port."""
super().close()
self._api = None
def connection_lost(self, exc: Exception | None) -> None:
"""Connection lost."""
super().connection_lost(exc)
if self._api is not None:
self._api.connection_lost(exc)
def connection_made(self, transport: asyncio.BaseTransport) -> None:
super().connection_made(transport)
if self._api is not None:
self._api.connection_made()
def data_received(self, data: bytes) -> None:
"""Callback when data is received."""
super().data_received(data)
LOGGER.log(log.TRACE, "Received data: %s", Bytes.__repr__(data))
for frame in self._extract_frames():
LOGGER.log(log.TRACE, "Parsed frame: %s", frame)
try:
self._api.frame_received(frame.payload)
except Exception as e:
LOGGER.error(
"Received an exception while passing frame to API: %s",
frame,
exc_info=e,
)
def send(self, payload: frames.GeneralFrame) -> None:
"""Sends data taking care of framing."""
self.write(frames.TransportFrame(payload).serialize())
def write(self, data: bytes) -> None:
"""
Writes raw bytes to the transport. This method should be used instead of
directly writing to the transport with `transport.write`.
"""
LOGGER.log(log.TRACE, "Sending data: %s", Bytes.__repr__(data))
self._transport.write(data)
def set_dtr_rts(self, *, dtr: bool, rts: bool) -> None:
# TCP transport does not have DTR or RTS pins
if not hasattr(self._transport, "serial"):
return
LOGGER.debug("Setting serial pin states: DTR=%s, RTS=%s", dtr, rts)
self._transport.serial.dtr = dtr
self._transport.serial.rts = rts
def _extract_frames(self) -> typing.Iterator[frames.TransportFrame]:
"""Extracts frames from the buffer until it is exhausted."""
while True:
try:
yield self._extract_frame()
except BufferTooShort:
# If the buffer is too short, there is nothing more we can do
break
except InvalidFrame:
# If the buffer contains invalid data, drop it until we find the SoF
sof_index = self._buffer.find(frames.TransportFrame.SOF, 1)
if sof_index < 0:
# If we don't have a SoF in the buffer, drop everything
self._buffer.clear()
else:
del self._buffer[:sof_index]
def _extract_frame(self) -> frames.TransportFrame:
"""Extracts a single frame from the buffer."""
# The shortest possible frame is 5 bytes long
if len(self._buffer) < 5:
raise BufferTooShort()
# The buffer must start with a SoF
if self._buffer[0] != frames.TransportFrame.SOF:
raise InvalidFrame()
length = self._buffer[1]
# If the packet length field exceeds 250, our packet is not valid
if length > 250:
raise InvalidFrame()
# Don't bother deserializing anything if the packet is too short
# [SoF:1] [Length:1] [Command:2] [Data:(Length)] [FCS:1]
if len(self._buffer) < length + 5:
raise BufferTooShort()
# At this point we should have a complete frame
# If not, deserialization will fail and the error will propapate up
frame, rest = frames.TransportFrame.deserialize(self._buffer)
# If we get this far then we have a valid frame. Update the buffer.
del self._buffer[: len(self._buffer) - len(rest)]
return frame
def __repr__(self) -> str:
return (
f"<"
f"{type(self).__name__} connected to {self.url!r}"
f" (api: {self._api})"
f">"
)
async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol:
port = config[zigpy.config.CONF_DEVICE_PATH]
_, protocol = await zigpy.serial.create_serial_connection(
loop=asyncio.get_running_loop(),
protocol_factory=lambda: ZnpMtProtocol(api, url=port),
url=port,
baudrate=config[zigpy.config.CONF_DEVICE_BAUDRATE],
flow_control=config[zigpy.config.CONF_DEVICE_FLOW_CONTROL],
)
await protocol.wait_until_connected()
return protocol
|