File: sdam_prose_test.go

package info (click to toggle)
golang-mongodb-mongo-driver 1.17.1%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: experimental, sid, trixie
  • size: 25,988 kB
  • sloc: perl: 533; ansic: 491; python: 432; sh: 327; makefile: 174
file content (278 lines) | stat: -rw-r--r-- 9,821 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
// Copyright (C) MongoDB, Inc. 2017-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package integration

import (
	"context"
	"net"
	"os"
	"runtime"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"go.mongodb.org/mongo-driver/bson/primitive"
	"go.mongodb.org/mongo-driver/event"
	"go.mongodb.org/mongo-driver/internal/assert"
	"go.mongodb.org/mongo-driver/internal/handshake"
	"go.mongodb.org/mongo-driver/internal/require"
	"go.mongodb.org/mongo-driver/mongo/address"
	"go.mongodb.org/mongo-driver/mongo/description"
	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
	"go.mongodb.org/mongo-driver/mongo/options"
	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
)

func TestSDAMProse(t *testing.T) {
	mt := mtest.New(t)

	// Server limits non-streaming heartbeats and explicit server transition checks to at most one
	// per 500ms. Set the test interval to 500ms to minimize the difference between the behavior of
	// streaming and non-streaming heartbeat intervals.
	heartbeatInterval := 500 * time.Millisecond
	heartbeatIntervalClientOpts := options.Client().
		SetHeartbeatInterval(heartbeatInterval)
	heartbeatIntervalMtOpts := mtest.NewOptions().
		ClientOptions(heartbeatIntervalClientOpts).
		CreateCollection(false).
		ClientType(mtest.Proxy).
		MinServerVersion("4.4") // RTT Monitor / Streaming protocol is not supported for versions < 4.4.
	mt.RunOpts("heartbeats processed more frequently", heartbeatIntervalMtOpts, func(mt *mtest.T) {
		// Test that setting heartbeat interval to 500ms causes the client to process heartbeats
		// approximately every 500ms instead of the default 10s. Note that a Client doesn't
		// guarantee that it will process heartbeats exactly every 500ms, just that it will wait at
		// least 500ms between heartbeats (and should process heartbeats more frequently for shorter
		// interval settings).
		//
		// For number of nodes N, interval I, and duration D, a Client should process at most X
		// operations:
		//
		//   X = (N * (2 handshakes + D/I heartbeats + D/I RTTs))
		//
		// Assert that a Client processes the expected number of operations for heartbeats sent at
		// an interval between I and 2*I to account for different actual heartbeat intervals under
		// different runtime conditions.

		// Measure the actual amount of time between the start of the test and when we inspect the
		// sent messages. The sleep duration will be at least the specified duration but
		// possibly longer, which could lead to extra heartbeat messages, so account for that in
		// the assertions.
		if len(os.Getenv("DOCKER_RUNNING")) > 0 {
			mt.Skip("skipping test in docker environment")
		}
		start := time.Now()
		time.Sleep(2 * time.Second)
		messages := mt.GetProxiedMessages()
		duration := time.Since(start)

		numNodes := len(options.Client().ApplyURI(mtest.ClusterURI()).Hosts)
		maxExpected := numNodes * (2 + 2*int(duration/heartbeatInterval))
		minExpected := numNodes * (2 + 2*int(duration/(heartbeatInterval*2)))

		assert.True(
			mt,
			len(messages) >= minExpected && len(messages) <= maxExpected,
			"expected number of messages to be in range [%d, %d], got %d"+
				" (num nodes = %d, duration = %v, interval = %v)",
			minExpected,
			maxExpected,
			len(messages),
			numNodes,
			duration,
			heartbeatInterval)
	})

	mt.RunOpts("rtt tests", noClientOpts, func(mt *mtest.T) {
		clientOpts := options.Client().
			SetHeartbeatInterval(500 * time.Millisecond).
			SetAppName("streamingRttTest")
		mtOpts := mtest.NewOptions().
			MinServerVersion("4.4").
			ClientOptions(clientOpts)
		mt.RunOpts("rtt is continuously updated", mtOpts, func(mt *mtest.T) {
			// Test that the RTT monitor updates the RTT for server descriptions.

			// The server has been discovered by the create command issued by mtest. Sleep for two seconds to allow
			// multiple heartbeats to finish.
			testTopology := getTopologyFromClient(mt.Client)
			time.Sleep(2 * time.Second)
			for _, serverDesc := range testTopology.Description().Servers {
				assert.NotEqual(mt, description.Unknown, serverDesc.Kind, "server %v is Unknown", serverDesc)
				assert.True(mt, serverDesc.AverageRTTSet, "AverageRTTSet for server description %v is false", serverDesc)

				if runtime.GOOS != "windows" {
					// Windows has a lower time resolution than other platforms, which causes the reported RTT to be
					// 0 if it's below some threshold. The assertion above already confirms that the RTT is set to
					// a value, so we can skip this assertion on Windows.
					assert.True(mt, serverDesc.AverageRTT > 0, "server description %v has 0 RTT", serverDesc)
				}
			}

			// Force hello requests to block for 500ms and wait until a server's average RTT goes over 250ms.
			mt.SetFailPoint(mtest.FailPoint{
				ConfigureFailPoint: "failCommand",
				Mode: mtest.FailPointMode{
					Times: 1000,
				},
				Data: mtest.FailPointData{
					FailCommands:    []string{handshake.LegacyHello, "hello"},
					BlockConnection: true,
					BlockTimeMS:     500,
					AppName:         "streamingRttTest",
				},
			})
			callback := func() bool {
				// We don't know which server received the failpoint command, so we wait until any of the server
				// RTTs cross the threshold.
				for _, serverDesc := range testTopology.Description().Servers {
					if serverDesc.AverageRTT > 250*time.Millisecond {
						return true
					}
				}

				// The next update will be in ~500ms.
				return false
			}
			assert.Eventually(t,
				callback,
				defaultCallbackTimeout,
				500*time.Millisecond,
				"expected average rtt heartbeats at least within every 500 ms period")
		})
	})

	mt.RunOpts("client waits between failed Hellos", mtest.NewOptions().MinServerVersion("4.9").Topologies(mtest.Single), func(mt *mtest.T) {
		// Force hello requests to fail 5 times.
		mt.SetFailPoint(mtest.FailPoint{
			ConfigureFailPoint: "failCommand",
			Mode: mtest.FailPointMode{
				Times: 5,
			},
			Data: mtest.FailPointData{
				FailCommands: []string{handshake.LegacyHello, "hello"},
				ErrorCode:    1234,
				AppName:      "SDAMMinHeartbeatFrequencyTest",
			},
		})

		// Reset client options to use direct connection, app name, and 5s SS timeout.
		clientOpts := options.Client().SetDirect(true).
			SetAppName("SDAMMinHeartbeatFrequencyTest").
			SetServerSelectionTimeout(5 * time.Second)
		mt.ResetClient(clientOpts)

		// Assert that Ping completes successfully within 2 to 3.5 seconds.
		start := time.Now()
		err := mt.Client.Ping(context.Background(), nil)
		assert.Nil(mt, err, "Ping error: %v", err)
		pingTime := time.Since(start)
		assert.True(mt, pingTime > 2000*time.Millisecond && pingTime < 3500*time.Millisecond,
			"expected Ping to take between 2 and 3.5 seconds, took %v seconds", pingTime.Seconds())

	})
}

