1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
import asyncio
import contextlib
import unittest
from typing import AsyncGenerator
from aioice import mdns
from .utils import asynctest
@contextlib.asynccontextmanager
async def querier_and_responder() -> AsyncGenerator[
tuple[mdns.MDnsProtocol, mdns.MDnsProtocol], None
]:
querier = await mdns.create_mdns_protocol()
responder = await mdns.create_mdns_protocol()
try:
yield querier, responder
finally:
await querier.close()
await responder.close()
class MdnsTest(unittest.TestCase):
@asynctest
async def test_receive_junk(self) -> None:
async with querier_and_responder() as (querier, _):
querier.datagram_received(b"junk", None)
@asynctest
async def test_resolve_bad(self) -> None:
hostname = mdns.create_mdns_hostname()
async with querier_and_responder() as (querier, _):
result = await querier.resolve(hostname)
self.assertEqual(result, None)
@asynctest
async def test_resolve_close(self) -> None:
hostname = mdns.create_mdns_hostname()
# close the querier while the query is ongoing
async with querier_and_responder() as (querier, _):
result = await asyncio.gather(
querier.resolve(hostname, timeout=None), querier.close()
)
self.assertEqual(result, [None, None])
@asynctest
async def test_resolve_good_ipv4(self) -> None:
hostaddr = "1.2.3.4"
hostname = mdns.create_mdns_hostname()
async with querier_and_responder() as (querier, responder):
await responder.publish(hostname, hostaddr)
result = await querier.resolve(hostname)
self.assertEqual(result, hostaddr)
@asynctest
async def test_resolve_good_ipv6(self) -> None:
hostaddr = "::ffff:1.2.3.4"
hostname = mdns.create_mdns_hostname()
async with querier_and_responder() as (querier, responder):
await responder.publish(hostname, hostaddr)
result = await querier.resolve(hostname)
self.assertEqual(result, hostaddr)
@asynctest
async def test_resolve_simultaneous_bad(self) -> None:
hostname = mdns.create_mdns_hostname()
async with querier_and_responder() as (querier, _):
results = await asyncio.gather(
querier.resolve(hostname), querier.resolve(hostname)
)
self.assertEqual(results, [None, None])
@asynctest
async def test_resolve_simultaneous_good(self) -> None:
hostaddr = "1.2.3.4"
hostname = mdns.create_mdns_hostname()
async with querier_and_responder() as (querier, responder):
await responder.publish(hostname, hostaddr)
results = await asyncio.gather(
querier.resolve(hostname), querier.resolve(hostname)
)
self.assertEqual(results, [hostaddr, hostaddr])
|