# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import contextlib
import socket
import sys
import time
import unittest

try:
    import ssl

    have_ssl = True
except Exception:
    have_ssl = False

import dns.exception
import dns.flags
import dns.inet
import dns.message
import dns.name
import dns.query
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.tsigkeyring
import dns.zone
import tests.util

# Some tests use a "nano nameserver" for testing.  It requires trio
# and threading, so try to import it and if it doesn't work, skip
# those tests.
try:
    from .nanonameserver import Server

    _nanonameserver_available = True
except ImportError:
    _nanonameserver_available = False

    class Server(object):
        pass


query_addresses = []
if tests.util.have_ipv4():
    query_addresses.append("8.8.8.8")
if tests.util.have_ipv6():
    query_addresses.append("2001:4860:4860::8888")

keyring = dns.tsigkeyring.from_text({"name": "tDz6cfXXGtNivRpQ98hr6A=="})


@unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable")
class QueryTests(unittest.TestCase):
    def testQueryUDP(self):
        for address in query_addresses:
            qname = dns.name.from_text("dns.google.")
            q = dns.message.make_query(qname, dns.rdatatype.A)
            response = dns.query.udp(q, address, timeout=2)
            rrs = response.get_rrset(
                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
            )
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue("8.8.8.8" in seen)
            self.assertTrue("8.8.4.4" in seen)

    def testQueryUDPWithSocket(self):
        for address in query_addresses:
            with socket.socket(
                dns.inet.af_for_address(address), socket.SOCK_DGRAM
            ) as s:
                s.setblocking(0)
                qname = dns.name.from_text("dns.google.")
                q = dns.message.make_query(qname, dns.rdatatype.A)
                response = dns.query.udp(q, address, sock=s, timeout=2)
                rrs = response.get_rrset(
                    response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
                )
                self.assertTrue(rrs is not None)
                seen = set([rdata.address for rdata in rrs])
                self.assertTrue("8.8.8.8" in seen)
                self.assertTrue("8.8.4.4" in seen)

    def testQueryTCP(self):
        for address in query_addresses:
            qname = dns.name.from_text("dns.google.")
            q = dns.message.make_query(qname, dns.rdatatype.A)
            response = dns.query.tcp(q, address, timeout=2)
            rrs = response.get_rrset(
                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
            )
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue("8.8.8.8" in seen)
            self.assertTrue("8.8.4.4" in seen)

    def testQueryTCPWithSocket(self):
        for address in query_addresses:
            with socket.socket(
                dns.inet.af_for_address(address), socket.SOCK_STREAM
            ) as s:
                ll = dns.inet.low_level_address_tuple((address, 53))
                s.settimeout(2)
                s.connect(ll)
                s.setblocking(0)
                qname = dns.name.from_text("dns.google.")
                q = dns.message.make_query(qname, dns.rdatatype.A)
                response = dns.query.tcp(q, None, sock=s, timeout=2)
                rrs = response.get_rrset(
                    response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
                )
                self.assertTrue(rrs is not None)
                seen = set([rdata.address for rdata in rrs])
                self.assertTrue("8.8.8.8" in seen)
                self.assertTrue("8.8.4.4" in seen)

    @unittest.skipUnless(have_ssl, "No SSL support")
    def testQueryTLS(self):
        for address in query_addresses:
            qname = dns.name.from_text("dns.google.")
            q = dns.message.make_query(qname, dns.rdatatype.A)
            response = dns.query.tls(q, address, timeout=2)
            rrs = response.get_rrset(
                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
            )
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue("8.8.8.8" in seen)
            self.assertTrue("8.8.4.4" in seen)

    @unittest.skipUnless(have_ssl, "No SSL support")
    def testQueryTLSWithContext(self):
        for address in query_addresses:
            qname = dns.name.from_text("dns.google.")
            q = dns.message.make_query(qname, dns.rdatatype.A)
            ssl_context = ssl.create_default_context()
            ssl_context.check_hostname = False
            response = dns.query.tls(q, address, timeout=2, ssl_context=ssl_context)
            rrs = response.get_rrset(
                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
            )
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue("8.8.8.8" in seen)
            self.assertTrue("8.8.4.4" in seen)

    @unittest.skipUnless(have_ssl, "No SSL support")
    def testQueryTLSWithSocket(self):
        for address in query_addresses:
            with socket.socket(
                dns.inet.af_for_address(address), socket.SOCK_STREAM
            ) as base_s:
                ll = dns.inet.low_level_address_tuple((address, 853))
                base_s.settimeout(2)
                base_s.connect(ll)
                ctx = ssl.create_default_context()
                ctx.minimum_version = ssl.TLSVersion.TLSv1_2
                with ctx.wrap_socket(
                    base_s, server_hostname="dns.google"
                ) as s:  # lgtm[py/insecure-protocol]
                    s.setblocking(0)
                    qname = dns.name.from_text("dns.google.")
                    q = dns.message.make_query(qname, dns.rdatatype.A)
                    response = dns.query.tls(q, None, sock=s, timeout=2)
                    rrs = response.get_rrset(
                        response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
                    )
                    self.assertTrue(rrs is not None)
                    seen = set([rdata.address for rdata in rrs])
                    self.assertTrue("8.8.8.8" in seen)
                    self.assertTrue("8.8.4.4" in seen)

    @unittest.skipUnless(have_ssl, "No SSL support")
    def testQueryTLSwithPadding(self):
        for address in query_addresses:
            qname = dns.name.from_text("dns.google.")
            q = dns.message.make_query(qname, dns.rdatatype.A, use_edns=0, pad=128)
            response = dns.query.tls(q, address, timeout=2)
            rrs = response.get_rrset(
                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
            )
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue("8.8.8.8" in seen)
            self.assertTrue("8.8.4.4" in seen)
            # the response should have a padding option
            self.assertIsNotNone(response.opt)
            has_pad = False
            for o in response.opt[0].options:
                if o.otype == dns.edns.OptionType.PADDING:
                    has_pad = True
            self.assertTrue(has_pad)

    def testQueryUDPFallback(self):
        for address in query_addresses:
            qname = dns.name.from_text(".")
            q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=4)
            self.assertTrue(tcp)

    def testQueryUDPFallbackWithSocket(self):
        for address in query_addresses:
            af = dns.inet.af_for_address(address)
            with socket.socket(af, socket.SOCK_DGRAM) as udp_s:
                udp_s.setblocking(0)
                with socket.socket(af, socket.SOCK_STREAM) as tcp_s:
                    ll = dns.inet.low_level_address_tuple((address, 53))
                    tcp_s.settimeout(2)
                    tcp_s.connect(ll)
                    tcp_s.setblocking(0)
                    qname = dns.name.from_text(".")
                    q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
                    (_, tcp) = dns.query.udp_with_fallback(
                        q, address, udp_sock=udp_s, tcp_sock=tcp_s, timeout=4
                    )
                    self.assertTrue(tcp)

    def testQueryUDPFallbackNoFallback(self):
        for address in query_addresses:
            qname = dns.name.from_text("dns.google.")
            q = dns.message.make_query(qname, dns.rdatatype.A)
            (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2)
            self.assertFalse(tcp)

    def testUDPReceiveQuery(self):
        with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener:
            listener.bind(("127.0.0.1", 0))
            with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender:
                sender.bind(("127.0.0.1", 0))
                q = dns.message.make_query("dns.google", dns.rdatatype.A)
                dns.query.send_udp(sender, q, listener.getsockname())
                expiration = time.time() + 2
                (q, _, addr) = dns.query.receive_udp(listener, expiration=expiration)
                self.assertEqual(addr, sender.getsockname())


