File: server_interceptors.go

package info (click to toggle)
golang-gitlab-gitlab-org-labkit 1.17.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,092 kB
  • sloc: sh: 210; javascript: 49; makefile: 4
file content (76 lines) | stat: -rw-r--r-- 2,704 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package grpccorrelation

import (
	"context"

	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
	"gitlab.com/gitlab-org/labkit/correlation"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"
)

func extractFromContext(ctx context.Context, propagateIncomingCorrelationID bool) (context.Context, string) {
	var correlationID string
	md, ok := metadata.FromIncomingContext(ctx)
	if ok {
		if propagateIncomingCorrelationID {
			// Extract correlation_id
			correlationID = CorrelationIDFromMetadata(md)
		}

		// Extract client name
		clientNames := md.Get(metadataClientNameKey)
		if len(clientNames) > 0 {
			ctx = correlation.ContextWithClientName(ctx, clientNames[0])
		}
	}
	if correlationID == "" {
		correlationID = correlation.SafeRandomID()
	}
	ctx = correlation.ContextWithCorrelation(ctx, correlationID)
	return ctx, correlationID
}

// CorrelationIDFromMetadata can be used to extract correlation ID from request/response metadata.
// Returns an empty string if correlation ID is not found.
func CorrelationIDFromMetadata(md metadata.MD) string {
	values := md.Get(metadataCorrelatorKey)
	if len(values) > 0 {
		return values[0]
	}
	return ""
}

// UnaryServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services.
func UnaryServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.UnaryServerInterceptor {
	config := applyServerCorrelationInterceptorOptions(opts)
	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
		ctx, correlationID := extractFromContext(ctx, config.propagateIncomingCorrelationID)
		if config.reversePropagateCorrelationID {
			sts := grpc.ServerTransportStreamFromContext(ctx)
			err := sts.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID))
			if err != nil {
				return nil, err
			}
		}
		return handler(ctx, req)
	}
}

// StreamServerCorrelationInterceptor propagates Correlation-IDs from incoming upstream services.
func StreamServerCorrelationInterceptor(opts ...ServerCorrelationInterceptorOption) grpc.StreamServerInterceptor {
	config := applyServerCorrelationInterceptorOptions(opts)
	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
		var correlationID string
		wrapped := grpc_middleware.WrapServerStream(ss)
		wrapped.WrappedContext, correlationID = extractFromContext(ss.Context(), config.propagateIncomingCorrelationID)
		if config.reversePropagateCorrelationID {
			err := wrapped.SetHeader(metadata.Pairs(metadataCorrelatorKey, correlationID))
			if err != nil {
				return err
			}
		}

		return handler(srv, wrapped)
	}
}