1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
package zk
import (
"context"
"fmt"
"net"
"sync"
"time"
)
const _defaultLookupTimeout = 3 * time.Second
type lookupHostFn func(context.Context, string) ([]string, error)
// DNSHostProviderOption is an option for the DNSHostProvider.
type DNSHostProviderOption interface {
apply(*DNSHostProvider)
}
type lookupTimeoutOption struct {
timeout time.Duration
}
// WithLookupTimeout returns a DNSHostProviderOption that sets the lookup timeout.
func WithLookupTimeout(timeout time.Duration) DNSHostProviderOption {
return lookupTimeoutOption{
timeout: timeout,
}
}
func (o lookupTimeoutOption) apply(provider *DNSHostProvider) {
provider.lookupTimeout = o.timeout
}
// DNSHostProvider is the default HostProvider. It currently matches
// the Java StaticHostProvider, resolving hosts from DNS once during
// the call to Init. It could be easily extended to re-query DNS
// periodically or if there is trouble connecting.
type DNSHostProvider struct {
mu sync.Mutex // Protects everything, so we can add asynchronous updates later.
servers []string
curr int
last int
lookupTimeout time.Duration
lookupHost lookupHostFn // Override of net.LookupHost, for testing.
}
// NewDNSHostProvider creates a new DNSHostProvider with the given options.
func NewDNSHostProvider(options ...DNSHostProviderOption) *DNSHostProvider {
var provider DNSHostProvider
for _, option := range options {
option.apply(&provider)
}
return &provider
}
// Init is called first, with the servers specified in the connection
// string. It uses DNS to look up addresses for each server, then
// shuffles them all together.
func (hp *DNSHostProvider) Init(servers []string) error {
hp.mu.Lock()
defer hp.mu.Unlock()
lookupHost := hp.lookupHost
if lookupHost == nil {
var resolver net.Resolver
lookupHost = resolver.LookupHost
}
timeout := hp.lookupTimeout
if timeout == 0 {
timeout = _defaultLookupTimeout
}
// TODO: consider using a context from the caller.
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
found := []string{}
for _, server := range servers {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err
}
addrs, err := lookupHost(ctx, host)
if err != nil {
return err
}
for _, addr := range addrs {
found = append(found, net.JoinHostPort(addr, port))
}
}
if len(found) == 0 {
return fmt.Errorf("No hosts found for addresses %q", servers)
}
// Randomize the order of the servers to avoid creating hotspots
stringShuffle(found)
hp.servers = found
hp.curr = -1
hp.last = -1
return nil
}
// Len returns the number of servers available
func (hp *DNSHostProvider) Len() int {
hp.mu.Lock()
defer hp.mu.Unlock()
return len(hp.servers)
}
// Next returns the next server to connect to. retryStart will be true
// if we've looped through all known servers without Connected() being
// called.
func (hp *DNSHostProvider) Next() (server string, retryStart bool) {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.curr = (hp.curr + 1) % len(hp.servers)
retryStart = hp.curr == hp.last
if hp.last == -1 {
hp.last = 0
}
return hp.servers[hp.curr], retryStart
}
// Connected notifies the HostProvider of a successful connection.
func (hp *DNSHostProvider) Connected() {
hp.mu.Lock()
defer hp.mu.Unlock()
hp.last = hp.curr
}
|