1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
|
"""matchpart is used to compare two DNS messages using a single criterion"""
from typing import ( # noqa
Any, Hashable, Sequence, Tuple, Union)
import dns.edns
import dns.rcode
import dns.set
MismatchValue = Union[str, Sequence[Any]]
class DataMismatch(Exception):
def __init__(self, exp_val, got_val):
super().__init__()
self.exp_val = exp_val
self.got_val = got_val
@staticmethod
def format_value(value: MismatchValue) -> str:
if isinstance(value, list):
return ' '.join([str(val) for val in value])
else:
return str(value)
def __str__(self) -> str:
return (
f'expected "{self.format_value(self.exp_val)}" '
f'got "{self.format_value(self.got_val)}"'
)
def __eq__(self, other):
return (isinstance(other, DataMismatch)
and self.exp_val == other.exp_val
and self.got_val == other.got_val)
def __ne__(self, other):
return not self.__eq__(other)
@property
def key(self) -> Tuple[Hashable, Hashable]:
def make_hashable(value):
if isinstance(value, (list, dns.set.Set)):
value = (make_hashable(item) for item in value)
value = tuple(value)
return value
return (make_hashable(self.exp_val), make_hashable(self.got_val))
def __hash__(self) -> int:
return hash(self.key)
def compare_val(exp, got):
"""Compare arbitraty objects, throw exception if different. """
if exp != got:
raise DataMismatch(exp, got)
return True
def compare_rrs(expected, got):
""" Compare lists of RR sets, throw exception if different. """
for rr in expected:
if rr not in got:
raise DataMismatch(expected, got)
for rr in got:
if rr not in expected:
raise DataMismatch(expected, got)
if len(expected) != len(got):
raise DataMismatch(expected, got)
return True
def compare_rrs_types(exp_val, got_val, skip_rrsigs):
"""sets of RR types in both sections must match"""
def rr_ordering_key(rrset):
if rrset.covers:
return rrset.covers, 1 # RRSIGs go to the end of RRtype list
else:
return rrset.rdtype, 0
def key_to_text(rrtype, rrsig):
if not rrsig:
return dns.rdatatype.to_text(rrtype)
else:
return f'RRSIG({dns.rdatatype.to_text(rrtype)})'
if skip_rrsigs:
exp_val = (rrset for rrset in exp_val
if rrset.rdtype != dns.rdatatype.RRSIG)
got_val = (rrset for rrset in got_val
if rrset.rdtype != dns.rdatatype.RRSIG)
exp_types = frozenset(rr_ordering_key(rrset) for rrset in exp_val)
got_types = frozenset(rr_ordering_key(rrset) for rrset in got_val)
if exp_types != got_types:
exp_types = tuple(key_to_text(*i) for i in sorted(exp_types))
got_types = tuple(key_to_text(*i) for i in sorted(got_types))
raise DataMismatch(exp_types, got_types)
def check_question(question):
if len(question) > 2:
raise NotImplementedError("More than one record in QUESTION SECTION.")
def match_opcode(exp, got):
return compare_val(exp.opcode(),
got.opcode())
def match_qtype(exp, got):
check_question(exp.question)
check_question(got.question)
if not exp.question and not got.question:
return True
if not exp.question:
raise DataMismatch("<empty question>", got.question[0].rdtype)
if not got.question:
raise DataMismatch(exp.question[0].rdtype, "<empty question>")
return compare_val(exp.question[0].rdtype,
got.question[0].rdtype)
def match_qname(exp, got):
check_question(exp.question)
check_question(got.question)
if not exp.question and not got.question:
return True
if not exp.question:
raise DataMismatch("<empty question>", got.question[0].name)
if not got.question:
raise DataMismatch(exp.question[0].name, "<empty question>")
return compare_val(exp.question[0].name,
got.question[0].name)
def match_qcase(exp, got):
check_question(exp.question)
check_question(got.question)
if not exp.question and not got.question:
return True
if not exp.question:
raise DataMismatch("<empty question>", got.question[0].name.labels)
if not got.question:
raise DataMismatch(exp.question[0].name.labels, "<empty question>")
return compare_val(exp.question[0].name.labels,
got.question[0].name.labels)
def match_subdomain(exp, got):
if not exp.question:
return True
if got.question:
qname = got.question[0].name
else:
qname = dns.name.root
if exp.question[0].name.is_superdomain(qname):
return True
raise DataMismatch(exp, got)
def match_flags(exp, got):
return compare_val(dns.flags.to_text(exp.flags),
dns.flags.to_text(got.flags))
def match_rcode(exp, got):
return compare_val(dns.rcode.to_text(exp.rcode()),
dns.rcode.to_text(got.rcode()))
def match_answer(exp, got):
return compare_rrs(exp.answer,
got.answer)
def match_answertypes(exp, got):
return compare_rrs_types(exp.answer,
got.answer, skip_rrsigs=True)
def match_answerrrsigs(exp, got):
return compare_rrs_types(exp.answer,
got.answer, skip_rrsigs=False)
def match_authority(exp, got):
return compare_rrs(exp.authority,
got.authority)
def match_additional(exp, got):
return compare_rrs(exp.additional,
got.additional)
def match_edns(exp, got):
if got.edns != exp.edns:
raise DataMismatch(exp.edns,
got.edns)
if got.payload != exp.payload:
raise DataMismatch(exp.payload,
got.payload)
def match_nsid(exp, got):
nsid_opt = None
for opt in exp.options:
if opt.otype == dns.edns.NSID:
nsid_opt = opt
break
# Find matching NSID
for opt in got.options:
if opt.otype == dns.edns.NSID:
if not nsid_opt:
raise DataMismatch(None, opt.data)
if opt == nsid_opt:
return True
else:
raise DataMismatch(nsid_opt.data, opt.data)
if nsid_opt:
raise DataMismatch(nsid_opt.data, None)
return True
MATCH = {"opcode": match_opcode, "qtype": match_qtype, "qname": match_qname, "qcase": match_qcase,
"subdomain": match_subdomain, "flags": match_flags, "rcode": match_rcode,
"answer": match_answer, "answertypes": match_answertypes,
"answerrrsigs": match_answerrrsigs, "authority": match_authority,
"additional": match_additional, "edns": match_edns,
"nsid": match_nsid}
def match_part(exp, got, code):
try:
return MATCH[code](exp, got)
except KeyError as ex:
raise NotImplementedError(f'unknown match request "{code}"') from ex
|