File: gateway_scanner.py

package info (click to toggle)
python-xknx 3.6.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 4,012 kB
  • sloc: python: 39,710; javascript: 8,556; makefile: 27; sh: 12
file content (344 lines) | stat: -rw-r--r-- 12,858 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
"""
GatewayScanner is an abstraction for searching for KNX/IP devices on the local network.

It walks through all network interfaces and sends UDP multicast
SearchRequest and SearchRequestExtended frames.
"""

from __future__ import annotations

import asyncio
from collections.abc import AsyncGenerator
from functools import partial
import logging
from typing import TYPE_CHECKING

from xknx.exceptions import XKNXException
from xknx.io import util
from xknx.knxip import (
    HPAI,
    SRP,
    DIBServiceFamily,
    DIBTypeCode,
    KNXIPFrame,
    KNXIPServiceType,
    SearchRequest,
    SearchRequestExtended,
    SearchResponse,
    SearchResponseExtended,
)
from xknx.knxip.dib import (
    DIB,
    DIBDeviceInformation,
    DIBSecuredServiceFamilies,
    DIBSuppSVCFamilies,
    DIBTunnelingInfo,
    TunnelingSlotStatus,
)
from xknx.telegram import IndividualAddress
from xknx.util import asyncio_timeout

from .transport import UDPTransport

if TYPE_CHECKING:
    from xknx.xknx import XKNX

logger = logging.getLogger("xknx.log")


class GatewayDescriptor:
    """Used to return information about the discovered gateways."""

    def __init__(
        self,
        ip_addr: str,
        port: int,
        local_ip: str = "",
        local_interface: str = "",
        name: str = "UNKNOWN",
        supports_routing: bool = False,
        supports_tunnelling: bool = False,
        supports_tunnelling_tcp: bool = False,
        supports_secure: bool = False,
        individual_address: IndividualAddress | None = None,
    ) -> None:
        """Initialize GatewayDescriptor class."""
        self.name = name
        self.ip_addr = ip_addr
        self.port = port
        self.individual_address = individual_address
        self.local_interface = local_interface
        self.local_ip = local_ip
        self.supports_routing = supports_routing
        self.supports_tunnelling = supports_tunnelling
        self.supports_tunnelling_tcp = supports_tunnelling_tcp
        self.supports_secure = supports_secure

        self.core_version: int = 0
        self.routing_requires_secure: bool | None = None
        self.tunnelling_requires_secure: bool | None = None
        self.tunnelling_slots: dict[IndividualAddress, TunnelingSlotStatus] = {}

    def parse_dibs(self, dibs: list[DIB]) -> None:
        """Parse DIBs for gateway information."""
        for dib in dibs:
            if isinstance(dib, DIBDeviceInformation):
                self.name = dib.name
                self.individual_address = dib.individual_address
                continue
            if isinstance(dib, DIBSuppSVCFamilies):
                self.core_version = dib.version(DIBServiceFamily.CORE) or 0
                self.supports_routing = dib.supports(DIBServiceFamily.ROUTING)
                if _tunnelling_version := dib.version(DIBServiceFamily.TUNNELING):
                    self.supports_tunnelling = True
                    self.supports_tunnelling_tcp = _tunnelling_version >= 2
                self.supports_secure = dib.supports(
                    DIBServiceFamily.SECURITY, version=1
                )
                continue
            if isinstance(dib, DIBSecuredServiceFamilies):
                self.tunnelling_requires_secure = dib.supports(
                    DIBServiceFamily.TUNNELING
                )
                self.routing_requires_secure = dib.supports(DIBServiceFamily.ROUTING)
                continue
            if isinstance(dib, DIBTunnelingInfo):
                self.tunnelling_slots = dib.slots
                continue

    def __repr__(self) -> str:
        """Return object as representation string."""
        return (
            "GatewayDescriptor(\n"
            f"    name={self.name},\n"
            f"    ip_addr={self.ip_addr},\n"
            f"    port={self.port},\n"
            f"    individual_address={self.individual_address}\n"
            f"    local_interface={self.local_interface},\n"
            f"    local_ip={self.local_ip},\n"
            f"    core_version={self.core_version},\n"
            f"    supports_routing={self.supports_routing},\n"
            f"    supports_tunnelling={self.supports_tunnelling},\n"
            f"    supports_tunnelling_tcp={self.supports_tunnelling_tcp},\n"
            f"    supports_secure={self.supports_secure},\n"
            f"    routing_requires_secure={self.routing_requires_secure}\n"
            f"    tunnelling_requires_secure={self.tunnelling_requires_secure}\n"
            f"    tunnelling_slots={self.tunnelling_slots}\n"
            ")"
        )

    def __str__(self) -> str:
        """Return object as readable string."""
        return f"{self.individual_address} - {self.name} @ {self.ip_addr}:{self.port}"


