File: tunnel.go

package info (click to toggle)
gitlab-agent 16.11.5-1
  • links: PTS, VCS
  • area: contrib
  • in suites: experimental
  • size: 7,072 kB
  • sloc: makefile: 193; sh: 55; ruby: 3
file content (215 lines) | stat: -rw-r--r-- 8,075 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
package tunserver

import (
	"context"
	"io"

	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/prototool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tunnel/info"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tunnel/rpc"
	"go.uber.org/zap"
	statuspb "google.golang.org/genproto/googleapis/rpc/status"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/reflect/protoreflect"
)

type StateType int

const (
	// zero value is invalid to catch initialization bugs.
	_ StateType = iota
	// StateReady - tunnel is owned by the registry and is ready to be found and used for forwarding.
	StateReady
	// StateFound - tunnel is not owned by registry, was found and about to be used for forwarding.
	StateFound
	// StateForwarding - tunnel is not owned by registry, is being used for forwarding.
	StateForwarding
	// StateDone - tunnel is not owned by anyone, it has been used for forwarding, Done() has been called.
	StateDone
	// StateContextDone - tunnel is not owned by anyone, reverse tunnel's context signaled done in HandleTunnel().
	StateContextDone
)

const (
	descriptorNumber protoreflect.FieldNumber = 1
	headerNumber     protoreflect.FieldNumber = 2
	messageNumber    protoreflect.FieldNumber = 3
	trailerNumber    protoreflect.FieldNumber = 4
	errorNumber      protoreflect.FieldNumber = 5
)

type DataCallback interface {
	Header(map[string]*prototool.Values) error
	Message([]byte) error
	Trailer(map[string]*prototool.Values) error
	Error(*statuspb.Status) error
}

type Tunnel interface {
	// ForwardStream performs bi-directional message forwarding between incomingStream and the tunnel.
	// cb is called with header, messages and trailer coming from the tunnel. It's the callers
	// responsibility to forward them into the incomingStream.
	ForwardStream(log *zap.Logger, rpcAPI modshared.RPCAPI, incomingStream grpc.ServerStream, cb DataCallback) error
	// Done must be called when the caller is done with the Tunnel.
	// ctx is used for tracing only.
	Done(ctx context.Context)
}

type TunnelImpl struct {
	Tunnel       rpc.ReverseTunnel_ConnectServer
	TunnelRetErr chan<- error
	AgentID      int64
	Descriptor   *info.APIDescriptor
	State        StateType

	OnForward func(*TunnelImpl) error
	OnDone    func(context.Context, *TunnelImpl)
}

func (t *TunnelImpl) ForwardStream(log *zap.Logger, rpcAPI modshared.RPCAPI, incomingStream grpc.ServerStream, cb DataCallback) error {
	if err := t.OnForward(t); err != nil {
		return err
	}
	pair := t.forwardStream(log, rpcAPI, incomingStream, cb)
	t.TunnelRetErr <- pair.forTunnel
	return pair.forIncomingStream
}

