1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
package grpctool
import (
"reflect"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
var (
errorType = reflect.TypeOf((*error)(nil)).Elem()
)
type config struct {
reflectMessage protoreflect.Message
goMessageType reflect.Type
oneof protoreflect.OneofDescriptor
eofCallback EOFCallback
invalidTransitionCallback InvalidTransitionCallback
startState protoreflect.FieldNumber
notExpectingFields map[protoreflect.FieldNumber]codes.Code // fields that are not expected during this invocation
msgCallbacks map[protoreflect.FieldNumber]reflect.Value // callbacks that accept the whole message
fieldCallbacks map[protoreflect.FieldNumber]reflect.Value // callbacks that accept a specific field type of the oneof
}
// StreamVisitorOption is an option for the visitor.
// Must return nil or an error, compatible with the gRPC status package.
type StreamVisitorOption func(*config) error
// WithEOFCallback sets a callback for end of stream.
func WithEOFCallback(cb EOFCallback) StreamVisitorOption {
return func(c *config) error {
c.eofCallback = cb
return nil
}
}
// WithNotExpectingToGet is used to list fields that the caller is not expecting to get during this Visit() invocation.
func WithNotExpectingToGet(code codes.Code, transitionTo ...protoreflect.FieldNumber) StreamVisitorOption {
return func(c *config) error {
if len(transitionTo) == 0 {
return status.Error(codes.Internal, "at least one field number is required")
}
for _, f := range transitionTo {
_, err := checkField(c, f)
if err != nil {
return err
}
c.notExpectingFields[f] = code
}
return nil
}
}
// WithCallback registers cb to be called when entering transitionTo when parsing the stream. Only one callback can be registered per target
func WithCallback(transitionTo protoreflect.FieldNumber, cb MessageCallback) StreamVisitorOption {
return func(c *config) error {
cbType := reflect.TypeOf(cb)
if cbType.Kind() != reflect.Func {
return status.Errorf(codes.Internal, "cb must be a function, got: %T", cb)
}
if cbType.NumIn() != 1 {
return status.Errorf(codes.Internal, "cb must take one parameter only, got: %T", cb)
}
if cbType.NumOut() != 1 {
return status.Errorf(codes.Internal, "cb must return one value only, got: %T", cb)
}
if cbType.Out(0) != errorType {
return status.Errorf(codes.Internal, "cb must return an error, got: %T", cb)
}
field, err := checkField(c, transitionTo)
if err != nil {
return err
}
inType := cbType.In(0)
if c.goMessageType.AssignableTo(inType) {
c.msgCallbacks[transitionTo] = reflect.ValueOf(cb)
return nil
}
var goField interface{}
switch field.Kind() { // nolint:exhaustive
case protoreflect.MessageKind:
goField = c.reflectMessage.Get(field).Message().Interface()
case protoreflect.EnumKind:
et, err := protoregistry.GlobalTypes.FindEnumByName(field.Enum().FullName())
if err != nil {
return status.Errorf(codes.Internal, "FindEnumByName(): %v", err)
}
goField = et.New(0)
default:
goField = c.reflectMessage.Get(field).Interface()
}
if !reflect.TypeOf(goField).AssignableTo(inType) {
return status.Errorf(codes.Internal, "callback must be a function with one parameter of type %s or the oneof field type %T, got: %T", c.goMessageType, goField, cb)
}
c.fieldCallbacks[transitionTo] = reflect.ValueOf(cb)
return nil
}
}
func WithInvalidTransitionCallback(cb InvalidTransitionCallback) StreamVisitorOption {
return func(c *config) error {
c.invalidTransitionCallback = cb
return nil
}
}
// WithStartState allows to specify a custom automata start state.
// The visitor then acts as if it has just visited field with startState number.
func WithStartState(startState protoreflect.FieldNumber) StreamVisitorOption {
return func(c *config) error {
c.startState = startState
return nil
}
}
func checkField(c *config, transitionTo protoreflect.FieldNumber) (protoreflect.FieldDescriptor, error) {
if _, exists := c.notExpectingFields[transitionTo]; exists {
return nil, status.Errorf(codes.Internal, "field %d has already been marked as unexpected", transitionTo)
}
if existingCb, exists := c.msgCallbacks[transitionTo]; exists {
return nil, status.Errorf(codes.Internal, "callback for %d has already been defined: %v", transitionTo, existingCb)
}
if existingCb, exists := c.fieldCallbacks[transitionTo]; exists {
return nil, status.Errorf(codes.Internal, "callback for %d has already been defined: %v", transitionTo, existingCb)
}
field := c.oneof.Fields().ByNumber(transitionTo)
if field == nil {
return nil, status.Errorf(codes.Internal, "oneof %s does not have a field %d", c.oneof.FullName(), transitionTo)
}
return field, nil
}
func defaultInvalidTransitionCallback(from, to protoreflect.FieldNumber, allowed []protoreflect.FieldNumber, message proto.Message) error {
return status.Errorf(codes.InvalidArgument, "transition from %d to %d is not allowed. Allowed: %v", from, to, allowed)
}
func defaultEOFCallback() error {
return nil
}
|