File: ssdp.py

package info (click to toggle)
async-upnp-client 0.44.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,072 kB
  • sloc: python: 11,921; xml: 2,826; sh: 32; makefile: 6
file content (497 lines) | stat: -rw-r--r-- 16,015 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
# -*- coding: utf-8 -*-
"""async_upnp_client.ssdp module."""

import logging
import socket
import sys
from asyncio import BaseTransport, DatagramProtocol, DatagramTransport
from asyncio.events import AbstractEventLoop
from datetime import datetime
from functools import lru_cache
from ipaddress import IPv4Address, IPv6Address, ip_address
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Coroutine,
    Dict,
    Optional,
    Tuple,
    Union,
    cast,
)
from urllib.parse import urlsplit, urlunsplit

from aiohttp.http_exceptions import InvalidHeader
from aiohttp.http_parser import HeadersParser
from multidict import CIMultiDictProxy

from async_upnp_client.const import (
    AddressTupleV4Type,
    AddressTupleV6Type,
    AddressTupleVXType,
    IPvXAddress,
    SsdpHeaders,
    UniqueDeviceName,
)
from async_upnp_client.exceptions import UpnpError
from async_upnp_client.utils import CaseInsensitiveDict, lowerstr

SSDP_PORT = 1900
SSDP_IP_V4 = "239.255.255.250"
SSDP_IP_V6_LINK_LOCAL = "FF02::C"
SSDP_IP_V6_SITE_LOCAL = "FF05::C"
SSDP_IP_V6_ORGANISATION_LOCAL = "FF08::C"
SSDP_IP_V6_GLOBAL = "FF0E::C"
SSDP_IP_V6 = SSDP_IP_V6_LINK_LOCAL
SSDP_TARGET_V4 = (SSDP_IP_V4, SSDP_PORT)
SSDP_TARGET_V6 = (
    SSDP_IP_V6,
    SSDP_PORT,
    0,
    0,
)  # Replace the last item with your scope_id!
SSDP_TARGET = SSDP_TARGET_V4
SSDP_ST_ALL = "ssdp:all"
SSDP_ST_ROOTDEVICE = "upnp:rootdevice"
SSDP_MX = 4
SSDP_DISCOVER = '"ssdp:discover"'

_LOGGER = logging.getLogger(__name__)
_LOGGER_TRAFFIC_SSDP = logging.getLogger("async_upnp_client.traffic.ssdp")


def get_host_string(addr: AddressTupleVXType) -> str:
    """Construct host string from address tuple."""
    if len(addr) == 4:
        if TYPE_CHECKING:
            addr = cast(AddressTupleV6Type, addr)
        if addr[3]:
            return f"{addr[0]}%{addr[3]}"

    return addr[0]


def get_host_port_string(addr: AddressTupleVXType) -> str:
    """Return a properly escaped host port pair."""
    host = get_host_string(addr)
    if ":" in host:
        return f"[{host}]:{addr[1]}"
    return f"{host}:{addr[1]}"


@lru_cache(maxsize=256)
def get_adjusted_url(url: str, addr: AddressTupleVXType) -> str:
    """Adjust a url with correction for link local scope."""
    if len(addr) < 4:
        return url

    if TYPE_CHECKING:
        addr = cast(AddressTupleV6Type, addr)

    if not addr[3]:
        return url

    data = urlsplit(url)
    assert data.hostname
    try:
        address = ip_address(data.hostname)
    except ValueError:
        return url

    if not address.is_link_local:
        return url

    netloc = f"[{data.hostname}%{addr[3]}]"
    if data.port:
        netloc += f":{data.port}"
    return urlunsplit(data._replace(netloc=netloc))


def is_ipv4_address(addr: AddressTupleVXType) -> bool:
    """Test if addr is a IPv4 tuple."""
    return len(addr) == 2


def is_ipv6_address(addr: AddressTupleVXType) -> bool:
    """Test if addr is a IPv6 tuple."""
    return len(addr) == 4


def build_ssdp_packet(status_line: str, headers: SsdpHeaders) -> bytes:
    """Construct a SSDP packet."""
    headers_str = "\r\n".join([f"{key}:{value}" for key, value in headers.items()])
    return f"{status_line}\r\n{headers_str}\r\n\r\n".encode()


