1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
|
package tunserver
import (
"context"
"io"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/module/modshared"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/grpctool"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/prototool"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tunnel/info"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tunnel/rpc"
"go.uber.org/zap"
statuspb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect"
)
type StateType int
const (
// zero value is invalid to catch initialization bugs.
_ StateType = iota
// StateReady - tunnel is owned by the registry and is ready to be found and used for forwarding.
StateReady
// StateFound - tunnel is not owned by registry, was found and about to be used for forwarding.
StateFound
// StateForwarding - tunnel is not owned by registry, is being used for forwarding.
StateForwarding
// StateDone - tunnel is not owned by anyone, it has been used for forwarding, Done() has been called.
StateDone
// StateContextDone - tunnel is not owned by anyone, reverse tunnel's context signaled done in HandleTunnel().
StateContextDone
)
const (
descriptorNumber protoreflect.FieldNumber = 1
headerNumber protoreflect.FieldNumber = 2
messageNumber protoreflect.FieldNumber = 3
trailerNumber protoreflect.FieldNumber = 4
errorNumber protoreflect.FieldNumber = 5
)
type DataCallback interface {
Header(map[string]*prototool.Values) error
Message([]byte) error
Trailer(map[string]*prototool.Values) error
Error(*statuspb.Status) error
}
type Tunnel interface {
// ForwardStream performs bi-directional message forwarding between incomingStream and the tunnel.
// cb is called with header, messages and trailer coming from the tunnel. It's the callers
// responsibility to forward them into the incomingStream.
ForwardStream(log *zap.Logger, rpcAPI modshared.RPCAPI, incomingStream grpc.ServerStream, cb DataCallback) error
// Done must be called when the caller is done with the Tunnel.
// ctx is used for tracing only.
Done(ctx context.Context)
}
type TunnelImpl struct {
Tunnel rpc.ReverseTunnel_ConnectServer
TunnelRetErr chan<- error
AgentID int64
Descriptor *info.APIDescriptor
State StateType
OnForward func(*TunnelImpl) error
OnDone func(context.Context, *TunnelImpl)
}
func (t *TunnelImpl) ForwardStream(log *zap.Logger, rpcAPI modshared.RPCAPI, incomingStream grpc.ServerStream, cb DataCallback) error {
if err := t.OnForward(t); err != nil {
return err
}
pair := t.forwardStream(log, rpcAPI, incomingStream, cb)
t.TunnelRetErr <- pair.forTunnel
return pair.forIncomingStream
}
func (t *TunnelImpl) forwardStream(log *zap.Logger, rpcAPI modshared.RPCAPI, incomingStream grpc.ServerStream, cb DataCallback) errPair {
// Here we have a situation where we need to pipe one server stream into another server stream.
// One stream is incoming request stream and the other one is incoming tunnel stream.
// We need to use at least one extra goroutine in addition to the current one (or two separate ones) to
// implement full duplex bidirectional stream piping. One goroutine reads and writes in one direction and the other
// one in the opposite direction.
// What if one of them returns an error? We need to unblock the other one, ideally ASAP, to release resources. If
// it's not unblocked, it'll sit there until it hits a timeout or is aborted by peer. Ok-ish, but far from ideal.
// To abort request processing on the server side, gRPC stream handler should just return from the call.
// See https://github.com/grpc/grpc-go/issues/465#issuecomment-179414474
// To implement this, we read and write in both directions in separate goroutines and return from both
// handlers whenever there is an error, aborting both connections:
// - Returning from this function means returning from the incoming request handler.
// - Sending to c.TunnelRetErr leads to returning that value from the tunnel handler.
// Channel of size 1 to ensure that if we return early, the second goroutine has space for the value.
// We don't care about the second value if the first one has at least one non-nil error.
res := make(chan errPair, 1)
incomingCtx := incomingStream.Context()
// Pipe incoming stream (i.e. data a client is sending us) into the tunnel stream
goErrPair(res, func() (error /* forTunnel */, error /* forIncomingStream */) {
md, _ := metadata.FromIncomingContext(incomingCtx)
err := t.Tunnel.Send(&rpc.ConnectResponse{
Msg: &rpc.ConnectResponse_RequestInfo{
RequestInfo: &rpc.RequestInfo{
MethodName: grpc.ServerTransportStreamFromContext(incomingCtx).Method(),
Meta: grpctool.MetaToValuesMap(md),
},
},
})
if err != nil {
err = rpcAPI.HandleIOError(log, "Send(ConnectResponse_RequestInfo)", err)
return err, err
}
// Outside the loop to allocate once vs on each message
var frame grpctool.RawFrame
for {
err = incomingStream.RecvMsg(&frame)
if err != nil {
if err == io.EOF { //nolint:errorlint
break
}
return status.Error(codes.Canceled, "read from incoming stream"), err
}
err = t.Tunnel.Send(&rpc.ConnectResponse{
Msg: &rpc.ConnectResponse_Message{
Message: &rpc.Message{
Data: frame.Data,
},
},
})
if err != nil {
err = rpcAPI.HandleIOError(log, "Send(ConnectResponse_Message)", err)
return err, err
}
}
err = t.Tunnel.Send(&rpc.ConnectResponse{
Msg: &rpc.ConnectResponse_CloseSend{
CloseSend: &rpc.CloseSend{},
},
})
if err != nil {
err = rpcAPI.HandleIOError(log, "Send(ConnectResponse_CloseSend)", err)
return err, err
}
return nil, nil
})
// Pipe tunnel stream (i.e. data agentk is sending us) into the incoming stream
goErrPair(res, func() (error /* forTunnel */, error /* forIncomingStream */) {
var forTunnel, forIncomingStream error
fromVisitor := rpc.ConnectRequestVisitor().Visit(t.Tunnel,
grpctool.WithStartState(descriptorNumber),
grpctool.WithCallback(headerNumber, func(header *rpc.Header) error {
return cb.Header(header.Meta)
}),
grpctool.WithCallback(messageNumber, func(message *rpc.Message) error {
return cb.Message(message.Data)
}),
grpctool.WithCallback(trailerNumber, func(trailer *rpc.Trailer) error {
return cb.Trailer(trailer.Meta)
}),
grpctool.WithCallback(errorNumber, func(rpcError *rpc.Error) error {
forIncomingStream = cb.Error(rpcError.Status)
// Not returning an error since we must be reading from the tunnel stream till io.EOF
// to properly consume it. There is no need to abort it in this scenario.
// The server is expected to close the stream (i.e. we'll get io.EOF) right after we got this message.
return nil
}),
)
if fromVisitor != nil {
forIncomingStream = fromVisitor
forTunnel = fromVisitor
}
return forTunnel, forIncomingStream
})
pair := <-res
if !pair.isNil() {
return pair
}
select {
case <-incomingCtx.Done():
// incoming stream finished sending all data (i.e. io.EOF was read from it) but
// now it signals that it's closing. We need to abort the potentially stuck t.tunnel.RecvMsg().
err := grpctool.StatusErrorFromContext(incomingCtx, "Incoming stream closed")
pair = errPair{
forTunnel: err,
forIncomingStream: err,
}
case pair = <-res:
}
return pair
}
func (t *TunnelImpl) Done(ctx context.Context) {
t.OnDone(ctx, t)
}
type errPair struct {
forTunnel error
forIncomingStream error
}
func (p errPair) isNil() bool {
return p.forTunnel == nil && p.forIncomingStream == nil
}
func goErrPair(c chan<- errPair, f func() (error /* forTunnel */, error /* forIncomingStream */)) {
go func() {
var pair errPair
pair.forTunnel, pair.forIncomingStream = f()
c <- pair
}()
}
|