1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
package kasapp
import (
"strconv"
"time"
"github.com/prometheus/client_golang/prometheus"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/module/modserver"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/module/reverse_tunnel"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/module/reverse_tunnel/tracker"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/grpctool"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/metric"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/retry"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
const (
kasRoutingDurationMetricName = "k8s_api_proxy_routing_duration_seconds"
kasRoutingTimeoutMetricName = "k8s_api_proxy_routing_timeout_total"
kasRoutingStatusLabelName = "status"
kasRoutingStatusSuccessLabelValue = "success"
kasRoutingStatusAbortedLabelValue = "aborted"
routerTracerName = "tunnel-router"
)
type kasRouter interface {
RegisterAgentApi(desc *grpc.ServiceDesc)
}
// router routes traffic from kas to another kas to agentk.
// routing kas -> gateway kas -> agentk
type router struct {
kasPool grpctool.PoolInterface
tunnelQuerier tracker.PollingQuerier
tunnelFinder reverse_tunnel.TunnelFinder
ownPrivateApiUrl string
pollConfig retry.PollConfigFactory
// internalServer is the internal gRPC server for use inside of kas.
// Request handlers can obtain the per-request logger using grpctool.LoggerFromContext(requestContext).
internalServer grpc.ServiceRegistrar
// privateApiServer is the gRPC server that other kas instances can talk to.
// Request handlers can obtain the per-request logger using grpctool.LoggerFromContext(requestContext).
privateApiServer grpc.ServiceRegistrar
gatewayKasVisitor *grpctool.StreamVisitor
tracer trace.Tracer
kasRoutingDurationSuccess prometheus.Observer
kasRoutingDurationAborted prometheus.Observer
kasRoutingDurationTimeout prometheus.Counter
tunnelFindTimeout time.Duration
}
func newRouter(kasPool grpctool.PoolInterface, tunnelQuerier tracker.PollingQuerier,
tunnelFinder reverse_tunnel.TunnelFinder, ownPrivateApiUrl string,
internalServer, privateApiServer grpc.ServiceRegistrar,
pollConfig retry.PollConfigFactory, tp trace.TracerProvider, registerer prometheus.Registerer) (*router, error) {
gatewayKasVisitor, err := grpctool.NewStreamVisitor(&GatewayKasResponse{})
if err != nil {
return nil, err
}
routingDuration, timeoutCounter := constructKasRoutingMetrics()
err = metric.Register(registerer, routingDuration, timeoutCounter)
if err != nil {
return nil, err
}
return &router{
kasPool: kasPool,
tunnelQuerier: tunnelQuerier,
tunnelFinder: tunnelFinder,
ownPrivateApiUrl: ownPrivateApiUrl,
pollConfig: pollConfig,
internalServer: internalServer,
privateApiServer: privateApiServer,
gatewayKasVisitor: gatewayKasVisitor,
tracer: tp.Tracer(routerTracerName),
kasRoutingDurationSuccess: routingDuration.WithLabelValues(kasRoutingStatusSuccessLabelValue),
kasRoutingDurationAborted: routingDuration.WithLabelValues(kasRoutingStatusAbortedLabelValue),
kasRoutingDurationTimeout: timeoutCounter,
tunnelFindTimeout: routingTunnelFindTimeout,
}, nil
}
func constructKasRoutingMetrics() (*prometheus.HistogramVec, prometheus.Counter) {
hist := prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: kasRoutingDurationMetricName,
Help: "The time it takes the routing kas to find a suitable tunnel in seconds",
Buckets: prometheus.ExponentialBuckets(time.Millisecond.Seconds(), 4, 8), // 8 buckets: 0.001s,0.004s,0.016s,0.064s,0.256s,1.024s,4.096s,16.384s, implicit: +Infs
}, []string{kasRoutingStatusLabelName})
timeoutCounter := prometheus.NewCounter(prometheus.CounterOpts{
Name: kasRoutingTimeoutMetricName,
Help: "The total number of times routing timed out i.e. didn't find a suitable agent connection within allocated time",
})
return hist, timeoutCounter
}
func (r *router) RegisterAgentApi(desc *grpc.ServiceDesc) {
// 1. Munge the descriptor into the right shape:
// - turn all unary calls into streaming calls
// - all streaming calls, including the ones from above, are handled by routing handlers
internalServerDesc := mungeDescriptor(desc, r.RouteToKasStreamHandler)
privateApiServerDesc := mungeDescriptor(desc, r.RouteToAgentStreamHandler)
// 2. Register on InternalServer gRPC server so that ReverseTunnelClient can be used in kas to send data to
// this API within this kas instance. This kas instance then routes the stream to the gateway kas instance.
r.internalServer.RegisterService(internalServerDesc, nil)
// 3. Register on PrivateApiServer gRPC server so that this kas instance can act as the gateway kas instance
// from above and then route to one of the matching connected agentk instances.
r.privateApiServer.RegisterService(privateApiServerDesc, nil)
}
func mungeDescriptor(in *grpc.ServiceDesc, handler grpc.StreamHandler) *grpc.ServiceDesc {
streams := make([]grpc.StreamDesc, 0, len(in.Streams)+len(in.Methods))
for _, stream := range in.Streams {
streams = append(streams, grpc.StreamDesc{
StreamName: stream.StreamName,
Handler: handler,
ServerStreams: true,
ClientStreams: true,
})
}
// Turn all methods into streams
for _, method := range in.Methods {
streams = append(streams, grpc.StreamDesc{
StreamName: method.MethodName,
Handler: handler,
ServerStreams: true,
ClientStreams: true,
})
}
return &grpc.ServiceDesc{
ServiceName: in.ServiceName,
Streams: streams,
Metadata: in.Metadata,
}
}
func agentIdFromMeta(md metadata.MD) (int64, error) {
val := md.Get(modserver.RoutingAgentIdMetadataKey)
if len(val) != 1 {
return 0, status.Errorf(codes.InvalidArgument, "Expecting a single %s, got %d", modserver.RoutingAgentIdMetadataKey, len(val))
}
agentId, err := strconv.ParseInt(val[0], 10, 64)
if err != nil {
return 0, status.Errorf(codes.InvalidArgument, "Invalid %s", modserver.RoutingAgentIdMetadataKey)
}
return agentId, nil
}
|