File: connection.go

package info (click to toggle)
prometheus-mongodb-exporter 1.0.0%2Bgit20180522.e755a44-1
  • links: PTS, VCS
  • area: main
  • in suites: buster
  • size: 668 kB
  • sloc: sh: 65; makefile: 27
file content (129 lines) | stat: -rw-r--r-- 3,154 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package shared

import (
	"time"

	"crypto/tls"
	"crypto/x509"
	"fmt"
	"net"

	"github.com/golang/glog"
	"gopkg.in/mgo.v2"
)

const (
	dialMongodbTimeout = 10 * time.Second
	syncMongodbTimeout = 1 * time.Minute
)

// MongoSessionOpts represents options for a Mongo session
type MongoSessionOpts struct {
	URI                   string
	TLSCertificateFile    string
	TLSPrivateKeyFile     string
	TLSCaFile             string
	TLSHostnameValidation bool
	UserName              string
	AuthMechanism         string
}

// MongoSession creates a Mongo session
func MongoSession(opts MongoSessionOpts) *mgo.Session {
	dialInfo, err := mgo.ParseURL(opts.URI)
	if err != nil {
		glog.Errorf("Cannot connect to server using url %s: %s", opts.URI, err)
		return nil
	}

	dialInfo.Direct = true // Force direct connection
	dialInfo.Timeout = dialMongodbTimeout
	if opts.UserName != "" {
		dialInfo.Username = opts.UserName
	}

	err = opts.configureDialInfoIfRequired(dialInfo)
	if err != nil {
		glog.Errorf("%s", err)
		return nil
	}

	session, err := mgo.DialWithInfo(dialInfo)
	if err != nil {
		glog.Errorf("Cannot connect to server using url %s: %s", opts.URI, err)
		return nil
	}
	session.SetMode(mgo.Eventual, true)
	session.SetSyncTimeout(syncMongodbTimeout)
	session.SetSocketTimeout(0)
	return session
}

func (opts MongoSessionOpts) configureDialInfoIfRequired(dialInfo *mgo.DialInfo) error {
	if opts.AuthMechanism != "" {
		dialInfo.Mechanism = opts.AuthMechanism
	}
	if len(opts.TLSCertificateFile) > 0 {
		certificates, err := LoadKeyPairFrom(opts.TLSCertificateFile, opts.TLSPrivateKeyFile)
		if err != nil {
			return fmt.Errorf("Cannot load key pair from '%s' and '%s' to connect to server '%s'. Got: %v", opts.TLSCertificateFile, opts.TLSPrivateKeyFile, opts.URI, err)
		}
		config := &tls.Config{
			Certificates:       []tls.Certificate{certificates},
			InsecureSkipVerify: !opts.TLSHostnameValidation,
		}
		if len(opts.TLSCaFile) > 0 {
			ca, err := LoadCertificatesFrom(opts.TLSCaFile)
			if err != nil {
				return fmt.Errorf("Couldn't load client CAs from %s. Got: %s", opts.TLSCaFile, err)
			}
			config.RootCAs = ca
		}
		dialInfo.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
			conn, err := tls.Dial("tcp", addr.String(), config)
			if err != nil {
				glog.Infof("Could not connect to %v. Got: %v", addr, err)
				return nil, err
			}
			if config.InsecureSkipVerify {
				err = enrichWithOwnChecks(conn, config)
				if err != nil {
					glog.Infof("Could not disable hostname validation. Got: %v", err)
				}
			}
			return conn, err
		}
	}
	return nil
}

func enrichWithOwnChecks(conn *tls.Conn, tlsConfig *tls.Config) error {
	var err error
	if err = conn.Handshake(); err != nil {
		conn.Close()
		return err
	}

	opts := x509.VerifyOptions{
		Roots:         tlsConfig.RootCAs,
		CurrentTime:   time.Now(),
		DNSName:       "",
		Intermediates: x509.NewCertPool(),
	}

	certs := conn.ConnectionState().PeerCertificates
	for i, cert := range certs {
		if i == 0 {
			continue
		}
		opts.Intermediates.AddCert(cert)
	}

	_, err = certs[0].Verify(opts)
	if err != nil {
		conn.Close()
		return err
	}

	return nil
}