1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
|
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package protobuild constructs messages.
//
// This package is used to construct multiple types of message with a similar shape
// from a common template.
package protobuild
import (
"fmt"
"math"
"reflect"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/reflect/protoregistry"
)
// A Value is a value assignable to a field.
// A Value may be a value accepted by protoreflect.ValueOf. In addition:
//
// • An int may be assigned to any numeric field.
//
// • A float64 may be assigned to a double field.
//
// • Either a string or []byte may be assigned to a string or bytes field.
//
// • A string containing the value name may be assigned to an enum field.
//
// • A slice may be assigned to a list, and a map may be assigned to a map.
type Value interface{}
// A Message is a template to apply to a message. Keys are field names, including
// extension names.
type Message map[protoreflect.Name]Value
// Unknown is a key associated with the unknown fields of a message.
// The value should be a []byte.
const Unknown = "@unknown"
// Build applies the template to a message.
func (template Message) Build(m protoreflect.Message) {
md := m.Descriptor()
fields := md.Fields()
exts := make(map[protoreflect.Name]protoreflect.FieldDescriptor)
protoregistry.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(xt protoreflect.ExtensionType) bool {
xd := xt.TypeDescriptor()
exts[xd.Name()] = xd
return true
})
for k, v := range template {
if k == Unknown {
m.SetUnknown(protoreflect.RawFields(v.([]byte)))
continue
}
fd := fields.ByName(k)
if fd == nil {
fd = exts[k]
}
if fd == nil {
panic(fmt.Sprintf("%v.%v: not found", md.FullName(), k))
}
switch {
case fd.IsList():
list := m.Mutable(fd).List()
s := reflect.ValueOf(v)
for i := 0; i < s.Len(); i++ {
if fd.Message() == nil {
list.Append(fieldValue(fd, s.Index(i).Interface()))
} else {
e := list.NewElement()
s.Index(i).Interface().(Message).Build(e.Message())
list.Append(e)
}
}
case fd.IsMap():
mapv := m.Mutable(fd).Map()
rm := reflect.ValueOf(v)
for _, k := range rm.MapKeys() {
mk := fieldValue(fd.MapKey(), k.Interface()).MapKey()
if fd.MapValue().Message() == nil {
mv := fieldValue(fd.MapValue(), rm.MapIndex(k).Interface())
mapv.Set(mk, mv)
} else if mapv.Has(mk) {
mv := mapv.Get(mk).Message()
rm.MapIndex(k).Interface().(Message).Build(mv)
} else {
mv := mapv.NewValue()
rm.MapIndex(k).Interface().(Message).Build(mv.Message())
mapv.Set(mk, mv)
}
}
default:
if fd.Message() == nil {
m.Set(fd, fieldValue(fd, v))
} else {
v.(Message).Build(m.Mutable(fd).Message())
}
}
}
}
func fieldValue(fd protoreflect.FieldDescriptor, v interface{}) protoreflect.Value {
switch o := v.(type) {
case int:
switch fd.Kind() {
case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
if o < math.MinInt32 || math.MaxInt32 < o {
panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, int32(math.MinInt32), int32(math.MaxInt32)))
}
v = int32(o)
case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
if o < 0 || math.MaxUint32 < 0 {
panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, uint32(0), uint32(math.MaxUint32)))
}
v = uint32(o)
case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
v = int64(o)
case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
if o < 0 {
panic(fmt.Sprintf("%v: value %v out of range [%v, %v]", fd.FullName(), o, uint64(0), uint64(math.MaxUint64)))
}
v = uint64(o)
case protoreflect.FloatKind:
v = float32(o)
case protoreflect.DoubleKind:
v = float64(o)
case protoreflect.EnumKind:
v = protoreflect.EnumNumber(o)
default:
panic(fmt.Sprintf("%v: invalid value type int", fd.FullName()))
}
case float64:
switch fd.Kind() {
case protoreflect.FloatKind:
v = float32(o)
}
case string:
switch fd.Kind() {
case protoreflect.BytesKind:
v = []byte(o)
case protoreflect.EnumKind:
v = fd.Enum().Values().ByName(protoreflect.Name(o)).Number()
}
case []byte:
return protoreflect.ValueOf(append([]byte{}, o...))
}
return protoreflect.ValueOf(v)
}
|