File: gateway_finder.go

package info (click to toggle)
gitlab-agent 16.11.5-1
  • links: PTS, VCS
  • area: contrib
  • in suites: experimental
  • size: 7,072 kB
  • sloc: makefile: 193; sh: 55; ruby: 3
file content (351 lines) | stat: -rw-r--r-- 11,019 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
package tunserver

import (
	"context"
	"errors"
	"io"
	"sync"
	"time"

	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/logz"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/retry"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tunnel/rpc"
	"go.opentelemetry.io/otel/trace"
	"go.uber.org/zap"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"k8s.io/apimachinery/pkg/util/wait"
)

var (
	proxyStreamDesc = grpc.StreamDesc{
		ServerStreams: true,
		ClientStreams: true,
	}

	// tunnelReadySentinelError is a sentinel error value to make stream visitor exit early.
	tunnelReadySentinelError = errors.New("")
)

type connAttempt struct {
	cancel context.CancelFunc
}

type ReadyGateway struct {
	URL          string
	Stream       grpc.ClientStream
	Conn         grpctool.PoolConn
	StreamCancel context.CancelFunc
}

func (g ReadyGateway) Done() {
	g.StreamCancel()
	g.Conn.Done()
}

// PollGatewayURLsCallback is called periodically with found kas URLs for a particular agent id.
type PollGatewayURLsCallback func(kasURLs []string)

type PollingGatewayURLQuerier interface {
	PollGatewayURLs(ctx context.Context, agentID int64, cb PollGatewayURLsCallback)
	CachedGatewayURLs(agentID int64) []string
}

type GatewayFinder interface {
	Find(ctx context.Context) (ReadyGateway, error)
}

type gatewayFinder struct {
	log                   *zap.Logger
	gatewayPool           grpctool.PoolInterface
	gatewayQuerier        PollingGatewayURLQuerier
	rpcAPI                modshared.RPCAPI
	fullMethod            string // /service/method
	ownPrivateAPIURL      string
	agentID               int64
	outgoingCtx           context.Context
	pollConfig            retry.PollConfigFactory
	foundGateway          chan ReadyGateway
	noTunnel              chan struct{}
	wg                    wait.Group
	pollCancel            context.CancelFunc
	tryNewGatewayInterval time.Duration

	mu          sync.Mutex             // protects the fields below
	connections map[string]connAttempt // gateway tunserver URL -> conn info
	gatewayURLs []string               // currently known gateway tunserver URLs for the agent id
	done        bool                   // successfully done searching
}

func NewGatewayFinder(log *zap.Logger, gatewayPool grpctool.PoolInterface, gatewayQuerier PollingGatewayURLQuerier,
	rpcAPI modshared.RPCAPI, fullMethod string, ownPrivateAPIURL string, agentID int64, outgoingCtx context.Context,
	pollConfig retry.PollConfigFactory, tryNewGatewayInterval time.Duration) GatewayFinder {
	return &gatewayFinder{
		log:                   log,
		gatewayPool:           gatewayPool,
		gatewayQuerier:        gatewayQuerier,
		rpcAPI:                rpcAPI,
		fullMethod:            fullMethod,
		ownPrivateAPIURL:      ownPrivateAPIURL,
		agentID:               agentID,
		outgoingCtx:           outgoingCtx,
		pollConfig:            pollConfig,
		tryNewGatewayInterval: tryNewGatewayInterval,
		foundGateway:          make(chan ReadyGateway),
		noTunnel:              make(chan struct{}),
		connections:           make(map[string]connAttempt),
	}
}