# for brevity
_d_and_s = dns.query._destination_and_source


class DestinationAndSourceTests(unittest.TestCase):
    def test_af_inferred_from_where(self):
        (af, d, s) = _d_and_s("1.2.3.4", 53, None, 0)
        self.assertEqual(af, socket.AF_INET)

    def test_af_inferred_from_where(self):
        (af, d, s) = _d_and_s("1::2", 53, None, 0)
        self.assertEqual(af, socket.AF_INET6)

    def test_af_inferred_from_source(self):
        (af, d, s) = _d_and_s("https://example/dns-query", 443, "1.2.3.4", 0, False)
        self.assertEqual(af, socket.AF_INET)

    def test_af_mismatch(self):
        def bad():
            (af, d, s) = _d_and_s("1::2", 53, "1.2.3.4", 0)

        self.assertRaises(ValueError, bad)

    def test_source_port_but_no_af_inferred(self):
        def bad():
            (af, d, s) = _d_and_s("https://example/dns-query", 443, None, 12345, False)

        self.assertRaises(ValueError, bad)

    def test_where_must_be_an_address(self):
        def bad():
            (af, d, s) = _d_and_s("not a valid address", 53, "1.2.3.4", 0)

        self.assertRaises(ValueError, bad)

    def test_destination_is_none_of_where_url(self):
        (af, d, s) = _d_and_s("https://example/dns-query", 443, None, 0, False)
        self.assertEqual(d, None)

    def test_v4_wildcard_source_set(self):
        (af, d, s) = _d_and_s("1.2.3.4", 53, None, 12345)
        self.assertEqual(s, ("0.0.0.0", 12345))

    def test_v6_wildcard_source_set(self):
        (af, d, s) = _d_and_s("1::2", 53, None, 12345)
        self.assertEqual(s, ("::", 12345, 0, 0))


