1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.
package selector
import (
"context"
"errors"
"testing"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
)
// allow matches only given methods.
func allow(methods []string) Matcher {
return MatchFunc(func(ctx context.Context, c interceptors.CallMeta) bool {
for _, s := range methods {
if s == c.FullMethod() {
return true
}
}
return false
})
}
type mockGRPCServerStream struct {
grpc.ServerStream
ctx context.Context
}
func (m *mockGRPCServerStream) Context() context.Context {
return m.ctx
}
const svcMethod = "/v1beta1.SomeService/NeedsAuth"
func TestUnaryServerInterceptor(t *testing.T) {
interceptor := UnaryServerInterceptor(
func(context.Context, any, *grpc.UnaryServerInfo, grpc.UnaryHandler) (any, error) {
return nil, errors.New("always error")
}, allow([]string{svcMethod}),
)
handler := func(ctx context.Context, req any) (any, error) {
return "good", nil
}
t.Run("not-selected", func(t *testing.T) {
info := &grpc.UnaryServerInfo{
FullMethod: "FakeMethod",
}
resp, err := interceptor(context.Background(), nil, info, handler)
assert.Nil(t, err)
assert.Equal(t, resp, "good")
})
t.Run("selected", func(t *testing.T) {
info := &grpc.UnaryServerInfo{
FullMethod: svcMethod,
}
resp, err := interceptor(context.Background(), nil, info, handler)
assert.Nil(t, resp)
assert.EqualError(t, err, "always error")
})
}
func TestStreamServerInterceptor(t *testing.T) {
interceptor := StreamServerInterceptor(
func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
return errors.New("always error")
},
allow([]string{svcMethod}),
)
handler := func(srv any, stream grpc.ServerStream) error {
return nil
}
t.Run("not-selected", func(t *testing.T) {
info := &grpc.StreamServerInfo{
FullMethod: "FakeMethod",
}
err := interceptor(nil, &mockGRPCServerStream{ctx: context.Background()}, info, handler)
assert.Nil(t, err)
})
t.Run("slected", func(t *testing.T) {
info := &grpc.StreamServerInfo{
FullMethod: svcMethod,
}
err := interceptor(nil, &mockGRPCServerStream{ctx: context.Background()}, info, handler)
assert.EqualError(t, err, "always error")
})
}
func TestAllow(t *testing.T) {
type args struct {
methods []string
}
tests := []struct {
name string
args args
method string
want bool
}{
{
name: "false",
args: args{
methods: []string{"/auth.v1beta1.AuthService/Login"},
},
method: "/testing.testpb.v1.TestService/PingList",
want: false,
},
{
name: "true",
args: args{
methods: []string{"/auth.v1beta1.AuthService/Login"},
},
method: "/auth.v1beta1.AuthService/Login",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
allow := allow(tt.args.methods)
want := allow.Match(context.Background(), interceptors.NewServerCallMeta(tt.method, nil, nil))
assert.Equalf(t, tt.want, want, "Allow(%v)(ctx, %v)", tt.args.methods, tt.method)
})
}
}
|