1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package dns
import (
"errors"
"fmt"
"net"
"runtime"
"strings"
)
// Resolver resolves DNS records.
type Resolver struct {
// Holds the functions to use for DNS lookups
LookupSRV func(string, string, string) (string, []*net.SRV, error)
LookupTXT func(string) ([]string, error)
}
// DefaultResolver is a Resolver that uses the default Resolver from the net package.
var DefaultResolver = &Resolver{net.LookupSRV, net.LookupTXT}
// ParseHosts uses the srv string and service name to get the hosts.
func (r *Resolver) ParseHosts(host string, srvName string, stopOnErr bool) ([]string, error) {
parsedHosts := strings.Split(host, ",")
if len(parsedHosts) != 1 {
return nil, fmt.Errorf("URI with SRV must include one and only one hostname")
}
return r.fetchSeedlistFromSRV(parsedHosts[0], srvName, stopOnErr)
}
// GetConnectionArgsFromTXT gets the TXT record associated with the host and returns the connection arguments.
func (r *Resolver) GetConnectionArgsFromTXT(host string) ([]string, error) {
var connectionArgsFromTXT []string
// error ignored because not finding a TXT record should not be
// considered an error.
recordsFromTXT, _ := r.LookupTXT(host)
// This is a temporary fix to get around bug https://github.com/golang/go/issues/21472.
// It will currently incorrectly concatenate multiple TXT records to one
// on windows.
if runtime.GOOS == "windows" {
recordsFromTXT = []string{strings.Join(recordsFromTXT, "")}
}
if len(recordsFromTXT) > 1 {
return nil, errors.New("multiple records from TXT not supported")
}
if len(recordsFromTXT) > 0 {
connectionArgsFromTXT = strings.FieldsFunc(recordsFromTXT[0], func(r rune) bool { return r == ';' || r == '&' })
err := validateTXTResult(connectionArgsFromTXT)
if err != nil {
return nil, err
}
}
return connectionArgsFromTXT, nil
}
func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr bool) ([]string, error) {
var err error
_, _, err = net.SplitHostPort(host)
if err == nil {
// we were able to successfully extract a port from the host,
// but should not be able to when using SRV
return nil, fmt.Errorf("URI with srv must not include a port number")
}
// default to "mongodb" as service name if not supplied
if srvName == "" {
srvName = "mongodb"
}
_, addresses, err := r.LookupSRV(srvName, "tcp", host)
if err != nil {
return nil, err
}
trimmedHost := strings.TrimSuffix(host, ".")
var parsedHosts []string
for _, address := range addresses {
trimmedAddressTarget := strings.TrimSuffix(address.Target, ".")
err := validateSRVResult(trimmedAddressTarget, trimmedHost)
if err != nil {
if stopOnErr {
return nil, err
}
continue
}
parsedHosts = append(parsedHosts, fmt.Sprintf("%s:%d", trimmedAddressTarget, address.Port))
}
return parsedHosts, nil
}
func validateSRVResult(recordFromSRV, inputHostName string) error {
separatedInputDomain := strings.Split(inputHostName, ".")
separatedRecord := strings.Split(recordFromSRV, ".")
if len(separatedRecord) < 2 {
return errors.New("DNS name must contain at least 2 labels")
}
if len(separatedRecord) < len(separatedInputDomain) {
return errors.New("Domain suffix from SRV record not matched input domain")
}
inputDomainSuffix := separatedInputDomain[1:]
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
recordDomainSuffix := separatedRecord[domainSuffixOffset:]
for ix, label := range inputDomainSuffix {
if label != recordDomainSuffix[ix] {
return errors.New("Domain suffix from SRV record not matched input domain")
}
}
return nil
}
var allowedTXTOptions = map[string]struct{}{
"authsource": {},
"replicaset": {},
"loadbalanced": {},
}
func validateTXTResult(paramsFromTXT []string) error {
for _, param := range paramsFromTXT {
kv := strings.SplitN(param, "=", 2)
if len(kv) != 2 {
return errors.New("Invalid TXT record")
}
key := strings.ToLower(kv[0])
if _, ok := allowedTXTOptions[key]; !ok {
return fmt.Errorf("Cannot specify option '%s' in TXT record", kv[0])
}
}
return nil
}
|