1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
package shared
import (
"time"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"github.com/golang/glog"
"gopkg.in/mgo.v2"
)
const (
dialMongodbTimeout = 10 * time.Second
syncMongodbTimeout = 1 * time.Minute
)
// MongoSessionOpts represents options for a Mongo session
type MongoSessionOpts struct {
URI string
TLSCertificateFile string
TLSPrivateKeyFile string
TLSCaFile string
TLSHostnameValidation bool
UserName string
AuthMechanism string
}
// MongoSession creates a Mongo session
func MongoSession(opts MongoSessionOpts) *mgo.Session {
dialInfo, err := mgo.ParseURL(opts.URI)
if err != nil {
glog.Errorf("Cannot connect to server using url %s: %s", opts.URI, err)
return nil
}
dialInfo.Direct = true // Force direct connection
dialInfo.Timeout = dialMongodbTimeout
if opts.UserName != "" {
dialInfo.Username = opts.UserName
}
err = opts.configureDialInfoIfRequired(dialInfo)
if err != nil {
glog.Errorf("%s", err)
return nil
}
session, err := mgo.DialWithInfo(dialInfo)
if err != nil {
glog.Errorf("Cannot connect to server using url %s: %s", opts.URI, err)
return nil
}
session.SetMode(mgo.Eventual, true)
session.SetSyncTimeout(syncMongodbTimeout)
session.SetSocketTimeout(0)
return session
}
func (opts MongoSessionOpts) configureDialInfoIfRequired(dialInfo *mgo.DialInfo) error {
if opts.AuthMechanism != "" {
dialInfo.Mechanism = opts.AuthMechanism
}
if len(opts.TLSCertificateFile) > 0 {
certificates, err := LoadKeyPairFrom(opts.TLSCertificateFile, opts.TLSPrivateKeyFile)
if err != nil {
return fmt.Errorf("Cannot load key pair from '%s' and '%s' to connect to server '%s'. Got: %v", opts.TLSCertificateFile, opts.TLSPrivateKeyFile, opts.URI, err)
}
config := &tls.Config{
Certificates: []tls.Certificate{certificates},
InsecureSkipVerify: !opts.TLSHostnameValidation,
}
if len(opts.TLSCaFile) > 0 {
ca, err := LoadCertificatesFrom(opts.TLSCaFile)
if err != nil {
return fmt.Errorf("Couldn't load client CAs from %s. Got: %s", opts.TLSCaFile, err)
}
config.RootCAs = ca
}
dialInfo.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) {
conn, err := tls.Dial("tcp", addr.String(), config)
if err != nil {
glog.Infof("Could not connect to %v. Got: %v", addr, err)
return nil, err
}
if config.InsecureSkipVerify {
err = enrichWithOwnChecks(conn, config)
if err != nil {
glog.Infof("Could not disable hostname validation. Got: %v", err)
}
}
return conn, err
}
}
return nil
}
func enrichWithOwnChecks(conn *tls.Conn, tlsConfig *tls.Config) error {
var err error
if err = conn.Handshake(); err != nil {
conn.Close()
return err
}
opts := x509.VerifyOptions{
Roots: tlsConfig.RootCAs,
CurrentTime: time.Now(),
DNSName: "",
Intermediates: x509.NewCertPool(),
}
certs := conn.ConnectionState().PeerCertificates
for i, cert := range certs {
if i == 0 {
continue
}
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
conn.Close()
return err
}
return nil
}
|