def build_ssdp_search_packet(
    ssdp_target: AddressTupleVXType, ssdp_mx: int, ssdp_st: str
) -> bytes:
    """Construct a SSDP M-SEARCH packet."""
    request_line = "M-SEARCH * HTTP/1.1"
    headers = {
        "HOST": f"{get_host_port_string(ssdp_target)}",
        "MAN": SSDP_DISCOVER,
        "MX": f"{ssdp_mx}",
        "ST": f"{ssdp_st}",
    }
    return build_ssdp_packet(request_line, headers)


@lru_cache(maxsize=128)
def is_valid_ssdp_packet(data: bytes) -> bool:
    """Check if data is a valid and decodable packet."""
    return (
        bool(data)
        and b"\n" in data
        and (
            data.startswith(b"NOTIFY * HTTP/1.1")
            or data.startswith(b"M-SEARCH * HTTP/1.1")
            or data.startswith(b"HTTP/1.1 200 OK")
        )
    )


# No longer used internally, but left for backwards compatibility
def udn_from_headers(
    headers: Union[CIMultiDictProxy, CaseInsensitiveDict]
) -> Optional[UniqueDeviceName]:
    """Get UDN from USN in headers."""
    usn: str = headers.get("usn", "")
    return udn_from_usn(usn)


@lru_cache(maxsize=128)
def udn_from_usn(usn: str) -> Optional[UniqueDeviceName]:
    """Get UDN from USN in headers."""
    if usn.lower().startswith("uuid:"):
        return usn.partition("::")[0]
    return None


@lru_cache(maxsize=128)
def _cached_header_parse(
    data: bytes,
) -> Tuple[CIMultiDictProxy[str], str, Optional[UniqueDeviceName]]:
    """Cache parsing headers.

    SSDP discover packets frequently end up being sent multiple
    times on multiple interfaces.

    We can avoid parsing the sames ones over and over
    again with a simple lru_cache.
    """
    lines = data.replace(b"\r\n", b"\n").split(b"\n")

    # request_line
    request_line = lines[0].strip().decode()

    if lines and lines[-1] != b"":
        lines.append(b"")

    parsed_headers, _ = HeadersParser().parse_headers(lines)

    usn = parsed_headers.get("usn")
    udn = udn_from_usn(usn) if usn else None

    return parsed_headers, request_line, udn


LOWER__TIMESTAMP = lowerstr("_timestamp")
LOWER__HOST = lowerstr("_host")
LOWER__PORT = lowerstr("_port")
LOWER__LOCAL_ADDR = lowerstr("_local_addr")
LOWER__REMOTE_ADDR = lowerstr("_remote_addr")
LOWER__UDN = lowerstr("_udn")
LOWER__LOCATION_ORIGINAL = lowerstr("_location_original")
LOWER_LOCATION = lowerstr("location")


@lru_cache(maxsize=512)
def _cached_decode_ssdp_packet(
    data: bytes,
    remote_addr_without_port: AddressTupleVXType,
) -> Tuple[str, CaseInsensitiveDict]:
    """Cache decoding SSDP packets."""
    parsed_headers, request_line, udn = _cached_header_parse(data)
    # own data
    extra: Dict[str, Any] = {LOWER__HOST: get_host_string(remote_addr_without_port)}
    if udn:
        extra[LOWER__UDN] = udn

    # adjust some headers
    location = parsed_headers.get("location", "")
    if location.strip():
        extra[LOWER__LOCATION_ORIGINAL] = location
        extra[LOWER_LOCATION] = get_adjusted_url(location, remote_addr_without_port)

    headers = CaseInsensitiveDict(parsed_headers, **extra)
    return request_line, headers


def decode_ssdp_packet(
    data: bytes,
    local_addr: Optional[AddressTupleVXType],
    remote_addr: AddressTupleVXType,
) -> Tuple[str, CaseInsensitiveDict]:
    """Decode a message."""
    # We want to use remote_addr_without_port as the cache
    # key since nothing in _cached_decode_ssdp_packet cares
    # about the port
    if len(remote_addr) == 4:
        if TYPE_CHECKING:
            remote_addr = cast(AddressTupleV6Type, remote_addr)
        addr, port, flow, scope = remote_addr
        remote_addr_without_port: AddressTupleVXType = addr, 0, flow, scope
    else:
        if TYPE_CHECKING:
            remote_addr = cast(AddressTupleV4Type, remote_addr)
        addr, port = remote_addr
        remote_addr_without_port = remote_addr[0], 0
    request_line, headers = _cached_decode_ssdp_packet(data, remote_addr_without_port)
    return request_line, headers.combine_lower_dict(
        {
            LOWER__TIMESTAMP: datetime.now(),
            LOWER__REMOTE_ADDR: remote_addr,
            LOWER__PORT: port,
            LOWER__LOCAL_ADDR: local_addr,
        }
    )


