File: router_to_tunclient.go

package info (click to toggle)
gitlab-agent 16.11.5-1
  • links: PTS, VCS
  • area: contrib
  • in suites: experimental
  • size: 7,072 kB
  • sloc: makefile: 193; sh: 55; ruby: 3
file content (140 lines) | stat: -rw-r--r-- 3,741 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package tunserver

import (
	"io"
	"strings"

	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/module/modshared"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/prototool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tunnel/rpc"
	"go.uber.org/zap"
	statuspb "google.golang.org/genproto/googleapis/rpc/status"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

func (r *Router) routeToTunclient(srv interface{}, stream grpc.ServerStream) error {
	ctx := stream.Context()
	rpcAPI := modshared.RPCAPIFromContext(ctx)
	tunnelFound, log, findHandle, err := r.plugin.FindTunnel(stream, rpcAPI)
	if err != nil {
		return err
	}
	defer findHandle.Done(ctx)

	md, _ := metadata.FromIncomingContext(ctx)
	// Overwrite incoming MD with sanitized MD
	wrappedStream := grpc_middleware.WrapServerStream(stream)
	wrappedStream.WrappedContext = metadata.NewIncomingContext(
		wrappedStream.WrappedContext,
		removeHopMeta(md),
	)
	ctx = wrappedStream.WrappedContext
	stream = wrappedStream
	if !tunnelFound {
		err = stream.SendMsg(&rpc.GatewayResponse{
			Msg: &rpc.GatewayResponse_NoTunnel_{
				NoTunnel: &rpc.GatewayResponse_NoTunnel{},
			},
		})
		if err != nil {
			return rpcAPI.HandleIOError(log, "SendMsg(GatewayResponse_NoTunnel) failed", err)
		}
	}
	tun, err := findHandle.Get(ctx)
	if err != nil {
		return err
	}
	defer tun.Done(ctx)
	err = stream.SendMsg(&rpc.GatewayResponse{
		Msg: &rpc.GatewayResponse_TunnelReady_{
			TunnelReady: &rpc.GatewayResponse_TunnelReady{},
		},
	})
	if err != nil {
		return rpcAPI.HandleIOError(log, "SendMsg(GatewayResponse_TunnelReady) failed", err)
	}
	var start rpc.StartStreaming
	err = stream.RecvMsg(&start)
	if err != nil {
		if err == io.EOF { //nolint:errorlint
			// Routing kas decided not to proceed
			return nil
		}
		return err
	}
	return tun.ForwardStream(log, rpcAPI, stream, newWrappingCallback(log, rpcAPI, stream))
}

func removeHopMeta(md metadata.MD) metadata.MD {
	md = md.Copy()
	for k := range md {
		if strings.HasPrefix(k, RoutingHopPrefix) {
			delete(md, k)
		}
	}
	return md
}

type wrappingCallback struct {
	log    *zap.Logger
	rpcAPI modshared.RPCAPI
	stream grpc.ServerStream
}

func newWrappingCallback(log *zap.Logger, rpcAPI modshared.RPCAPI, stream grpc.ServerStream) *wrappingCallback {
	return &wrappingCallback{
		log:    log,
		rpcAPI: rpcAPI,
		stream: stream,
	}
}

func (c *wrappingCallback) Header(md map[string]*prototool.Values) error {
	return c.sendMsg("SendMsg(GatewayResponse_Header) failed", &rpc.GatewayResponse{
		Msg: &rpc.GatewayResponse_Header_{
			Header: &rpc.GatewayResponse_Header{
				Meta: md,
			},
		},
	})
}

func (c *wrappingCallback) Message(data []byte) error {
	return c.sendMsg("SendMsg(GatewayResponse_Message) failed", &rpc.GatewayResponse{
		Msg: &rpc.GatewayResponse_Message_{
			Message: &rpc.GatewayResponse_Message{
				Data: data,
			},
		},
	})
}

func (c *wrappingCallback) Trailer(md map[string]*prototool.Values) error {
	return c.sendMsg("SendMsg(GatewayResponse_Trailer) failed", &rpc.GatewayResponse{
		Msg: &rpc.GatewayResponse_Trailer_{
			Trailer: &rpc.GatewayResponse_Trailer{
				Meta: md,
			},
		},
	})
}

func (c *wrappingCallback) Error(stat *statuspb.Status) error {
	return c.sendMsg("SendMsg(GatewayResponse_Error) failed", &rpc.GatewayResponse{
		Msg: &rpc.GatewayResponse_Error_{
			Error: &rpc.GatewayResponse_Error{
				Status: stat,
			},
		},
	})
}

func (c *wrappingCallback) sendMsg(errMsg string, msg *rpc.GatewayResponse) error {
	err := c.stream.SendMsg(msg)
	if err != nil {
		return c.rpcAPI.HandleIOError(c.log, errMsg, err)
	}
	return nil
}