1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""asyncio library query support"""
import socket
import asyncio
import dns._asyncbackend
import dns.exception
def _get_running_loop():
try:
return asyncio.get_running_loop()
except AttributeError: # pragma: no cover
return asyncio.get_event_loop()
class _DatagramProtocol:
def __init__(self):
self.transport = None
self.recvfrom = None
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
if self.recvfrom:
self.recvfrom.set_result((data, addr))
self.recvfrom = None
def error_received(self, exc): # pragma: no cover
if self.recvfrom:
self.recvfrom.set_exception(exc)
def connection_lost(self, exc):
if self.recvfrom:
self.recvfrom.set_exception(exc)
def close(self):
self.transport.close()
async def _maybe_wait_for(awaitable, timeout):
if timeout:
try:
return await asyncio.wait_for(awaitable, timeout)
except asyncio.TimeoutError:
raise dns.exception.Timeout(timeout=timeout)
else:
return await awaitable
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
self.family = family
self.transport = transport
self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto
self.transport.sendto(what, destination)
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
async def close(self):
self.protocol.close()
async def getpeername(self):
return self.transport.get_extra_info('peername')
async def getsockname(self):
return self.transport.get_extra_info('sockname')
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, af, reader, writer):
self.family = af
self.reader = reader
self.writer = writer
async def sendall(self, what, timeout):
self.writer.write(what),
return await _maybe_wait_for(self.writer.drain(), timeout)
raise dns.exception.Timeout(timeout=timeout)
async def recv(self, count, timeout):
return await _maybe_wait_for(self.reader.read(count),
timeout)
raise dns.exception.Timeout(timeout=timeout)
async def close(self):
self.writer.close()
try:
await self.writer.wait_closed()
except AttributeError: # pragma: no cover
pass
async def getpeername(self):
return self.writer.get_extra_info('peername')
async def getsockname(self):
return self.writer.get_extra_info('sockname')
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'asyncio'
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af,
proto=proto)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
(r, w) = await _maybe_wait_for(
asyncio.open_connection(destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname),
timeout)
return StreamSocket(af, r, w)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
async def sleep(self, interval):
await asyncio.sleep(interval)
|