1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
// Package testutil implements common testing utilities.
package testutil
import (
"fmt"
"net"
"net/http"
"testing"
"time"
"blitiri.com.ar/go/dnss/internal/trace"
"github.com/miekg/dns"
)
// WaitForDNSServer waits 5 seconds for a DNS server to start, and returns an
// error if it fails to do so.
// It does this by repeatedly querying the DNS server until it either replies
// or times out. Note we do not do any validation of the reply.
func WaitForDNSServer(addr string) error {
conn, err := dns.DialTimeout("udp", addr, 1*time.Second)
if err != nil {
return fmt.Errorf("dns.Dial error: %v", err)
}
defer conn.Close()
m := &dns.Msg{}
m.SetQuestion("unused.", dns.TypeA)
deadline := time.Now().Add(5 * time.Second)
tick := time.Tick(100 * time.Millisecond)
for (<-tick).Before(deadline) {
conn.SetDeadline(time.Now().Add(1 * time.Second))
conn.WriteMsg(m)
_, err := conn.ReadMsg()
if err == nil {
return nil
}
}
return fmt.Errorf("timed out")
}
// WaitForHTTPServer waits 5 seconds for an HTTP server to start, and returns
// an error if it fails to do so.
// It does this by repeatedly querying the server until it either replies or
// times out.
func WaitForHTTPServer(addr string) error {
c := http.Client{
Timeout: 100 * time.Millisecond,
}
deadline := time.Now().Add(5 * time.Second)
tick := time.Tick(100 * time.Millisecond)
for (<-tick).Before(deadline) {
_, err := c.Get("http://" + addr + "/testpoke")
if err == nil {
return nil
}
}
return fmt.Errorf("timed out")
}
// GetFreePort returns a free TCP port. This is hacky and not race-free, but
// it works well enough for testing purposes.
func GetFreePort() string {
l, _ := net.Listen("tcp", "localhost:0")
defer l.Close()
return l.Addr().String()
}
// DNSQuery is a convenient wrapper to issue simple DNS queries.
func DNSQuery(srv, addr string, qtype uint16) (*dns.Msg, dns.RR, error) {
m := new(dns.Msg)
m.SetQuestion(addr, qtype)
in, err := dns.Exchange(m, srv)
if err != nil {
return nil, nil, err
} else if len(in.Answer) > 0 {
return in, in.Answer[0], nil
} else {
return in, nil, nil
}
}
// TestResolver is a dnsserver.Resolver implementation for testing, so we can
// control its responses during tests.
type TestResolver struct {
// Has this resolver been initialized?
Initialized bool
// Maintain() sends a value over this channel.
MaintainC chan bool
// The last query we've seen.
LastQuery *dns.Msg
// What we will respond to queries.
Response *dns.Msg
RespError error
}
// NewTestResolver creates a new TestResolver with minimal initialization.
func NewTestResolver() *TestResolver {
return &TestResolver{
MaintainC: make(chan bool, 1),
}
}
// Init the resolver.
func (r *TestResolver) Init() error {
r.Initialized = true
return nil
}
// Maintain the resolver.
func (r *TestResolver) Maintain() {
r.MaintainC <- true
}
// Query handles the given query, returning the pre-recorded response.
func (r *TestResolver) Query(req *dns.Msg, tr *trace.Trace) (*dns.Msg, error) {
r.LastQuery = req
if r.Response != nil {
r.Response.Question = req.Question
r.Response.Authoritative = true
}
return r.Response, r.RespError
}
// ServeTestDNSServer starts the fake DNS server.
func ServeTestDNSServer(addr string, handler func(dns.ResponseWriter, *dns.Msg)) {
server := &dns.Server{
Addr: addr,
Handler: dns.HandlerFunc(handler),
Net: "udp",
}
err := server.ListenAndServe()
panic(err)
}
// MakeStaticHandler for the DNS server. The given answer must be a valid
// zone.
func MakeStaticHandler(tb testing.TB, answer string) func(dns.ResponseWriter, *dns.Msg) {
rr := NewRR(tb, answer)
return func(w dns.ResponseWriter, r *dns.Msg) {
m := &dns.Msg{}
m.SetReply(r)
m.Answer = append(m.Answer, rr)
w.WriteMsg(m)
}
}
func NewRR(tb testing.TB, s string) dns.RR {
rr, err := dns.NewRR(s)
if err != nil {
tb.Fatalf("Error parsing RR for testing: %v", err)
}
return rr
}
|