class AddressesEqualTestCase(unittest.TestCase):
    def test_v4(self):
        self.assertTrue(
            dns.query._addresses_equal(
                socket.AF_INET, ("10.0.0.1", 53), ("10.0.0.1", 53)
            )
        )
        self.assertFalse(
            dns.query._addresses_equal(
                socket.AF_INET, ("10.0.0.1", 53), ("10.0.0.2", 53)
            )
        )

    def test_v6(self):
        self.assertTrue(
            dns.query._addresses_equal(
                socket.AF_INET6, ("1::1", 53), ("0001:0000::1", 53)
            )
        )
        self.assertFalse(
            dns.query._addresses_equal(socket.AF_INET6, ("::1", 53), ("::2", 53))
        )

    def test_mixed(self):
        self.assertFalse(
            dns.query._addresses_equal(socket.AF_INET, ("10.0.0.1", 53), ("::2", 53))
        )


axfr_zone = """
$TTL 300
@ SOA ns1 root 1 7200 900 1209600 86400
@ NS ns1
@ NS ns2
ns1 A 10.0.0.1
ns2 A 10.0.0.1
"""


class AXFRNanoNameserver(Server):
    def handle(self, request):
        self.zone = dns.zone.from_text(axfr_zone, origin=self.origin)
        self.origin = self.zone.origin
        items = []
        soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA)
        response = dns.message.make_response(request.message)
        response.flags |= dns.flags.AA
        response.answer.append(soa)
        items.append(response)
        response = dns.message.make_response(request.message)
        response.question = []
        response.flags |= dns.flags.AA
        for name, rdataset in self.zone.iterate_rdatasets():
            if rdataset.rdtype == dns.rdatatype.SOA and name == dns.name.empty:
                continue
            rrset = dns.rrset.RRset(
                name, rdataset.rdclass, rdataset.rdtype, rdataset.covers
            )
            rrset.update(rdataset)
            response.answer.append(rrset)
        items.append(response)
        response = dns.message.make_response(request.message)
        response.question = []
        response.flags |= dns.flags.AA
        response.answer.append(soa)
        items.append(response)
        return items


ixfr_message = """id 12345
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN IXFR
;ANSWER
example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
deleted.example. 300 IN A 10.0.0.1
changed.example. 300 IN A 10.0.0.2
example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
changed.example. 300 IN A 10.0.0.4
added.example. 300 IN A 10.0.0.3
example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400
example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
added2.example. 300 IN A 10.0.0.5
example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400
"""

ixfr_trailing_junk = ixfr_message + "junk.example. 300 IN A 10.0.0.6"

ixfr_up_to_date_message = """id 12345
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN IXFR
;ANSWER
example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400
"""

axfr_trailing_junk = """id 12345
opcode QUERY
rcode NOERROR
flags AA
;QUESTION
example. IN AXFR
;ANSWER
example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
added.example. 300 IN A 10.0.0.3
added2.example. 300 IN A 10.0.0.5
changed.example. 300 IN A 10.0.0.4
example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400
junk.example. 300 IN A 10.0.0.6
"""


class IXFRNanoNameserver(Server):
    def __init__(self, response_text):
        super().__init__()
        self.response_text = response_text

    def handle(self, request):
        try:
            r = dns.message.from_text(self.response_text, one_rr_per_rrset=True)
            r.id = request.message.id
            return r
        except Exception:
            pass