func TestServerHeartbeatStartedEvent(t *testing.T) {
	t.Run("emits the first HeartbeatStartedEvent before the monitoring socket was created", func(t *testing.T) {
		t.Parallel()

		const address = address.Address("localhost:9999")
		expectedEvents := []string{
			"serverHeartbeatStartedEvent",
			"client connected",
			"client hello received",
			"serverHeartbeatFailedEvent",
		}

		events := make(chan string)

		listener, err := net.Listen("tcp", address.String())
		assert.NoError(t, err)
		defer listener.Close()
		go func() {
			conn, err := listener.Accept()
			assert.NoError(t, err)
			defer conn.Close()

			events <- "client connected"
			_, _ = conn.Read(nil)
			events <- "client hello received"
		}()

		server := topology.NewServer(
			address,
			primitive.NewObjectID(),
			topology.WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor {
				return &event.ServerMonitor{
					ServerHeartbeatStarted: func(*event.ServerHeartbeatStartedEvent) {
						events <- "serverHeartbeatStartedEvent"
					},
					ServerHeartbeatFailed: func(*event.ServerHeartbeatFailedEvent) {
						events <- "serverHeartbeatFailedEvent"
					},
				}
			}),
		)
		require.NoError(t, server.Connect(nil))

		ticker := time.NewTicker(5 * time.Second)
		defer ticker.Stop()

		actualEvents := make([]string, 0, len(expectedEvents))
		for len(actualEvents) < len(expectedEvents) {
			select {
			case event := <-events:
				actualEvents = append(actualEvents, event)
			case <-ticker.C:
				assert.FailNow(t, "timed out for incoming event")
			}
		}
		assert.Equal(t, expectedEvents, actualEvents)
	})

	mt := mtest.New(t)

	mt.Run("polling must await frequency", func(mt *mtest.T) {
		var heartbeatStartedCount atomic.Int64

		servers := map[string]bool{}
		serversMu := sync.RWMutex{} // Guard the servers set

		serverMonitor := &event.ServerMonitor{
			ServerHeartbeatStarted: func(*event.ServerHeartbeatStartedEvent) {
				heartbeatStartedCount.Add(1)
			},
			TopologyDescriptionChanged: func(evt *event.TopologyDescriptionChangedEvent) {
				serversMu.Lock()
				defer serversMu.Unlock()

				for _, srv := range evt.NewDescription.Servers {
					servers[srv.Addr.String()] = true
				}
			},
		}

		// Create a client with  heartbeatFrequency=100ms,
		// serverMonitoringMode=poll. Use SDAM to record the number of times the
		// a heartbeat is started and the number of servers discovered.
		mt.ResetClient(options.Client().
			SetServerMonitor(serverMonitor).
			SetServerMonitoringMode(options.ServerMonitoringModePoll))

		// Per specifications, minHeartbeatFrequencyMS=500ms. So, within the first
		// 500ms the heartbeatStartedCount should be LEQ to the number of discovered
		// servers.
		time.Sleep(500 * time.Millisecond)

		serversMu.Lock()
		serverCount := int64(len(servers))
		serversMu.Unlock()

		assert.LessOrEqual(mt, heartbeatStartedCount.Load(), serverCount)
	})
}