1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
|
# -*- coding: utf-8 -*-
"""async_upnp_client.search module."""
import asyncio
import logging
import socket
import sys
from asyncio import DatagramTransport
from asyncio.events import AbstractEventLoop
from ipaddress import IPv4Address, IPv6Address
from typing import Any, Callable, Coroutine, Optional, cast
from async_upnp_client.const import SsdpSource
from async_upnp_client.ssdp import (
SSDP_DISCOVER,
SSDP_MX,
SSDP_ST_ALL,
AddressTupleVXType,
IPvXAddress,
SsdpProtocol,
build_ssdp_search_packet,
determine_source_target,
get_host_string,
get_ssdp_socket,
)
from async_upnp_client.utils import CaseInsensitiveDict
_LOGGER = logging.getLogger(__name__)
class SsdpSearchListener:
"""SSDP Search (response) listener."""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
async_callback: Optional[
Callable[[CaseInsensitiveDict], Coroutine[Any, Any, None]]
] = None,
callback: Optional[Callable[[CaseInsensitiveDict], None]] = None,
loop: Optional[AbstractEventLoop] = None,
source: Optional[AddressTupleVXType] = None,
target: Optional[AddressTupleVXType] = None,
timeout: int = SSDP_MX,
search_target: str = SSDP_ST_ALL,
async_connect_callback: Optional[
Callable[[], Coroutine[Any, Any, None]]
] = None,
connect_callback: Optional[Callable[[], None]] = None,
) -> None:
"""Init the ssdp listener class."""
# pylint: disable=too-many-arguments,too-many-positional-arguments
assert (
callback is not None or async_callback is not None
), "Provide at least one callback"
self.async_callback = async_callback
self.callback = callback
self.async_connect_callback = async_connect_callback
self.connect_callback = connect_callback
self.search_target = search_target
self.source, self.target = determine_source_target(source, target)
self.timeout = timeout
self.loop = loop or asyncio.get_event_loop()
self._target_host: Optional[str] = None
self._transport: Optional[DatagramTransport] = None
def async_search(
self, override_target: Optional[AddressTupleVXType] = None
) -> None:
"""Start an SSDP search."""
assert self._transport is not None
sock: Optional[socket.socket] = self._transport.get_extra_info("socket")
_LOGGER.debug(
"Sending SEARCH packet, transport: %s, socket: %s, override_target: %s",
self._transport,
sock,
override_target,
)
assert self._target_host is not None, "Call async_start() first"
packet = build_ssdp_search_packet(self.target, self.timeout, self.search_target)
protocol = cast(SsdpProtocol, self._transport.get_protocol())
target = override_target or self.target
protocol.send_ssdp_packet(packet, target)
def _on_data(self, request_line: str, headers: CaseInsensitiveDict) -> None:
"""Handle data."""
if headers.get_lower("man") == SSDP_DISCOVER:
# Ignore discover packets.
return
if headers.get_lower("nts"):
_LOGGER.debug(
"Got non-search response packet: %s, %s", request_line, headers
)
return
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Received search response, _remote_addr: %s, USN: %s, location: %s",
headers.get_lower("_remote_addr", ""),
headers.get_lower("usn", "<no USN>"),
headers.get_lower("location", ""),
)
headers["_source"] = SsdpSource.SEARCH
if self._target_host and self._target_host != headers["_host"]:
return
if self.async_callback:
coro = self.async_callback(headers)
self.loop.create_task(coro)
if self.callback:
self.callback(headers)
def _on_connect(self, transport: DatagramTransport) -> None:
sock: Optional[socket.socket] = transport.get_extra_info("socket")
_LOGGER.debug("On connect, transport: %s, socket: %s", transport, sock)
self._transport = transport
if self.async_connect_callback:
coro = self.async_connect_callback()
self.loop.create_task(coro)
if self.connect_callback:
self.connect_callback()
@property
def target_ip(self) -> IPvXAddress:
"""Get target IP."""
if len(self.target) == 4:
return IPv6Address(self.target[0])
return IPv4Address(self.target[0])
async def async_start(self) -> None:
"""Start the listener."""
_LOGGER.debug("Start listening for search responses")
sock, _source, _target = get_ssdp_socket(self.source, self.target)
if sys.platform.startswith("win32"):
address = self.source
_LOGGER.debug("Binding socket, socket: %s, address: %s", sock, address)
sock.bind(address)
if not self.target_ip.is_multicast:
self._target_host = get_host_string(self.target)
else:
self._target_host = ""
loop = self.loop
await loop.create_datagram_endpoint(
lambda: SsdpProtocol(
loop,
on_connect=self._on_connect,
on_data=self._on_data,
),
sock=sock,
)
def async_stop(self) -> None:
"""Stop the listener."""
if self._transport:
self._transport.close()
async def async_search(
async_callback: Callable[[CaseInsensitiveDict], Coroutine[Any, Any, None]],
timeout: int = SSDP_MX,
search_target: str = SSDP_ST_ALL,
source: Optional[AddressTupleVXType] = None,
target: Optional[AddressTupleVXType] = None,
loop: Optional[AbstractEventLoop] = None,
) -> None:
"""Discover devices via SSDP."""
# pylint: disable=too-many-arguments,too-many-positional-arguments
loop_: AbstractEventLoop = loop or asyncio.get_event_loop()
listener: Optional[SsdpSearchListener] = None
async def _async_connected() -> None:
nonlocal listener
assert listener is not None
listener.async_search()
listener = SsdpSearchListener(
async_callback=async_callback,
loop=loop_,
source=source,
target=target,
timeout=timeout,
search_target=search_target,
async_connect_callback=_async_connected,
)
await listener.async_start()
# Wait for devices to respond.
await asyncio.sleep(timeout)
listener.async_stop()
|