1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
package net
import (
"errors"
"log"
"net"
"sort"
"strconv"
"time"
"github.com/miekg/dns"
"golang.org/x/net/proxy"
)
const defaultDNSServer = "208.67.222.222:53"
const defaultLookupTimeout = 10 * time.Second
// LookupSRV mirrors net.LookupSRV but uses the provided proxy dialer in order to do the lookup instead.
// By default it uses the OpenDNS server
func LookupSRV(dialer proxy.Dialer, service, proto, name string) (cname string, addrs []*net.SRV, err error) {
return LookupSRVWith(dialer, defaultDNSServer, service, proto, name)
}
func timingOutLookup(f func() (cname string, addrs []*net.SRV, err error), t time.Duration) (cname string, addrs []*net.SRV, err error) {
result := make(chan bool, 1)
go func() {
cname, addrs, err = f()
result <- true
}()
select {
case <-time.After(t):
log.Println("dns: lookup timed out")
return "", nil, ErrTimeout
case <-result:
return
}
}
// LookupSRVWith looks up the provided service and protocol on the given name using the proxy dialer given and the dns server provided
func LookupSRVWith(dialer proxy.Dialer, dnsServer, service, proto, name string) (cname string, addrs []*net.SRV, err error) {
return timingOutLookup(func() (cname string, addrs []*net.SRV, err error) {
cname = createCName(service, proto, name)
conn, err := dialer.Dial("tcp", dnsServer)
if err != nil {
return
}
dnsConn := &dns.Conn{Conn: conn}
defer dnsConn.Close()
r, err := exchange(dnsConn, msgSRV(cname))
if err != nil {
return
}
addrs = convertAnswersToSRV(r.Answer)
return
}, defaultLookupTimeout)
}
func createCName(service, proto, name string) string {
return "_" + service + "._" + proto + "." + name + "."
}
func msgSRV(cname string) *dns.Msg {
m := &dns.Msg{}
m.SetQuestion(cname, dns.TypeSRV)
m.RecursionDesired = true
return m
}
func exchange(conn *dns.Conn, m *dns.Msg) (r *dns.Msg, err error) {
if err = conn.WriteMsg(m); err != nil {
return
}
if r, err = conn.ReadMsg(); err != nil {
return
}
if r.Rcode != dns.RcodeSuccess {
err = errors.New("got return: " + strconv.Itoa(r.Rcode))
}
return
}
type byPriorityWeight []*net.SRV
func (s byPriorityWeight) Len() int { return len(s) }
func (s byPriorityWeight) Less(i, j int) bool {
if s[i] == nil {
return true
}
if s[j] == nil {
return false
}
if s[i].Priority == s[j].Priority {
return s[i].Weight < s[j].Weight
}
return s[i].Priority < s[j].Priority
}
func (s byPriorityWeight) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func convertAnswersToSRV(in []dns.RR) []*net.SRV {
result := make([]*net.SRV, 0, len(in))
for _, a := range in {
srv := convertAnswerToSRV(a)
if srv == nil {
continue
}
result = append(result, srv)
}
sort.Sort(byPriorityWeight(result))
return result
}
func convertAnswerToSRV(in dns.RR) *net.SRV {
srv, ok := in.(*dns.SRV)
if ok {
return &net.SRV{
Target: srv.Target,
Port: srv.Port,
Priority: srv.Priority,
Weight: srv.Weight,
}
}
return nil
}
|