class SsdpProtocol(DatagramProtocol):
    """SSDP Protocol."""

    def __init__(
        self,
        loop: AbstractEventLoop,
        async_on_connect: Optional[
            Callable[[DatagramTransport], Coroutine[Any, Any, None]]
        ] = None,
        on_connect: Optional[Callable[[DatagramTransport], None]] = None,
        async_on_data: Optional[
            Callable[[str, CaseInsensitiveDict], Coroutine[Any, Any, None]]
        ] = None,
        on_data: Optional[Callable[[str, CaseInsensitiveDict], None]] = None,
    ) -> None:
        """Initialize."""
        # pylint: disable=too-many-arguments,too-many-positional-arguments
        self.loop = loop
        self.async_on_connect = async_on_connect
        self.on_connect = on_connect
        self.async_on_data = async_on_data
        self.on_data = on_data

        self.transport: Optional[DatagramTransport] = None
        self.local_addr: Optional[AddressTupleVXType] = None

    def connection_made(self, transport: BaseTransport) -> None:
        """Handle connection made."""
        self.transport = cast(DatagramTransport, transport)
        sock: Optional[socket.socket] = transport.get_extra_info("socket")
        self.local_addr = sock.getsockname() if sock is not None else None
        _LOGGER.debug(
            "Connection made, transport: %s, socket: %s",
            transport,
            sock,
        )

        if self.async_on_connect:
            coro = self.async_on_connect(self.transport)
            self.loop.create_task(coro)
        if self.on_connect:
            self.on_connect(self.transport)

    def datagram_received(self, data: bytes, addr: AddressTupleVXType) -> None:
        """Handle a discovery-response."""
        _LOGGER_TRAFFIC_SSDP.debug("Received packet from %s: %s", addr, data)
        assert self.transport

        if is_valid_ssdp_packet(data):
            try:
                request_line, headers = decode_ssdp_packet(data, self.local_addr, addr)
            except InvalidHeader as exc:
                _LOGGER.debug("Ignoring received packet with invalid headers: %s", exc)
                return

            if self.async_on_data:
                coro = self.async_on_data(request_line, headers)
                self.loop.create_task(coro)
            if self.on_data:
                self.on_data(request_line, headers)

    def error_received(self, exc: Exception) -> None:
        """Handle an error."""
        sock: Optional[socket.socket] = (
            self.transport.get_extra_info("socket") if self.transport else None
        )
        _LOGGER.error(
            "Received error: %s, transport: %s, socket: %s", exc, self.transport, sock
        )

    def connection_lost(self, exc: Optional[Exception]) -> None:
        """Handle connection lost."""
        if not _LOGGER.isEnabledFor(logging.DEBUG):
            return
        assert self.transport
        sock: Optional[socket.socket] = self.transport.get_extra_info("socket")
        _LOGGER.debug(
            "Lost connection, error: %s, transport: %s, socket: %s",
            exc,
            self.transport,
            sock,
        )

    def send_ssdp_packet(self, packet: bytes, target: AddressTupleVXType) -> None:
        """Send a SSDP packet."""
        assert self.transport
        if _LOGGER.isEnabledFor(logging.DEBUG):
            sock: Optional[socket.socket] = self.transport.get_extra_info("socket")
            _LOGGER.debug(
                "Sending SSDP packet, transport: %s, socket: %s, target: %s",
                self.transport,
                sock,
                target,
            )
        if _LOGGER_TRAFFIC_SSDP.isEnabledFor(logging.DEBUG):
            _LOGGER_TRAFFIC_SSDP.debug(
                "Sending SSDP packet, target: %s, data: %s", target, packet
            )
        self.transport.sendto(packet, target)