@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
class XfrTests(unittest.TestCase):
    def test_axfr(self):
        expected = dns.zone.from_text(axfr_zone, origin="example")
        with AXFRNanoNameserver(origin="example") as ns:
            xfr = dns.query.xfr(ns.tcp_address[0], "example", port=ns.tcp_address[1])
            zone = dns.zone.from_xfr(xfr)
            self.assertEqual(zone, expected)

    def test_axfr_tsig(self):
        expected = dns.zone.from_text(axfr_zone, origin="example")
        with AXFRNanoNameserver(origin="example", keyring=keyring) as ns:
            xfr = dns.query.xfr(
                ns.tcp_address[0],
                "example",
                port=ns.tcp_address[1],
                keyring=keyring,
                keyname="name",
            )
            zone = dns.zone.from_xfr(xfr)
            self.assertEqual(zone, expected)

    def test_axfr_root_tsig(self):
        expected = dns.zone.from_text(axfr_zone, origin=".")
        with AXFRNanoNameserver(origin=".", keyring=keyring) as ns:
            xfr = dns.query.xfr(
                ns.tcp_address[0],
                ".",
                port=ns.tcp_address[1],
                keyring=keyring,
                keyname="name",
            )
            zone = dns.zone.from_xfr(xfr)
            self.assertEqual(zone, expected)

    def test_axfr_udp(self):
        def bad():
            with AXFRNanoNameserver(origin="example") as ns:
                xfr = dns.query.xfr(
                    ns.udp_address[0], "example", port=ns.udp_address[1], use_udp=True
                )
                l = list(xfr)

        self.assertRaises(ValueError, bad)

    def test_axfr_bad_rcode(self):
        def bad():
            # We just use Server here as by default it will refuse.
            with Server() as ns:
                xfr = dns.query.xfr(
                    ns.tcp_address[0], "example", port=ns.tcp_address[1]
                )
                l = list(xfr)

        self.assertRaises(dns.query.TransferError, bad)

    def test_axfr_trailing_junk(self):
        # we use the IXFR server here as it returns messages
        def bad():
            with IXFRNanoNameserver(axfr_trailing_junk) as ns:
                xfr = dns.query.xfr(
                    ns.tcp_address[0],
                    "example",
                    dns.rdatatype.AXFR,
                    port=ns.tcp_address[1],
                )
                l = list(xfr)

        self.assertRaises(dns.exception.FormError, bad)

    def test_ixfr_tcp(self):
        with IXFRNanoNameserver(ixfr_message) as ns:
            xfr = dns.query.xfr(
                ns.tcp_address[0],
                "example",
                dns.rdatatype.IXFR,
                port=ns.tcp_address[1],
                serial=2,
                relativize=False,
            )
            l = list(xfr)
            self.assertEqual(len(l), 1)
            expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True)
            expected.id = l[0].id
            self.assertEqual(l[0], expected)

    def test_ixfr_udp(self):
        with IXFRNanoNameserver(ixfr_message) as ns:
            xfr = dns.query.xfr(
                ns.udp_address[0],
                "example",
                dns.rdatatype.IXFR,
                port=ns.udp_address[1],
                serial=2,
                relativize=False,
                use_udp=True,
            )
            l = list(xfr)
            self.assertEqual(len(l), 1)
            expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True)
            expected.id = l[0].id
            self.assertEqual(l[0], expected)

    def test_ixfr_up_to_date(self):
        with IXFRNanoNameserver(ixfr_up_to_date_message) as ns:
            xfr = dns.query.xfr(
                ns.tcp_address[0],
                "example",
                dns.rdatatype.IXFR,
                port=ns.tcp_address[1],
                serial=2,
                relativize=False,
            )
            l = list(xfr)
            self.assertEqual(len(l), 1)
            expected = dns.message.from_text(
                ixfr_up_to_date_message, one_rr_per_rrset=True
            )
            expected.id = l[0].id
            self.assertEqual(l[0], expected)

    def test_ixfr_trailing_junk(self):
        def bad():
            with IXFRNanoNameserver(ixfr_trailing_junk) as ns:
                xfr = dns.query.xfr(
                    ns.tcp_address[0],
                    "example",
                    dns.rdatatype.IXFR,
                    port=ns.tcp_address[1],
                    serial=2,
                    relativize=False,
                )
                l = list(xfr)

        self.assertRaises(dns.exception.FormError, bad)

    def test_ixfr_base_serial_mismatch(self):
        def bad():
            with IXFRNanoNameserver(ixfr_message) as ns:
                xfr = dns.query.xfr(
                    ns.tcp_address[0],
                    "example",
                    dns.rdatatype.IXFR,
                    port=ns.tcp_address[1],
                    serial=1,
                    relativize=False,
                )
                l = list(xfr)

        self.assertRaises(dns.exception.FormError, bad)


