1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
|
package grpctool_test
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/grpctool"
"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/testing/mock_rpc"
"google.golang.org/grpc"
)
type testServerLimiter struct {
allow bool
}
func (l *testServerLimiter) Allow(ctx context.Context) bool {
return l.allow
}
func TestServerInterceptors(t *testing.T) {
ctrl := gomock.NewController(t)
usHandler := func(ctx context.Context, req interface{}) (interface{}, error) {
return struct{}{}, nil
}
ssHandler := func(interface{}, grpc.ServerStream) error {
return nil
}
t.Run("It lets the connection through when allowed", func(t *testing.T) {
limiter := &testServerLimiter{allow: true}
usi := grpctool.UnaryServerLimitingInterceptor(limiter)
_, err := usi(context.Background(), struct{}{}, nil, usHandler)
require.NoError(t, err)
ssi := grpctool.StreamServerLimitingInterceptor(limiter)
ss := mock_rpc.NewMockServerStream(ctrl)
ss.EXPECT().Context().Return(context.Background())
err = ssi(struct{}{}, ss, nil, ssHandler)
require.NoError(t, err)
})
t.Run("It blocks the connection when not allowed", func(t *testing.T) {
limiter := &testServerLimiter{false}
usi := grpctool.UnaryServerLimitingInterceptor(limiter)
_, err := usi(context.Background(), struct{}{}, nil, usHandler)
require.Error(t, err)
ssi := grpctool.StreamServerLimitingInterceptor(limiter)
ss := mock_rpc.NewMockServerStream(ctrl)
ss.EXPECT().Context().Return(context.Background())
err = ssi(struct{}{}, ss, nil, ssHandler)
require.Error(t, err)
})
}
|