class GatewayScanFilter:
    """
    Filter to limit gateway scan results.

    If `name` doesn't match the gateway name, the gateway will be ignored.

    Connection methods are treated as OR if `True` is set for multiple methods.
    Non-secure methods don't match if secure is required.
    """

    def __init__(
        self,
        name: str | None = None,
        tunnelling: bool | None = True,
        tunnelling_tcp: bool | None = True,
        routing: bool | None = True,
        secure_tunnelling: bool | None = True,
        secure_routing: bool | None = True,
    ) -> None:
        """Initialize GatewayScanFilter class."""
        self.name = name
        self.tunnelling = tunnelling
        self.tunnelling_tcp = tunnelling_tcp
        self.routing = routing
        self.secure_tunnelling = secure_tunnelling
        self.secure_routing = secure_routing

    def match(self, gateway: GatewayDescriptor) -> bool:
        """Check whether the device is a gateway and given GatewayDescriptor matches the filter."""
        if self.name is not None and self.name != gateway.name:
            return False
        return (
            bool(
                self.tunnelling
                and gateway.supports_tunnelling
                and not gateway.tunnelling_requires_secure
            )
            or bool(
                self.tunnelling_tcp
                and gateway.supports_tunnelling_tcp
                and not gateway.tunnelling_requires_secure
            )
            or bool(
                self.routing
                and gateway.supports_routing
                and not gateway.routing_requires_secure
            )
            or bool(
                self.secure_tunnelling
                and gateway.supports_tunnelling_tcp
                and gateway.tunnelling_requires_secure
            )
            or bool(
                self.secure_routing
                and gateway.supports_routing
                and gateway.routing_requires_secure
            )
        )

    def __eq__(self, other: object) -> bool:
        """Equality for GatewayScanFilter class."""
        return self.__dict__ == other.__dict__


