1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
|
package topology
import (
"context"
"crypto/tls"
"net"
"time"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
)
// Dialer is used to make network connections.
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// DialerFunc is a type implemented by functions that can be used as a Dialer.
type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error)
// DialContext implements the Dialer interface.
func (df DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return df(ctx, network, address)
}
// DefaultDialer is the Dialer implementation that is used by this package. Changing this
// will also change the Dialer used for this package. This should only be changed why all
// of the connections being made need to use a different Dialer. Most of the time, using a
// WithDialer option is more appropriate than changing this variable.
var DefaultDialer Dialer = &net.Dialer{}
// Handshaker is the interface implemented by types that can perform a MongoDB
// handshake over a provided driver.Connection. This is used during connection
// initialization. Implementations must be goroutine safe.
type Handshaker = driver.Handshaker
// generationNumberFn is a callback type used by a connection to fetch its generation number given its service ID.
type generationNumberFn func(serviceID *primitive.ObjectID) uint64
type connectionConfig struct {
connectTimeout time.Duration
dialer Dialer
handshaker Handshaker
idleTimeout time.Duration
cmdMonitor *event.CommandMonitor
readTimeout time.Duration
writeTimeout time.Duration
tlsConfig *tls.Config
compressors []string
zlibLevel *int
zstdLevel *int
ocspCache ocsp.Cache
disableOCSPEndpointCheck bool
errorHandlingCallback func(err error, startGenNum uint64, svcID *primitive.ObjectID)
tlsConnectionSource tlsConnectionSource
loadBalanced bool
getGenerationFn generationNumberFn
}
func newConnectionConfig(opts ...ConnectionOption) (*connectionConfig, error) {
cfg := &connectionConfig{
connectTimeout: 30 * time.Second,
dialer: nil,
tlsConnectionSource: defaultTLSConnectionSource,
}
for _, opt := range opts {
err := opt(cfg)
if err != nil {
return nil, err
}
}
if cfg.dialer == nil {
cfg.dialer = &net.Dialer{}
}
return cfg, nil
}
// ConnectionOption is used to configure a connection.
type ConnectionOption func(*connectionConfig) error
func withTLSConnectionSource(fn func(tlsConnectionSource) tlsConnectionSource) ConnectionOption {
return func(c *connectionConfig) error {
c.tlsConnectionSource = fn(c.tlsConnectionSource)
return nil
}
}
func withErrorHandlingCallback(fn func(err error, startGenNum uint64, svcID *primitive.ObjectID)) ConnectionOption {
return func(c *connectionConfig) error {
c.errorHandlingCallback = fn
return nil
}
}
// WithCompressors sets the compressors that can be used for communication.
func WithCompressors(fn func([]string) []string) ConnectionOption {
return func(c *connectionConfig) error {
c.compressors = fn(c.compressors)
return nil
}
}
// WithConnectTimeout configures the maximum amount of time a dial will wait for a
// Connect to complete. The default is 30 seconds.
func WithConnectTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.connectTimeout = fn(c.connectTimeout)
return nil
}
}
// WithDialer configures the Dialer to use when making a new connection to MongoDB.
func WithDialer(fn func(Dialer) Dialer) ConnectionOption {
return func(c *connectionConfig) error {
c.dialer = fn(c.dialer)
return nil
}
}
// WithHandshaker configures the Handshaker that wll be used to initialize newly
// dialed connections.
func WithHandshaker(fn func(Handshaker) Handshaker) ConnectionOption {
return func(c *connectionConfig) error {
c.handshaker = fn(c.handshaker)
return nil
}
}
// WithIdleTimeout configures the maximum idle time to allow for a connection.
func WithIdleTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.idleTimeout = fn(c.idleTimeout)
return nil
}
}
// WithReadTimeout configures the maximum read time for a connection.
func WithReadTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.readTimeout = fn(c.readTimeout)
return nil
}
}
// WithWriteTimeout configures the maximum write time for a connection.
func WithWriteTimeout(fn func(time.Duration) time.Duration) ConnectionOption {
return func(c *connectionConfig) error {
c.writeTimeout = fn(c.writeTimeout)
return nil
}
}
// WithTLSConfig configures the TLS options for a connection.
func WithTLSConfig(fn func(*tls.Config) *tls.Config) ConnectionOption {
return func(c *connectionConfig) error {
c.tlsConfig = fn(c.tlsConfig)
return nil
}
}
// WithMonitor configures a event for command monitoring.
func WithMonitor(fn func(*event.CommandMonitor) *event.CommandMonitor) ConnectionOption {
return func(c *connectionConfig) error {
c.cmdMonitor = fn(c.cmdMonitor)
return nil
}
}
// WithZlibLevel sets the zLib compression level.
func WithZlibLevel(fn func(*int) *int) ConnectionOption {
return func(c *connectionConfig) error {
c.zlibLevel = fn(c.zlibLevel)
return nil
}
}
// WithZstdLevel sets the zstd compression level.
func WithZstdLevel(fn func(*int) *int) ConnectionOption {
return func(c *connectionConfig) error {
c.zstdLevel = fn(c.zstdLevel)
return nil
}
}
// WithOCSPCache specifies a cache to use for OCSP verification.
func WithOCSPCache(fn func(ocsp.Cache) ocsp.Cache) ConnectionOption {
return func(c *connectionConfig) error {
c.ocspCache = fn(c.ocspCache)
return nil
}
}
// WithDisableOCSPEndpointCheck specifies whether or the driver should perform non-stapled OCSP verification. If set
// to true, the driver will only check stapled responses and will continue the connection without reaching out to
// OCSP responders.
func WithDisableOCSPEndpointCheck(fn func(bool) bool) ConnectionOption {
return func(c *connectionConfig) error {
c.disableOCSPEndpointCheck = fn(c.disableOCSPEndpointCheck)
return nil
}
}
// WithConnectionLoadBalanced specifies whether or not the connection is to a server behind a load balancer.
func WithConnectionLoadBalanced(fn func(bool) bool) ConnectionOption {
return func(c *connectionConfig) error {
c.loadBalanced = fn(c.loadBalanced)
return nil
}
}
func withGenerationNumberFn(fn func(generationNumberFn) generationNumberFn) ConnectionOption {
return func(c *connectionConfig) error {
c.getGenerationFn = fn(c.getGenerationFn)
return nil
}
}
|