def determine_source_target(
    source: Optional[AddressTupleVXType] = None,
    target: Optional[AddressTupleVXType] = None,
) -> Tuple[AddressTupleVXType, AddressTupleVXType]:
    """Determine source and target."""
    if source is None and target is None:
        return ("0.0.0.0", 0), (SSDP_IP_V4, SSDP_PORT)

    if source is not None and target is None:
        if len(source) == 2:
            return source, (SSDP_IP_V4, SSDP_PORT)

        source = cast(AddressTupleV6Type, source)
        return source, (SSDP_IP_V6, SSDP_PORT, 0, source[3])

    if source is None and target is not None:
        if len(target) == 2:
            return (
                "0.0.0.0",
                0,
            ), target

        target = cast(AddressTupleV6Type, target)
        return ("::", 0, 0, target[3]), target

    if source is not None and target is not None and len(source) != len(target):
        raise UpnpError("Source and target do not match protocol")

    return cast(AddressTupleVXType, source), cast(AddressTupleVXType, target)


def fix_ipv6_address_scope_id(
    address: Optional[AddressTupleVXType],
) -> Optional[AddressTupleVXType]:
    """Fix scope_id for an IPv6 address, if needed."""
    if address is None or is_ipv4_address(address):
        return address

    ip_str = address[0]
    if "%" not in ip_str:
        # Nothing to fix.
        return address

    address = cast(AddressTupleV6Type, address)
    idx = ip_str.index("%")
    try:
        ip_scope_id = int(ip_str[idx + 1 :])
    except ValueError:
        pass
    scope_id = address[3]
    new_scope_id = ip_scope_id if not scope_id and ip_scope_id else address[3]
    new_ip = ip_str[:idx]
    return (
        new_ip,
        address[1],
        address[2],
        new_scope_id,
    )


def ip_port_from_address_tuple(
    address_tuple: AddressTupleVXType,
) -> Tuple[IPvXAddress, int]:
    """Get IPvXAddress from AddressTupleVXType."""
    if len(address_tuple) == 4:
        address_tuple = cast(AddressTupleV6Type, address_tuple)
        if "%" in address_tuple[0]:
            return IPv6Address(address_tuple[0]), address_tuple[1]

        return IPv6Address(f"{address_tuple[0]}%{address_tuple[3]}"), address_tuple[1]

    return IPv4Address(address_tuple[0]), address_tuple[1]


def get_ssdp_socket(
    source: AddressTupleVXType,
    target: AddressTupleVXType,
) -> Tuple[socket.socket, AddressTupleVXType, AddressTupleVXType]:
    """Create a socket to listen on."""
    # Ensure a proper IPv6 source/target.
    if is_ipv6_address(source):
        source = cast(AddressTupleV6Type, source)
        if not source[3]:
            raise UpnpError(f"Source missing scope_id, source: {source}")

    if is_ipv6_address(target):
        target = cast(AddressTupleV6Type, target)
        if not target[3]:
            raise UpnpError(f"Target missing scope_id, target: {target}")

    target_ip, target_port = ip_port_from_address_tuple(target)
    target_info = socket.getaddrinfo(
        str(target_ip),
        target_port,
        type=socket.SOCK_DGRAM,
        proto=socket.IPPROTO_UDP,
    )[0]
    source_ip, source_port = ip_port_from_address_tuple(source)
    source_info = socket.getaddrinfo(
        str(source_ip), source_port, type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP
    )[0]
    _LOGGER.debug("Creating socket, source: %s, target: %s", source_info, target_info)

    # create socket
    sock = socket.socket(source_info[0], source_info[1])

    # set options
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    try:
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    except AttributeError:
        pass
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)

    # multicast
    if target_ip.is_multicast:
        if source_info[0] == socket.AF_INET6:
            sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, 2)
            addr = cast(AddressTupleV6Type, source_info[4])
            if addr[3]:
                mreq = target_ip.packed + addr[3].to_bytes(4, sys.byteorder)
                sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, mreq)
                sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, addr[3])
            else:
                _LOGGER.debug("Skipping setting multicast interface")
        else:
            sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, source_ip.packed)
            sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2)
            sock.setsockopt(
                socket.IPPROTO_IP,
                socket.IP_ADD_MEMBERSHIP,
                target_ip.packed + source_ip.packed,
            )

    return sock, source_info[4], target_info[4]