1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
|
from __future__ import annotations
import asyncio
import binascii
from functools import partial
import logging
from struct import Struct
from typing import TYPE_CHECKING, Any, Callable
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
from cryptography.exceptions import InvalidTag
from noise.backends.default import DefaultNoiseBackend # type: ignore[import-untyped]
from noise.backends.default.ciphers import ( # type: ignore[import-untyped]
ChaCha20Cipher,
)
from noise.connection import NoiseConnection # type: ignore[import-untyped]
from ..core import (
APIConnectionError,
BadNameAPIError,
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
)
from .base import _LOGGER, APIFrameHelper
if TYPE_CHECKING:
from ..connection import APIConnection
PACK_NONCE = partial(Struct("<LQ").pack, 0)
class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
"""ChaCha20 cipher that can be reused."""
format_nonce = PACK_NONCE
@property
def klass(self) -> type[ChaCha20Poly1305Reusable]:
return ChaCha20Poly1305Reusable
class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.ciphers["ChaChaPoly"] = ChaCha20CipherReuseable
ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
# This is effectively an enum but we don't want to use an enum
# because we have a simple dispatch in the data_received method
# that would be more complicated with an enum and we want to add
# cdefs for each different state so we have a good test for each
# state receiving data since we found that the protractor event
# loop will send use a bytearray instead of bytes was not handled
# correctly.
NOISE_STATE_HELLO = 1
NOISE_STATE_HANDSHAKE = 2
NOISE_STATE_READY = 3
NOISE_STATE_CLOSED = 4
NOISE_HELLO = b"\x01\x00\x00"
int_ = int
class APINoiseFrameHelper(APIFrameHelper):
"""Frame helper for noise encrypted connections."""
__slots__ = (
"_noise_psk",
"_expected_name",
"_state",
"_server_name",
"_proto",
"_decrypt",
"_encrypt",
)
def __init__(
self,
connection: APIConnection,
noise_psk: str,
expected_name: str | None,
client_info: str,
log_name: str,
) -> None:
"""Initialize the API frame helper."""
super().__init__(connection, client_info, log_name)
self._noise_psk = noise_psk
self._expected_name = expected_name
self._state = NOISE_STATE_HELLO
self._server_name: str | None = None
self._decrypt: Callable[[bytes], bytes] | None = None
self._encrypt: Callable[[bytes], bytes] | None = None
self._setup_proto()
def close(self) -> None:
"""Close the connection."""
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._set_ready_future_exception(
APIConnectionError(f"{self._log_name}: Connection closed")
)
self._state = NOISE_STATE_CLOSED
super().close()
def _handle_error(self, exc: Exception) -> None:
"""Handle an error, and provide a good message when during hello."""
if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError):
original_exc: Exception = exc
exc = HandshakeAPIError(
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
"Try enabling encryption on the device or turning off "
f"encryption on the client ({self._client_info})."
)
exc.__cause__ = original_exc
elif isinstance(exc, InvalidTag):
original_exc = exc
exc = InvalidEncryptionKeyAPIError(
f"{self._log_name}: Invalid encryption key", self._server_name
)
exc.__cause__ = original_exc
super()._handle_error(exc)
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
super().connection_made(transport)
self._send_hello_handshake()
def data_received(self, data: bytes | bytearray | memoryview) -> None:
self._add_to_buffer(data)
while self._buffer_len:
self._pos = 0
if (header := self._read(3)) is None:
return
preamble = header[0]
if preamble != 0x01:
self._handle_error_and_close(
ProtocolAPIError(
f"{self._log_name}: Marker byte invalid: {header[0]}"
)
)
return
msg_size_high = header[1]
msg_size_low = header[2]
if (frame := self._read((msg_size_high << 8) | msg_size_low)) is None:
# The complete frame is not yet available, wait for more data
# to arrive before continuing, since callback_packet has not
# been called yet the buffer will not be cleared and the next
# call to data_received will continue processing the packet
# at the start of the frame.
return
# asyncio already runs data_received in a try block
# which will call connection_lost if an exception is raised
if self._state == NOISE_STATE_READY:
self._handle_frame(frame)
elif self._state == NOISE_STATE_HELLO:
self._handle_hello(frame)
elif self._state == NOISE_STATE_HANDSHAKE:
self._handle_handshake(frame)
else:
self._handle_closed(frame)
self._remove_from_buffer()
def _send_hello_handshake(self) -> None:
"""Send a ClientHello to the server."""
handshake_frame = self._proto.write_message()
frame_len = len(handshake_frame) + 1
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
self._write_bytes(
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
_LOGGER.isEnabledFor(logging.DEBUG),
)
def _handle_hello(self, server_hello: bytes) -> None:
"""Perform the handshake with the server."""
if not server_hello:
self._handle_error_and_close(
HandshakeAPIError(f"{self._log_name}: ServerHello is empty")
)
return
# First byte of server hello is the protocol the server chose
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
# exists.
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
self._handle_error_and_close(
HandshakeAPIError(
f"{self._log_name}: Unknown protocol selected by client {chosen_proto}"
)
)
return
# Check name matches expected name (for noise sessions, this is done
# during hello phase before a connection is set up)
# Server name is encoded as a string followed by a zero byte after the chosen proto byte
server_name_i = server_hello.find(b"\0", 1)
if server_name_i != -1:
# server name found, this extension was added in 2022.2
server_name = server_hello[1:server_name_i].decode()
self._server_name = server_name
if self._expected_name is not None and self._expected_name != server_name:
self._handle_error_and_close(
BadNameAPIError(
f"{self._log_name}: Server sent a different name '{server_name}'",
server_name,
)
)
return
self._state = NOISE_STATE_HANDSHAKE
def _decode_noise_psk(self) -> bytes:
"""Decode the given noise psk from base64 format to raw bytes."""
psk = self._noise_psk
server_name = self._server_name
try:
psk_bytes = binascii.a2b_base64(psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"{self._log_name}: Malformed PSK `{psk}`, expected "
"base64-encoded value",
server_name,
)
if len(psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"{self._log_name}:Malformed PSK `{psk}`, expected"
f" 32-bytes of base64 data",
server_name,
)
return psk_bytes
def _setup_proto(self) -> None:
"""Set up the noise protocol."""
proto = NoiseConnection.from_name(
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
)
proto.set_as_initiator()
proto.set_psks(self._decode_noise_psk())
proto.set_prologue(b"NoiseAPIInit\x00\x00")
proto.start_handshake()
self._proto = proto
def _error_on_incorrect_preamble(self, msg: bytes) -> None:
"""Handle an incorrect preamble."""
explanation = msg[1:].decode()
if explanation != "Handshake MAC failure":
exc = HandshakeAPIError(
f"{self._log_name}: Handshake failure: {explanation}"
)
else:
exc = InvalidEncryptionKeyAPIError(
f"{self._log_name}: Invalid encryption key", self._server_name
)
self._handle_error_and_close(exc)
def _handle_handshake(self, msg: bytes) -> None:
if msg[0] != 0:
self._error_on_incorrect_preamble(msg)
return
self._proto.read_message(msg[1:])
self._state = NOISE_STATE_READY
noise_protocol = self._proto.noise_protocol
self._decrypt = partial(
noise_protocol.cipher_state_decrypt.decrypt_with_ad, # pylint: disable=no-member
None,
)
self._encrypt = partial(
noise_protocol.cipher_state_encrypt.encrypt_with_ad, # pylint: disable=no-member
None,
)
self.ready_future.set_result(None)
def write_packets(
self, packets: list[tuple[int, bytes]], debug_enabled: bool
) -> None:
"""Write a packets to the socket.
Packets are in the format of tuple[protobuf_type, protobuf_data]
"""
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
out: list[bytes] = []
for packet in packets:
type_: int = packet[0]
data: bytes = packet[1]
data_len = len(data)
data_header = bytes(
(
(type_ >> 8) & 0xFF,
type_ & 0xFF,
(data_len >> 8) & 0xFF,
data_len & 0xFF,
)
)
frame = self._encrypt(data_header + data)
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
out.append(header)
out.append(frame)
self._write_bytes(b"".join(out), debug_enabled)
def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""
if TYPE_CHECKING:
assert self._decrypt is not None, "Handshake should be complete"
msg = self._decrypt(frame)
# Message layout is
# 2 bytes: message type
# 2 bytes: message length
# N bytes: message data
type_high = msg[0]
type_low = msg[1]
self._connection.process_packet((type_high << 8) | type_low, msg[4:])
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
"""Handle a closed frame."""
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
|