func (f *gatewayFinder) Find(ctx context.Context) (ReadyGateway, error) {
	defer f.wg.Wait()
	var pollCtx context.Context
	pollCtx, f.pollCancel = context.WithCancel(ctx)
	defer f.pollCancel()

	// Unconditionally connect to self ASAP.
	f.tryGatewayLocked(f.ownPrivateAPIURL) //nolint: contextcheck
	startedPolling := false
	// This flag is set when we've run out of gateway tunserver URLs to try. When a new set of URLs is received, if this is set,
	// we try to connect to one of those URLs.
	needToTryNewGateway := false

	// Timer is used to wake up the loop below after a certain amount of time has passed but there has been no activity,
	// in particular, a recently connected to gateway tunserver didn't reply with noTunnel. If it's not replying, we
	// need to try another instance if it has been discovered.
	// If, for some reason, our own private API server doesn't respond with noTunnel/startStreaming in time, we
	// want to proceed with normal flow too.
	t := time.NewTimer(f.tryNewGatewayInterval)
	defer t.Stop()
	gatewayURLsC := make(chan []string)
	f.gatewayURLs = f.gatewayQuerier.CachedGatewayURLs(f.agentID)
	done := ctx.Done()

	// Timer must have been stopped or has fired when this function is called
	tryNewGatewayWhenTimerNotRunning := func() {
		if f.tryNewGateway() { //nolint: contextcheck
			// Connected to an instance.
			needToTryNewGateway = false
			t.Reset(f.tryNewGatewayInterval)
		} else {
			// Couldn't find a gateway tunserver instance we haven't connected to already.
			needToTryNewGateway = true
			if !startedPolling {
				startedPolling = true
				// No more cached instances, start polling for gateway tunserver instances.
				f.wg.Start(func() {
					pollDone := pollCtx.Done()
					f.gatewayQuerier.PollGatewayURLs(pollCtx, f.agentID, func(gatewayURLs []string) {
						select {
						case <-pollDone:
						case gatewayURLsC <- gatewayURLs:
						}
					})
				})
			}
		}
	}

	for {
		select {
		case <-done:
			f.stopAllConnectionAttempts()
			return ReadyGateway{}, ctx.Err()
		case <-f.noTunnel:
			stopAndDrain(t)
			tryNewGatewayWhenTimerNotRunning()
		case gatewayURLs := <-gatewayURLsC:
			f.mu.Lock()
			f.gatewayURLs = gatewayURLs
			f.mu.Unlock()
			if !needToTryNewGateway {
				continue
			}
			if f.tryNewGateway() { //nolint: contextcheck
				// Connected to a new gateway instance.
				needToTryNewGateway = false
				stopAndDrain(t)
				t.Reset(f.tryNewGatewayInterval)
			}
		case <-t.C:
			tryNewGatewayWhenTimerNotRunning()
		case rt := <-f.foundGateway:
			f.stopAllConnectionAttemptsExcept(rt.URL)
			return rt, nil
		}
	}
}

func (f *gatewayFinder) tryNewGateway() bool {
	f.mu.Lock()
	defer f.mu.Unlock()
	for _, gatewayURL := range f.gatewayURLs {
		if _, ok := f.connections[gatewayURL]; ok {
			continue // skip gateway tunserver that we have connected to already
		}
		f.tryGatewayLocked(gatewayURL)
		return true
	}
	return false
}

func (f *gatewayFinder) tryGatewayLocked(gatewayURL string) {
	connCtx, connCancel := context.WithCancel(f.outgoingCtx)
	f.connections[gatewayURL] = connAttempt{
		cancel: connCancel,
	}
	f.wg.Start(func() {
		f.tryGatewayAsync(connCtx, connCancel, gatewayURL)
	})
}

