1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.
package protovalidate
import (
"context"
"errors"
"buf.build/go/protovalidate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
func UnaryServerInterceptor(validator protovalidate.Validator, opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOpts(opts)
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (resp interface{}, err error) {
if err := validateMsg(req, validator, o); err != nil {
return nil, err
}
return handler(ctx, req)
}
}
// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
// If the request is invalid, clients may access a structured representation of the validation failure as an error detail.
func StreamServerInterceptor(validator protovalidate.Validator, opts ...Option) grpc.StreamServerInterceptor {
o := evaluateOpts(opts)
return func(
srv interface{},
stream grpc.ServerStream,
info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
return handler(srv, &wrappedServerStream{
ServerStream: stream,
validator: validator,
options: o,
})
}
}
// wrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context.
type wrappedServerStream struct {
grpc.ServerStream
validator protovalidate.Validator
options *options
}
func (w *wrappedServerStream) RecvMsg(m interface{}) error {
if err := w.ServerStream.RecvMsg(m); err != nil {
return err
}
return validateMsg(m, w.validator, w.options)
}
func validateMsg(m interface{}, validator protovalidate.Validator, opts *options) error {
msg, ok := m.(proto.Message)
if !ok {
return status.Errorf(codes.Internal, "unsupported message type: %T", m)
}
if opts.shouldIgnoreMessage(msg.ProtoReflect().Descriptor().FullName()) {
return nil
}
err := validator.Validate(msg)
if err == nil {
return nil
}
var valErr *protovalidate.ValidationError
if errors.As(err, &valErr) {
// Message is invalid.
st := status.New(codes.InvalidArgument, err.Error())
ds, detErr := st.WithDetails(valErr.ToProto())
if detErr != nil {
return st.Err()
}
return ds.Err()
}
// CEL expression doesn't compile or type-check.
return status.Error(codes.Internal, err.Error())
}
|