class GatewayScanner:
    """Class for searching KNX/IP devices."""

    def __init__(
        self,
        xknx: XKNX,
        local_ip: str | None = None,
        timeout_in_seconds: float = 3.0,
        stop_on_found: int | None = None,
        scan_filter: GatewayScanFilter | None = None,
    ) -> None:
        """Initialize GatewayScanner class."""
        self.xknx = xknx
        self.local_ip = local_ip
        self.timeout_in_seconds = timeout_in_seconds
        self.stop_on_found = stop_on_found
        self.scan_filter = scan_filter or GatewayScanFilter()
        self.found_gateways: dict[HPAI, GatewayDescriptor] = {}
        self._response_received_event = asyncio.Event()

    async def scan(self) -> list[GatewayDescriptor]:
        """Scan and return a list of GatewayDescriptors on success."""
        await self._scan()
        return list(self.found_gateways.values())

    async def async_scan(self) -> AsyncGenerator[GatewayDescriptor, None]:
        """Search and yield found gateways."""
        queue: asyncio.Queue[GatewayDescriptor | None] = asyncio.Queue()
        scan_task = asyncio.create_task(self._scan(queue=queue))
        try:
            while True:
                gateway = await queue.get()
                if gateway is None:
                    return
                yield gateway
        finally:
            # cleanup after GeneratorExit or XKNXExceptions
            if not scan_task.done():
                scan_task.cancel()
            await scan_task  # to bubble up exceptions

    async def _scan(
        self, queue: asyncio.Queue[GatewayDescriptor | None] | None = None
    ) -> None:
        """Scan for gateways."""
        _local_ip = self.local_ip or await util.get_default_local_ip(
            remote_ip=self.xknx.multicast_group
        )
        if _local_ip is None:
            if queue is not None:
                queue.put_nowait(None)
            raise XKNXException("No usable network interface found.")
        local_ip = await util.validate_ip(_local_ip)
        interface_name = util.get_local_interface_name(local_ip=local_ip)
        logger.debug("Searching on %s / %s", interface_name, local_ip)

        udp_transport = UDPTransport(
            local_addr=(local_ip, 0),
            remote_addr=(self.xknx.multicast_group, self.xknx.multicast_port),
        )
        udp_transport.register_callback(
            partial(self._response_rec_callback, interface=interface_name, queue=queue),
            [
                KNXIPServiceType.SEARCH_RESPONSE,
                KNXIPServiceType.SEARCH_RESPONSE_EXTENDED,
            ],
        )
        try:
            await self._send_search_requests(udp_transport=udp_transport)
            async with asyncio_timeout(self.timeout_in_seconds):
                await self._response_received_event.wait()
        except asyncio.TimeoutError:
            pass
        except asyncio.CancelledError:
            pass
        finally:
            udp_transport.stop()
            if queue is not None:
                queue.put_nowait(None)

    @staticmethod
    async def _send_search_requests(udp_transport: UDPTransport) -> None:
        """Send search requests on a specific interface."""
        await udp_transport.connect()
        discovery_endpoint = HPAI(*udp_transport.getsockname())
        # send SearchRequestExtended requesting needed DIBs
        search_request_extended = SearchRequestExtended(
            discovery_endpoint=discovery_endpoint,
            srps=[
                SRP.request_device_description(
                    [
                        DIBTypeCode.DEVICE_INFO,
                        DIBTypeCode.SUPP_SVC_FAMILIES,
                        DIBTypeCode.SECURED_SERVICE_FAMILIES,
                        DIBTypeCode.TUNNELING_INFO,
                    ]
                )
            ],
        )
        udp_transport.send(KNXIPFrame.init_from_body(search_request_extended))
        # send SearchRequest for Core-V1 devices
        search_request = SearchRequest(discovery_endpoint=discovery_endpoint)
        udp_transport.send(KNXIPFrame.init_from_body(search_request))

    def _response_rec_callback(
        self,
        knx_ip_frame: KNXIPFrame,
        source: HPAI,
        udp_transport: UDPTransport,
        interface: str = "",
        queue: asyncio.Queue[GatewayDescriptor | None] | None = None,
    ) -> None:
        """Verify and handle knxipframe. Callback from internal udp_transport."""
        if not isinstance(knx_ip_frame.body, SearchResponse | SearchResponseExtended):
            logger.warning("Could not understand knxipframe")
            return

        # skip non-extended SearchResponse for Core-V2 devices
        if knx_ip_frame.header.service_type_ident == KNXIPServiceType.SEARCH_RESPONSE:
            if svc_families_dib := next(
                (
                    dib
                    for dib in knx_ip_frame.body.dibs
                    if isinstance(dib, DIBSuppSVCFamilies)
                ),
                None,
            ):
                if svc_families_dib.supports(DIBServiceFamily.CORE, version=2):
                    logger.debug("Skipping SearchResponse for Core-V2 device")
                    return

        gateway = GatewayDescriptor(
            ip_addr=knx_ip_frame.body.control_endpoint.ip_addr,
            port=knx_ip_frame.body.control_endpoint.port,
            local_ip=udp_transport.local_addr[0],
            local_interface=interface,
        )
        gateway.parse_dibs(knx_ip_frame.body.dibs)

        logger.debug("Found KNX/IP device at %s: %s", source, repr(gateway))
        if self.scan_filter.match(gateway):
            self.found_gateways[knx_ip_frame.body.control_endpoint] = gateway
            if queue is not None:
                queue.put_nowait(gateway)
            if self.stop_on_found and len(self.found_gateways) >= self.stop_on_found:
                self._response_received_event.set()