func (f *gatewayFinder) tryGatewayAsync(ctx context.Context, cancel context.CancelFunc, gatewayURL string) {
	log := f.log.With(logz.KASURL(gatewayURL))
	noTunnelSent := false
	_ = retry.PollWithBackoff(ctx, f.pollConfig(), func(ctx context.Context) (error, retry.AttemptResult) {
		success := false

		// 1. Dial another gateway tunserver
		log.Debug("Trying tunnel")
		attemptCtx, attemptCancel := context.WithCancel(ctx)
		defer func() {
			if !success {
				attemptCancel()
				f.maybeStopTrying(gatewayURL)
			}
		}()
		gatewayConn, err := f.gatewayPool.Dial(attemptCtx, gatewayURL)
		if err != nil {
			f.rpcAPI.HandleProcessingError(log, f.agentID, "Failed to dial gateway tunserver", err)
			return nil, retry.Backoff
		}
		defer func() {
			if !success {
				gatewayConn.Done()
			}
		}()

		// 2. Open a stream to the desired service/method
		gatewayStream, err := gatewayConn.NewStream(
			attemptCtx,
			&proxyStreamDesc,
			f.fullMethod,
			grpc.ForceCodec(grpctool.RawCodecWithProtoFallback{}),
			grpc.WaitForReady(true),
		)
		if err != nil {
			f.rpcAPI.HandleProcessingError(log, f.agentID, "Failed to open a new stream to gateway tunserver", err)
			return nil, retry.Backoff
		}

		// 3. Wait for the gateway tunserver to say it's ready to start streaming i.e. has a suitable tunnel to an agent
		err = rpc.GatewayResponseVisitor().Visit(gatewayStream,
			grpctool.WithCallback(noTunnelFieldNumber, func(noTunnel *rpc.GatewayResponse_NoTunnel) error {
				trace.SpanFromContext(gatewayStream.Context()).AddEvent("No tunnel") //nolint: contextcheck
				if !noTunnelSent {                                                   // send only once
					noTunnelSent = true
					// Let Find() know there is no tunnel available from that gateway tunserver instantaneously.
					// A tunnel may still be found when a suitable agent connects later, but none available immediately.
					select {
					case <-attemptCtx.Done():
					case f.noTunnel <- struct{}{}:
					}
				}
				return nil
			}),
			grpctool.WithCallback(tunnelReadyFieldNumber, func(tunnelReady *rpc.GatewayResponse_TunnelReady) error { //nolint:contextcheck
				trace.SpanFromContext(gatewayStream.Context()).AddEvent("Ready")
				return tunnelReadySentinelError
			}),
			grpctool.WithNotExpectingToGet(codes.Internal, headerFieldNumber, messageFieldNumber, trailerFieldNumber, errorFieldNumber),
		)
		switch err { //nolint:errorlint
		case nil:
			// Gateway tunserver closed the connection cleanly, perhaps it's been open for too long
			return nil, retry.ContinueImmediately
		case tunnelReadySentinelError:
			// fallthrough
		default:
			f.rpcAPI.HandleProcessingError(log, f.agentID, "RecvMsg(GatewayResponse)", err)
			return nil, retry.Backoff
		}

		// 4. Check if another goroutine has found a suitable tunnel already
		f.mu.Lock() // Ensure only one gateway tunserver gets StartStreaming message
		if f.done {
			f.mu.Unlock()
			return nil, retry.Done
		}
		// 5. Tell the gateway tunserver we are starting streaming
		err = gatewayStream.SendMsg(&rpc.StartStreaming{})
		if err != nil {
			f.mu.Unlock()
			if err == io.EOF { //nolint:errorlint
				var frame grpctool.RawFrame
				err = gatewayStream.RecvMsg(&frame) // get the real error
			}
			_ = f.rpcAPI.HandleIOError(log, "SendMsg(StartStreaming)", err)
			return nil, retry.Backoff
		}
		f.done = true
		f.mu.Unlock()
		f.pollCancel()
		rt := ReadyGateway{
			URL:          gatewayURL,
			Stream:       gatewayStream,
			Conn:         gatewayConn,
			StreamCancel: cancel,
		}
		select {
		case <-attemptCtx.Done():
		case f.foundGateway <- rt:
			success = true
		}
		return nil, retry.Done
	})
}

func (f *gatewayFinder) maybeStopTrying(tryingGatewayURL string) {
	if tryingGatewayURL == f.ownPrivateAPIURL {
		return // keep trying the own URL
	}
	f.mu.Lock()
	defer f.mu.Unlock()
	for _, gatewayURL := range f.gatewayURLs {
		if gatewayURL == tryingGatewayURL {
			return // known URLs still contain this URL so keep trying it.
		}
	}
	attempt := f.connections[tryingGatewayURL]
	delete(f.connections, tryingGatewayURL)
	attempt.cancel()
}

func (f *gatewayFinder) stopAllConnectionAttemptsExcept(gatewayURL string) {
	f.mu.Lock()
	defer f.mu.Unlock()
	for url, c := range f.connections {
		if url != gatewayURL {
			c.cancel()
		}
	}
}

func (f *gatewayFinder) stopAllConnectionAttempts() {
	f.mu.Lock()
	defer f.mu.Unlock()
	for _, c := range f.connections {
		c.cancel()
	}
}

func stopAndDrain(t *time.Timer) {
	if !t.Stop() {
		select {
		case <-t.C:
		default:
		}
	}
}