File: server_limiting_test.go

package info (click to toggle)
gitlab-agent 16.1.3-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 6,324 kB
  • sloc: makefile: 175; sh: 52; ruby: 3
file content (57 lines) | stat: -rw-r--r-- 1,715 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
package grpctool_test

import (
	"context"
	"testing"

	"github.com/golang/mock/gomock"
	"github.com/stretchr/testify/require"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/grpctool"
	"gitlab.com/gitlab-org/cluster-integration/gitlab-agent/v16/internal/tool/testing/mock_rpc"
	"google.golang.org/grpc"
)

type testServerLimiter struct {
	allow bool
}

func (l *testServerLimiter) Allow(ctx context.Context) bool {
	return l.allow
}

func TestServerInterceptors(t *testing.T) {
	ctrl := gomock.NewController(t)
	usHandler := func(ctx context.Context, req interface{}) (interface{}, error) {
		return struct{}{}, nil
	}
	ssHandler := func(interface{}, grpc.ServerStream) error {
		return nil
	}
	t.Run("It lets the connection through when allowed", func(t *testing.T) {
		limiter := &testServerLimiter{allow: true}

		usi := grpctool.UnaryServerLimitingInterceptor(limiter)
		_, err := usi(context.Background(), struct{}{}, nil, usHandler)
		require.NoError(t, err)

		ssi := grpctool.StreamServerLimitingInterceptor(limiter)
		ss := mock_rpc.NewMockServerStream(ctrl)
		ss.EXPECT().Context().Return(context.Background())
		err = ssi(struct{}{}, ss, nil, ssHandler)
		require.NoError(t, err)
	})

	t.Run("It blocks the connection when not allowed", func(t *testing.T) {
		limiter := &testServerLimiter{false}

		usi := grpctool.UnaryServerLimitingInterceptor(limiter)
		_, err := usi(context.Background(), struct{}{}, nil, usHandler)
		require.Error(t, err)

		ssi := grpctool.StreamServerLimitingInterceptor(limiter)
		ss := mock_rpc.NewMockServerStream(ctrl)
		ss.EXPECT().Context().Return(context.Background())
		err = ssi(struct{}{}, ss, nil, ssHandler)
		require.Error(t, err)
	})
}