1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
package grpccorrelation
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/labkit/correlation"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
const (
correlationID = "CORRELATION_ID"
clientName = "CLIENT_NAME"
methodName = "METHOD_NAME"
)
func verifyContextMetadata(ctx context.Context, require *require.Assertions, expCorrelationID, expClientName string) {
md, ok := metadata.FromOutgoingContext(ctx)
require.True(ok)
ids := md.Get(metadataCorrelatorKey)
require.Less(0, len(ids))
require.Equal(expCorrelationID, ids[0])
clientNames := md.Get(metadataClientNameKey)
require.Less(0, len(clientNames))
require.Equal(expClientName, clientNames[0])
}
func getTestUnaryInvoker(require *require.Assertions, expCorrelationID, expClientName string) grpc.UnaryInvoker {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
verifyContextMetadata(ctx, require, expCorrelationID, expClientName)
return nil
}
}
func getTestStreamer(require *require.Assertions, expCorrelationID, expClientName string) grpc.Streamer {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
verifyContextMetadata(ctx, require, expCorrelationID, expClientName)
return nil, nil
}
}
func TestUnaryClientCorrelationInterceptor(t *testing.T) {
require := require.New(t)
clientInterceptor := UnaryClientCorrelationInterceptor(WithClientName(clientName))
ctx := correlation.ContextWithCorrelation(context.Background(), correlationID)
err := clientInterceptor(
ctx,
methodName,
nil,
nil,
nil,
getTestUnaryInvoker(require, correlationID, clientName),
)
require.NoError(err)
}
func TestStreamClientCorrelationInterceptor(t *testing.T) {
require := require.New(t)
clientInterceptor := StreamClientCorrelationInterceptor(WithClientName(clientName))
ctx := correlation.ContextWithCorrelation(context.Background(), correlationID)
_, err := clientInterceptor(
ctx,
nil,
nil,
methodName,
getTestStreamer(require, correlationID, clientName),
)
require.NoError(err)
}
|