class TSIGNanoNameserver(Server):
    def handle(self, request):
        response = dns.message.make_response(request.message)
        response.set_rcode(dns.rcode.REFUSED)
        response.flags |= dns.flags.RA
        try:
            if request.qtype == dns.rdatatype.A and request.qclass == dns.rdataclass.IN:
                rrs = dns.rrset.from_text(request.qname, 300, "IN", "A", "1.2.3.4")
                response.answer.append(rrs)
                response.set_rcode(dns.rcode.NOERROR)
                response.flags |= dns.flags.AA
        except Exception:
            pass
        return response


@unittest.skipIf(not _nanonameserver_available, "nanonameserver required")
class TsigTests(unittest.TestCase):
    def test_tsig(self):
        with TSIGNanoNameserver(keyring=keyring) as ns:
            qname = dns.name.from_text("example.com")
            q = dns.message.make_query(qname, "A")
            q.use_tsig(keyring=keyring, keyname="name")
            response = dns.query.udp(q, ns.udp_address[0], port=ns.udp_address[1])
            self.assertTrue(response.had_tsig)
            rrs = response.get_rrset(
                response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
            )
            self.assertTrue(rrs is not None)
            seen = set([rdata.address for rdata in rrs])
            self.assertTrue("1.2.3.4" in seen)


@unittest.skipIf(sys.platform == "win32", "low level tests do not work on win32")
class LowLevelWaitTests(unittest.TestCase):
    def test_wait_for(self):
        try:
            (l, r) = socket.socketpair()
            # already expired
            with self.assertRaises(dns.exception.Timeout):
                dns.query._wait_for(l, True, True, True, 0)
            # simple timeout
            with self.assertRaises(dns.exception.Timeout):
                dns.query._wait_for(l, False, False, False, time.time() + 0.05)
            # writable no timeout (not hanging is passing)
            dns.query._wait_for(l, False, True, False, None)
        finally:
            l.close()
            r.close()


class MiscTests(unittest.TestCase):
    def test_matches_destination(self):
        self.assertTrue(
            dns.query._matches_destination(
                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1234), True
            )
        )
        self.assertTrue(
            dns.query._matches_destination(
                socket.AF_INET6, ("1::2", 1234), ("0001::2", 1234), True
            )
        )
        self.assertTrue(
            dns.query._matches_destination(
                socket.AF_INET, ("10.0.0.1", 1234), None, True
            )
        )
        self.assertFalse(
            dns.query._matches_destination(
                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.2", 1234), True
            )
        )
        self.assertFalse(
            dns.query._matches_destination(
                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), True
            )
        )
        with self.assertRaises(dns.query.UnexpectedSource):
            dns.query._matches_destination(
                socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False
            )


@contextlib.contextmanager
def mock_udp_recv(wire1, from1, wire2, from2):
    saved = dns.query._udp_recv
    first_time = True

    def mock(sock, max_size, expiration):
        nonlocal first_time
        if first_time:
            first_time = False
            return wire1, from1
        else:
            return wire2, from2

    try:
        dns.query._udp_recv = mock
        yield None
    finally:
        dns.query._udp_recv = saved


class MockSock:
    def __init__(self):
        self.family = socket.AF_INET

    def sendto(self, data, where):
        return len(data)


