File: extension.go

package info (click to toggle)
gitlab-shell 14.35.0%2Bds1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 23,652 kB
  • sloc: ruby: 1,129; makefile: 583; sql: 391; sh: 384
file content (78 lines) | stat: -rw-r--r-- 2,771 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package protoutil

import (
	"errors"
	"fmt"

	"gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
	"google.golang.org/protobuf/proto"
	"google.golang.org/protobuf/reflect/protoregistry"
	"google.golang.org/protobuf/runtime/protoimpl"
	"google.golang.org/protobuf/types/descriptorpb"
)

// GetOpExtension gets the OperationMsg from a method descriptor
func GetOpExtension(m *descriptorpb.MethodDescriptorProto) (*gitalypb.OperationMsg, error) {
	ext, err := getExtension(m.GetOptions(), gitalypb.E_OpType)
	if err != nil {
		return nil, err
	}

	return ext.(*gitalypb.OperationMsg), nil
}

// IsInterceptedMethod returns whether the RPC method is intercepted by Praefect.
func IsInterceptedMethod(s *descriptorpb.ServiceDescriptorProto, m *descriptorpb.MethodDescriptorProto) (bool, error) {
	isServiceIntercepted, err := getBoolExtension(s.GetOptions(), gitalypb.E_Intercepted)
	if err != nil {
		return false, fmt.Errorf("is service intercepted: %w", err)
	}

	isMethodIntercepted, err := getBoolExtension(m.GetOptions(), gitalypb.E_InterceptedMethod)
	if err != nil {
		return false, fmt.Errorf("is method intercepted: %w", err)
	}

	return isServiceIntercepted || isMethodIntercepted, nil
}

// GetRepositoryExtension gets the repository extension from a field descriptor
func GetRepositoryExtension(m *descriptorpb.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_Repository)
}

// GetStorageExtension gets the storage extension from a field descriptor
func GetStorageExtension(m *descriptorpb.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_Storage)
}

// GetTargetRepositoryExtension gets the target_repository extension from a field descriptor
func GetTargetRepositoryExtension(m *descriptorpb.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_TargetRepository)
}

// GetAdditionalRepositoryExtension gets the target_repository extension from a field descriptor
func GetAdditionalRepositoryExtension(m *descriptorpb.FieldDescriptorProto) (bool, error) {
	return getBoolExtension(m.GetOptions(), gitalypb.E_AdditionalRepository)
}

func getBoolExtension(options proto.Message, extension *protoimpl.ExtensionInfo) (bool, error) {
	val, err := getExtension(options, extension)
	if err != nil {
		if errors.Is(err, protoregistry.NotFound) {
			return false, nil
		}

		return false, err
	}

	return val.(bool), nil
}

func getExtension(options proto.Message, extension *protoimpl.ExtensionInfo) (interface{}, error) {
	if !proto.HasExtension(options, extension) {
		return nil, fmt.Errorf("protoutil.getExtension %q: %w", extension.TypeDescriptor().FullName(), protoregistry.NotFound)
	}

	return proto.GetExtension(options, extension), nil
}