1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
# SPDX-License-Identifier: GPL-3.0-or-later
from contextlib import contextmanager
import random
import ssl
import struct
import time
import dns
import dns.message
import pytest
# default net.tcp_in_idle is 10s, TCP_DEFER_ACCEPT 3s, some extra for
# Python handling / edge cases
MAX_TIMEOUT = 16
def receive_answer(sock):
answer_total_len = 0
data = sock.recv(2)
if not data:
return None
answer_total_len = struct.unpack_from("!H", data)[0]
answer_received_len = 0
data_answer = b''
while answer_received_len < answer_total_len:
data_chunk = sock.recv(answer_total_len - answer_received_len)
if not data_chunk:
return None
data_answer += data_chunk
answer_received_len += len(data_answer)
return data_answer
def receive_parse_answer(sock):
data_answer = receive_answer(sock)
if data_answer is None:
raise BrokenPipeError("kresd closed connection")
return dns.message.from_wire(data_answer, one_rr_per_rrset=True)
def prepare_wire(
qname='localhost.',
qtype=dns.rdatatype.A,
qclass=dns.rdataclass.IN,
msgid=None):
"""Utility function to generate DNS wire format message"""
msg = dns.message.make_query(qname, qtype, qclass, use_edns=True)
if msgid is not None:
msg.id = msgid
return msg.to_wire(), msg.id
def prepare_buffer(wire, datalen=None):
"""Utility function to prepare TCP buffer from DNS message in wire format"""
assert isinstance(wire, bytes)
if datalen is None:
datalen = len(wire)
return struct.pack("!H", datalen) + wire
def get_msgbuff(qname='localhost.', qtype=dns.rdatatype.A, msgid=None):
wire, msgid = prepare_wire(qname, qtype, msgid=msgid)
buff = prepare_buffer(wire)
return buff, msgid
def get_garbage(length):
return bytes(random.getrandbits(8) for _ in range(length))
def get_prefixed_garbage(length):
data = get_garbage(length)
return prepare_buffer(data)
def try_ping_alive(sock, msgid=None, close=False):
try:
ping_alive(sock, msgid)
except AssertionError:
return False
finally:
if close:
sock.close()
return True
def ping_alive(sock, msgid=None):
buff, msgid = get_msgbuff(msgid=msgid)
sock.sendall(buff)
answer = receive_parse_answer(sock)
assert answer.id == msgid
@contextmanager
def expect_kresd_close(rst_ok=False):
with pytest.raises((BrokenPipeError, ssl.SSLEOFError)):
try:
time.sleep(0.2) # give kresd time to close connection with TCP FIN
yield
except ConnectionResetError as ex:
if rst_ok:
raise BrokenPipeError from ex
pytest.skip("kresd closed connection with TCP RST")
pytest.fail("kresd didn't close the connection")
def make_ssl_context(insecure=False, verify_location=None,
minimum_tls=ssl.TLSVersion.TLSv1_2,
maximum_tls=ssl.TLSVersion.MAXIMUM_SUPPORTED):
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.minimum_version = minimum_tls
context.maximum_version = maximum_tls
if insecure:
# turn off certificate verification
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
else:
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = True
if verify_location is not None:
context.load_verify_locations(verify_location)
return context
|