1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
"""Module takes care of sending and recieving DNS messages as a mock client"""
import errno
import socket
import struct
import time
from typing import Optional, Tuple, Union
import dns.message
import dns.inet
SOCKET_OPERATION_TIMEOUT = 5
RECEIVE_MESSAGE_SIZE = 2**16-1
THROTTLE_BY = 0.1
def handle_socket_timeout(sock: socket.socket, deadline: float):
# deadline is always time.monotonic
remaining = deadline - time.monotonic()
if remaining <= 0:
raise RuntimeError("Server took too long to respond")
sock.settimeout(remaining)
def recv_n_bytes_from_tcp(stream: socket.socket, n: int, deadline: float) -> bytes:
# deadline is always time.monotonic
data = b""
while n != 0:
handle_socket_timeout(stream, deadline)
chunk = stream.recv(n)
# Empty bytes from socket.recv mean that socket is closed
if not chunk:
raise OSError()
n -= len(chunk)
data += chunk
return data
def recvfrom_blob(sock: socket.socket,
timeout: int = SOCKET_OPERATION_TIMEOUT) -> Tuple[bytes, str]:
"""
Receive DNS message from TCP/UDP socket.
"""
# deadline is always time.monotonic
deadline = time.monotonic() + timeout
while True:
try:
if sock.type & socket.SOCK_DGRAM:
handle_socket_timeout(sock, deadline)
data, addr = sock.recvfrom(RECEIVE_MESSAGE_SIZE)
elif sock.type & socket.SOCK_STREAM:
# First 2 bytes of TCP packet are the size of the message
# See https://tools.ietf.org/html/rfc1035#section-4.2.2
data = recv_n_bytes_from_tcp(sock, 2, deadline)
msg_len = struct.unpack_from("!H", data)[0]
data = recv_n_bytes_from_tcp(sock, msg_len, deadline)
addr = sock.getpeername()[0]
else:
raise NotImplementedError(f"[recvfrom_blob]: unknown socket type '{sock.type}'")
return data, addr
except socket.timeout as ex:
raise RuntimeError("Server took too long to respond") from ex
except OSError as ex:
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def recvfrom_msg(sock: socket.socket,
timeout: int = SOCKET_OPERATION_TIMEOUT) -> Tuple[dns.message.Message, str]:
data, addr = recvfrom_blob(sock, timeout=timeout)
msg = dns.message.from_wire(data, one_rr_per_rrset=True)
return msg, addr
def sendto_msg(sock: socket.socket, message: bytes, addr: Optional[str] = None) -> None:
""" Send DNS/UDP/TCP message. """
try:
if sock.type & socket.SOCK_DGRAM:
if addr is None:
sock.send(message)
else:
sock.sendto(message, addr)
elif sock.type & socket.SOCK_STREAM:
data = struct.pack("!H", len(message)) + message
sock.sendall(data)
else:
raise NotImplementedError(f"[sendto_msg]: unknown socket type '{sock.type}'")
except OSError as ex:
# Reference: http://lkml.iu.edu/hypermail/linux/kernel/0002.3/0709.html
if ex.errno != errno.ECONNREFUSED:
raise
def setup_socket(address: str,
port: int,
tcp: bool = False,
src_address: Optional[str] = None) -> socket.socket:
family = dns.inet.af_for_address(address)
sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
if tcp:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
if src_address is not None:
sock.bind((src_address, 0)) # random source port
sock.settimeout(SOCKET_OPERATION_TIMEOUT)
sock.connect((address, port))
return sock
def send_query(sock: socket.socket, query: Union[dns.message.Message, bytes]) -> None:
message = query if isinstance(query, bytes) else query.to_wire()
while True:
try:
sendto_msg(sock, message)
break
except OSError as ex:
# ENOBUFS, throttle sending
if ex.errno == errno.ENOBUFS:
time.sleep(0.1)
else:
raise
def get_answer(sock: socket.socket, timeout: int = SOCKET_OPERATION_TIMEOUT) -> bytes:
""" Compatibility function """
answer, _ = recvfrom_blob(sock, timeout=timeout)
return answer
def get_dns_message(sock: socket.socket,
timeout: int = SOCKET_OPERATION_TIMEOUT) -> dns.message.Message:
return dns.message.from_wire(get_answer(sock, timeout=timeout))
|