class IgnoreErrors(unittest.TestCase):
    def setUp(self):
        self.q = dns.message.make_query("example.", "A")
        self.good_r = dns.message.make_response(self.q)
        self.good_r.set_rcode(dns.rcode.NXDOMAIN)
        self.good_r_wire = self.good_r.to_wire()

    def mock_receive(
        self,
        wire1,
        from1,
        wire2,
        from2,
        ignore_unexpected=True,
        ignore_errors=True,
        raise_on_truncation=False,
        good_r=None,
    ):
        if good_r is None:
            good_r = self.good_r
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        try:
            with mock_udp_recv(wire1, from1, wire2, from2):
                (r, when) = dns.query.receive_udp(
                    s,
                    ("127.0.0.1", 53),
                    time.time() + 2,
                    ignore_unexpected=ignore_unexpected,
                    ignore_errors=ignore_errors,
                    raise_on_truncation=raise_on_truncation,
                    query=self.q,
                )
                self.assertEqual(r, good_r)
        finally:
            s.close()

    def test_good_mock(self):
        self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)

    def test_bad_address(self):
        self.mock_receive(
            self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
        )

    def test_bad_address_not_ignored(self):
        def bad():
            self.mock_receive(
                self.good_r_wire,
                ("127.0.0.2", 53),
                self.good_r_wire,
                ("127.0.0.1", 53),
                ignore_unexpected=False,
            )

        self.assertRaises(dns.query.UnexpectedSource, bad)

    def test_bad_id(self):
        bad_r = dns.message.make_response(self.q)
        bad_r.id += 1
        bad_r_wire = bad_r.to_wire()
        self.mock_receive(
            bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
        )

    def test_bad_id_not_ignored(self):
        bad_r = dns.message.make_response(self.q)
        bad_r.id += 1
        bad_r_wire = bad_r.to_wire()

        def bad():
            (r, wire) = self.mock_receive(
                bad_r_wire,
                ("127.0.0.1", 53),
                self.good_r_wire,
                ("127.0.0.1", 53),
                ignore_errors=False,
            )

        self.assertRaises(AssertionError, bad)

    def test_not_response_not_ignored_udp_level(self):
        def bad():
            bad_r = dns.message.make_response(self.q)
            bad_r.id += 1
            bad_r_wire = bad_r.to_wire()
            with mock_udp_recv(
                bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
            ):
                s = MockSock()
                dns.query.udp(self.good_r, "127.0.0.1", sock=s)

        self.assertRaises(dns.query.BadResponse, bad)

    def test_bad_wire(self):
        bad_r = dns.message.make_response(self.q)
        bad_r.id += 1
        bad_r_wire = bad_r.to_wire()
        self.mock_receive(
            bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
        )

    def test_good_wire_with_truncation_flag_and_no_truncation_raise(self):
        tc_r = dns.message.make_response(self.q)
        tc_r.flags |= dns.flags.TC
        tc_r_wire = tc_r.to_wire()
        self.mock_receive(tc_r_wire, ("127.0.0.1", 53), None, None, good_r=tc_r)

    def test_good_wire_with_truncation_flag_and_truncation_raise(self):
        def good():
            tc_r = dns.message.make_response(self.q)
            tc_r.flags |= dns.flags.TC
            tc_r_wire = tc_r.to_wire()
            self.mock_receive(
                tc_r_wire, ("127.0.0.1", 53), None, None, raise_on_truncation=True
            )

        self.assertRaises(dns.message.Truncated, good)

    def test_wrong_id_wire_with_truncation_flag_and_no_truncation_raise(self):
        bad_r = dns.message.make_response(self.q)
        bad_r.id += 1
        bad_r.flags |= dns.flags.TC
        bad_r_wire = bad_r.to_wire()
        self.mock_receive(
            bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
        )

    def test_wrong_id_wire_with_truncation_flag_and_truncation_raise(self):
        bad_r = dns.message.make_response(self.q)
        bad_r.id += 1
        bad_r.flags |= dns.flags.TC
        bad_r_wire = bad_r.to_wire()
        self.mock_receive(
            bad_r_wire,
            ("127.0.0.1", 53),
            self.good_r_wire,
            ("127.0.0.1", 53),
            raise_on_truncation=True,
        )

    def test_bad_wire_not_ignored(self):
        bad_r = dns.message.make_response(self.q)
        bad_r.id += 1
        bad_r_wire = bad_r.to_wire()

        def bad():
            self.mock_receive(
                bad_r_wire[:10],
                ("127.0.0.1", 53),
                self.good_r_wire,
                ("127.0.0.1", 53),
                ignore_errors=False,
            )

        self.assertRaises(dns.message.ShortHeader, bad)

    def test_trailing_wire(self):
        wire = self.good_r_wire + b"abcd"
        self.mock_receive(wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53))

    def test_trailing_wire_not_ignored(self):
        wire = self.good_r_wire + b"abcd"

        def bad():
            self.mock_receive(
                wire,
                ("127.0.0.1", 53),
                self.good_r_wire,
                ("127.0.0.1", 53),
                ignore_errors=False,
            )

        self.assertRaises(dns.message.TrailingJunk, bad)
