1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
package validation
import (
"errors"
"fmt"
"strings"
"google.golang.org/genproto/googleapis/rpc/errdetails"
)
// MessageValidator provides primitives for validating the fields of a message.
type MessageValidator struct {
parentField string
fieldViolations []*errdetails.BadRequest_FieldViolation
}
// SetParentField sets a parent field which will be prepended to all the subsequently added violations.
func (m *MessageValidator) SetParentField(parentField string) {
m.parentField = parentField
}
// AddFieldViolation adds a field violation to the message validator.
func (m *MessageValidator) AddFieldViolation(field, description string, formatArgs ...interface{}) {
if m.parentField != "" {
field = makeFieldWithParent(m.parentField, field)
}
if len(formatArgs) > 0 {
description = fmt.Sprintf(description, formatArgs...)
}
m.fieldViolations = append(m.fieldViolations, &errdetails.BadRequest_FieldViolation{
Field: field,
Description: description,
})
}
// AddFieldError adds a field violation from the provided error.
// If the provided error is a validation.Error, the individual field violations from the provided error are added.
func (m *MessageValidator) AddFieldError(field string, err error) {
var errValidation *Error
if errors.As(err, &errValidation) {
// Add the child field violations with the current field as parent.
originalParentField := m.parentField
m.parentField = makeFieldWithParent(m.parentField, field)
for _, fieldViolation := range errValidation.fieldViolations {
m.AddFieldViolation(fieldViolation.GetField(), fieldViolation.GetDescription())
}
m.parentField = originalParentField
} else {
m.AddFieldViolation(field, err.Error())
}
}
// Err returns the validator's current validation error, or nil if no field validations have been registered.
func (m *MessageValidator) Err() error {
if len(m.fieldViolations) > 0 {
return NewError(m.fieldViolations)
}
return nil
}
func makeFieldWithParent(parentField, field string) string {
if parentField == "" {
return field
}
var result strings.Builder
result.Grow(len(parentField) + 1 + len(field))
_, _ = result.WriteString(parentField)
_ = result.WriteByte('.')
_, _ = result.WriteString(field)
return result.String()
}
|