1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
|
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
package mongo
import (
"context"
"net"
"os"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/internal/testutil"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/bsonx"
)
func TestClientOptions_CustomDialer(t *testing.T) {
if os.Getenv("TEST_MONGODB_SERVER") == "" {
t.Skip()
}
td := &testDialer{d: &net.Dialer{}}
cs := testutil.ConnString(t)
opts := options.Client().ApplyURI(cs.String()).SetDialer(td)
testutil.AddTestServerAPIVersion(opts)
client, err := NewClient(opts)
require.NoError(t, err)
err = client.Connect(context.Background())
require.NoError(t, err)
_, err = client.ListDatabases(context.Background(), bsonx.Doc{})
require.NoError(t, err)
got := atomic.LoadInt32(&td.called)
if got < 1 {
t.Errorf("Custom dialer was not used when dialing new connections")
}
}
type testDialer struct {
called int32
d Dialer
}
func (td *testDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
atomic.AddInt32(&td.called, 1)
return td.d.DialContext(ctx, network, address)
}
|