File: stream_visitor_options.go

package info (click to toggle)
gitlab-agent 16.1.3-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 6,324 kB
  • sloc: makefile: 175; sh: 52; ruby: 3
file content (143 lines) | stat: -rw-r--r-- 5,287 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package grpctool

import (
	"reflect"

	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/reflect/protoregistry"
)

var (
	errorType = reflect.TypeOf((*error)(nil)).Elem()
)

type config struct {
	reflectMessage            protoreflect.Message
	goMessageType             reflect.Type
	oneof                     protoreflect.OneofDescriptor
	eofCallback               EOFCallback
	invalidTransitionCallback InvalidTransitionCallback
	startState                protoreflect.FieldNumber
	notExpectingFields        map[protoreflect.FieldNumber]codes.Code    // fields that are not expected during this invocation
	msgCallbacks              map[protoreflect.FieldNumber]reflect.Value // callbacks that accept the whole message
	fieldCallbacks            map[protoreflect.FieldNumber]reflect.Value // callbacks that accept a specific field type of the oneof
}

// StreamVisitorOption is an option for the visitor.
// Must return nil or an error, compatible with the gRPC status package.
type StreamVisitorOption func(*config) error

// WithEOFCallback sets a callback for end of stream.
func WithEOFCallback(cb EOFCallback) StreamVisitorOption {
	return func(c *config) error {
		c.eofCallback = cb
		return nil
	}
}

// WithNotExpectingToGet is used to list fields that the caller is not expecting to get during this Visit() invocation.
func WithNotExpectingToGet(code codes.Code, transitionTo ...protoreflect.FieldNumber) StreamVisitorOption {
	return func(c *config) error {
		if len(transitionTo) == 0 {
			return status.Error(codes.Internal, "at least one field number is required")
		}
		for _, f := range transitionTo {
			_, err := checkField(c, f)
			if err != nil {
				return err
			}
			c.notExpectingFields[f] = code
		}
		return nil
	}
}

// WithCallback registers cb to be called when entering transitionTo when parsing the stream. Only one callback can be registered per target
func WithCallback(transitionTo protoreflect.FieldNumber, cb MessageCallback) StreamVisitorOption {
	return func(c *config) error {
		cbType := reflect.TypeOf(cb)
		if cbType.Kind() != reflect.Func {
			return status.Errorf(codes.Internal, "cb must be a function, got: %T", cb)
		}
		if cbType.NumIn() != 1 {
			return status.Errorf(codes.Internal, "cb must take one parameter only, got: %T", cb)
		}
		if cbType.NumOut() != 1 {
			return status.Errorf(codes.Internal, "cb must return one value only, got: %T", cb)
		}
		if cbType.Out(0) != errorType {
			return status.Errorf(codes.Internal, "cb must return an error, got: %T", cb)
		}
		field, err := checkField(c, transitionTo)
		if err != nil {
			return err
		}
		inType := cbType.In(0)
		if c.goMessageType.AssignableTo(inType) {
			c.msgCallbacks[transitionTo] = reflect.ValueOf(cb)
			return nil
		}
		var goField interface{}
		switch field.Kind() { // nolint:exhaustive
		case protoreflect.MessageKind:
			goField = c.reflectMessage.Get(field).Message().Interface()
		case protoreflect.EnumKind:
			et, err := protoregistry.GlobalTypes.FindEnumByName(field.Enum().FullName())
			if err != nil {
				return status.Errorf(codes.Internal, "FindEnumByName(): %v", err)
			}
			goField = et.New(0)
		default:
			goField = c.reflectMessage.Get(field).Interface()
		}
		if !reflect.TypeOf(goField).AssignableTo(inType) {
			return status.Errorf(codes.Internal, "callback must be a function with one parameter of type %s or the oneof field type %T, got: %T", c.goMessageType, goField, cb)
		}
		c.fieldCallbacks[transitionTo] = reflect.ValueOf(cb)
		return nil
	}
}

func WithInvalidTransitionCallback(cb InvalidTransitionCallback) StreamVisitorOption {
	return func(c *config) error {
		c.invalidTransitionCallback = cb
		return nil
	}
}

// WithStartState allows to specify a custom automata start state.
// The visitor then acts as if it has just visited field with startState number.
func WithStartState(startState protoreflect.FieldNumber) StreamVisitorOption {
	return func(c *config) error {
		c.startState = startState
		return nil
	}
}

func checkField(c *config, transitionTo protoreflect.FieldNumber) (protoreflect.FieldDescriptor, error) {
	if _, exists := c.notExpectingFields[transitionTo]; exists {
		return nil, status.Errorf(codes.Internal, "field %d has already been marked as unexpected", transitionTo)
	}
	if existingCb, exists := c.msgCallbacks[transitionTo]; exists {
		return nil, status.Errorf(codes.Internal, "callback for %d has already been defined: %v", transitionTo, existingCb)
	}
	if existingCb, exists := c.fieldCallbacks[transitionTo]; exists {
		return nil, status.Errorf(codes.Internal, "callback for %d has already been defined: %v", transitionTo, existingCb)
	}
	field := c.oneof.Fields().ByNumber(transitionTo)
	if field == nil {
		return nil, status.Errorf(codes.Internal, "oneof %s does not have a field %d", c.oneof.FullName(), transitionTo)
	}
	return field, nil
}

func defaultInvalidTransitionCallback(from, to protoreflect.FieldNumber, allowed []protoreflect.FieldNumber, message proto.Message) error {
	return status.Errorf(codes.InvalidArgument, "transition from %d to %d is not allowed. Allowed: %v", from, to, allowed)
}

func defaultEOFCallback() error {
	return nil
}