1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""curio async I/O library query support"""
import socket
import curio
import curio.socket # type: ignore
import dns._asyncbackend
import dns.exception
import dns.inet
def _maybe_timeout(timeout):
if timeout:
return curio.ignore_after(timeout)
else:
return dns._asyncbackend.NullContext()
# for brevity
_lltuple = dns.inet.low_level_address_tuple
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
self.socket = socket
self.family = socket.family
async def sendto(self, what, destination, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(timeout=timeout) # pragma: no cover
async def recvfrom(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout)
async def close(self):
await self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
class StreamSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
self.socket = socket
self.family = socket.family
async def sendall(self, what, timeout):
async with _maybe_timeout(timeout):
return await self.socket.sendall(what)
raise dns.exception.Timeout(timeout=timeout)
async def recv(self, size, timeout):
async with _maybe_timeout(timeout):
return await self.socket.recv(size)
raise dns.exception.Timeout(timeout=timeout)
async def close(self):
await self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'curio'
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
if socktype == socket.SOCK_DGRAM:
s = curio.socket.socket(af, socktype, proto)
try:
if source:
s.bind(_lltuple(source, af))
except Exception: # pragma: no cover
await s.close()
raise
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
if source:
source_addr = _lltuple(source, af)
else:
source_addr = None
async with _maybe_timeout(timeout):
s = await curio.open_connection(destination[0], destination[1],
ssl=ssl_context,
source_addr=source_addr,
server_hostname=server_hostname)
return StreamSocket(s)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
async def sleep(self, interval):
await curio.sleep(interval)
|