func (t *TunnelImpl) forwardStream(log *zap.Logger, rpcAPI modshared.RPCAPI, incomingStream grpc.ServerStream, cb DataCallback) errPair {
	// Here we have a situation where we need to pipe one server stream into another server stream.
	// One stream is incoming request stream and the other one is incoming tunnel stream.
	// We need to use at least one extra goroutine in addition to the current one (or two separate ones) to
	// implement full duplex bidirectional stream piping. One goroutine reads and writes in one direction and the other
	// one in the opposite direction.
	// What if one of them returns an error? We need to unblock the other one, ideally ASAP, to release resources. If
	// it's not unblocked, it'll sit there until it hits a timeout or is aborted by peer. Ok-ish, but far from ideal.
	// To abort request processing on the server side, gRPC stream handler should just return from the call.
	// See https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474
	// To implement this, we read and write in both directions in separate goroutines and return from both
	// handlers whenever there is an error, aborting both connections:
	// - Returning from this function means returning from the incoming request handler.
	// - Sending to c.TunnelRetErr leads to returning that value from the tunnel handler.

	// Channel of size 1 to ensure that if we return early, the second goroutine has space for the value.
	// We don't care about the second value if the first one has at least one non-nil error.
	res := make(chan errPair, 1)
	incomingCtx := incomingStream.Context()
	// Pipe incoming stream (i.e. data a client is sending us) into the tunnel stream
	goErrPair(res, func() (error /* forTunnel */, error /* forIncomingStream */) {
		md, _ := metadata.FromIncomingContext(incomingCtx)
		err := t.Tunnel.Send(&rpc.ConnectResponse{
			Msg: &rpc.ConnectResponse_RequestInfo{
				RequestInfo: &rpc.RequestInfo{
					MethodName: grpc.ServerTransportStreamFromContext(incomingCtx).Method(),
					Meta:       grpctool.MetaToValuesMap(md),
				},
			},
		})
		if err != nil {
			err = rpcAPI.HandleIOError(log, "Send(ConnectResponse_RequestInfo)", err)
			return err, err
		}
		// Outside the loop to allocate once vs on each message
		var frame grpctool.RawFrame
		for {
			err = incomingStream.RecvMsg(&frame)
			if err != nil {
				if err == io.EOF { //nolint:errorlint
					break
				}
				return status.Error(codes.Canceled, "read from incoming stream"), err
			}
			err = t.Tunnel.Send(&rpc.ConnectResponse{
				Msg: &rpc.ConnectResponse_Message{
					Message: &rpc.Message{
						Data: frame.Data,
					},
				},
			})
			if err != nil {
				err = rpcAPI.HandleIOError(log, "Send(ConnectResponse_Message)", err)
				return err, err
			}
		}
		err = t.Tunnel.Send(&rpc.ConnectResponse{
			Msg: &rpc.ConnectResponse_CloseSend{
				CloseSend: &rpc.CloseSend{},
			},
		})
		if err != nil {
			err = rpcAPI.HandleIOError(log, "Send(ConnectResponse_CloseSend)", err)
			return err, err
		}
		return nil, nil
	})
	// Pipe tunnel stream (i.e. data agentk is sending us) into the incoming stream
	goErrPair(res, func() (error /* forTunnel */, error /* forIncomingStream */) {
		var forTunnel, forIncomingStream error
		fromVisitor := rpc.ConnectRequestVisitor().Visit(t.Tunnel,
			grpctool.WithStartState(descriptorNumber),
			grpctool.WithCallback(headerNumber, func(header *rpc.Header) error {
				return cb.Header(header.Meta)
			}),
			grpctool.WithCallback(messageNumber, func(message *rpc.Message) error {
				return cb.Message(message.Data)
			}),
			grpctool.WithCallback(trailerNumber, func(trailer *rpc.Trailer) error {
				return cb.Trailer(trailer.Meta)
			}),
			grpctool.WithCallback(errorNumber, func(rpcError *rpc.Error) error {
				forIncomingStream = cb.Error(rpcError.Status)
				// Not returning an error since we must be reading from the tunnel stream till io.EOF
				// to properly consume it. There is no need to abort it in this scenario.
				// The server is expected to close the stream (i.e. we'll get io.EOF) right after we got this message.
				return nil
			}),
		)
		if fromVisitor != nil {
			forIncomingStream = fromVisitor
			forTunnel = fromVisitor
		}
		return forTunnel, forIncomingStream
	})
	pair := <-res
	if !pair.isNil() {
		return pair
	}
	select {
	case <-incomingCtx.Done():
		// incoming stream finished sending all data (i.e. io.EOF was read from it) but
		// now it signals that it's closing. We need to abort the potentially stuck t.tunnel.RecvMsg().
		err := grpctool.StatusErrorFromContext(incomingCtx, "Incoming stream closed")
		pair = errPair{
			forTunnel:         err,
			forIncomingStream: err,
		}
	case pair = <-res:
	}
	return pair
}

func (t *TunnelImpl) Done(ctx context.Context) {
	t.OnDone(ctx, t)
}

type errPair struct {
	forTunnel         error
	forIncomingStream error
}

func (p errPair) isNil() bool {
	return p.forTunnel == nil && p.forIncomingStream == nil
}

func goErrPair(c chan<- errPair, f func() (error /* forTunnel */, error /* forIncomingStream */)) {
	go func() {
		var pair errPair
		pair.forTunnel, pair.forIncomingStream = f()
		c <- pair
	}()
}