123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576 |
- // Copyright 2019 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package impl
- import (
- "fmt"
- "math"
- "math/bits"
- "reflect"
- "unicode/utf8"
- "google.golang.org/protobuf/encoding/protowire"
- "google.golang.org/protobuf/internal/encoding/messageset"
- "google.golang.org/protobuf/internal/flags"
- "google.golang.org/protobuf/internal/genid"
- "google.golang.org/protobuf/internal/strs"
- pref "google.golang.org/protobuf/reflect/protoreflect"
- preg "google.golang.org/protobuf/reflect/protoregistry"
- piface "google.golang.org/protobuf/runtime/protoiface"
- )
- // ValidationStatus is the result of validating the wire-format encoding of a message.
- type ValidationStatus int
- const (
- // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
- // The validator was unable to render a judgement.
- //
- // The only causes of this status are an aberrant message type appearing somewhere
- // in the message or a failure in the extension resolver.
- ValidationUnknown ValidationStatus = iota + 1
- // ValidationInvalid indicates that unmarshaling the message will fail.
- ValidationInvalid
- // ValidationValid indicates that unmarshaling the message will succeed.
- ValidationValid
- )
- func (v ValidationStatus) String() string {
- switch v {
- case ValidationUnknown:
- return "ValidationUnknown"
- case ValidationInvalid:
- return "ValidationInvalid"
- case ValidationValid:
- return "ValidationValid"
- default:
- return fmt.Sprintf("ValidationStatus(%d)", int(v))
- }
- }
- // Validate determines whether the contents of the buffer are a valid wire encoding
- // of the message type.
- //
- // This function is exposed for testing.
- func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) {
- mi, ok := mt.(*MessageInfo)
- if !ok {
- return out, ValidationUnknown
- }
- if in.Resolver == nil {
- in.Resolver = preg.GlobalTypes
- }
- o, st := mi.validate(in.Buf, 0, unmarshalOptions{
- flags: in.Flags,
- resolver: in.Resolver,
- })
- if o.initialized {
- out.Flags |= piface.UnmarshalInitialized
- }
- return out, st
- }
- type validationInfo struct {
- mi *MessageInfo
- typ validationType
- keyType, valType validationType
- // For non-required fields, requiredBit is 0.
- //
- // For required fields, requiredBit's nth bit is set, where n is a
- // unique index in the range [0, MessageInfo.numRequiredFields).
- //
- // If there are more than 64 required fields, requiredBit is 0.
- requiredBit uint64
- }
- type validationType uint8
- const (
- validationTypeOther validationType = iota
- validationTypeMessage
- validationTypeGroup
- validationTypeMap
- validationTypeRepeatedVarint
- validationTypeRepeatedFixed32
- validationTypeRepeatedFixed64
- validationTypeVarint
- validationTypeFixed32
- validationTypeFixed64
- validationTypeBytes
- validationTypeUTF8String
- validationTypeMessageSetItem
- )
- func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
- var vi validationInfo
- switch {
- case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
- switch fd.Kind() {
- case pref.MessageKind:
- vi.typ = validationTypeMessage
- if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
- vi.mi = getMessageInfo(ot.Field(0).Type)
- }
- case pref.GroupKind:
- vi.typ = validationTypeGroup
- if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
- vi.mi = getMessageInfo(ot.Field(0).Type)
- }
- case pref.StringKind:
- if strs.EnforceUTF8(fd) {
- vi.typ = validationTypeUTF8String
- }
- }
- default:
- vi = newValidationInfo(fd, ft)
- }
- if fd.Cardinality() == pref.Required {
- // Avoid overflow. The required field check is done with a 64-bit mask, with
- // any message containing more than 64 required fields always reported as
- // potentially uninitialized, so it is not important to get a precise count
- // of the required fields past 64.
- if mi.numRequiredFields < math.MaxUint8 {
- mi.numRequiredFields++
- vi.requiredBit = 1 << (mi.numRequiredFields - 1)
- }
- }
- return vi
- }
- func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
- var vi validationInfo
- switch {
- case fd.IsList():
- switch fd.Kind() {
- case pref.MessageKind:
- vi.typ = validationTypeMessage
- if ft.Kind() == reflect.Slice {
- vi.mi = getMessageInfo(ft.Elem())
- }
- case pref.GroupKind:
- vi.typ = validationTypeGroup
- if ft.Kind() == reflect.Slice {
- vi.mi = getMessageInfo(ft.Elem())
- }
- case pref.StringKind:
- vi.typ = validationTypeBytes
- if strs.EnforceUTF8(fd) {
- vi.typ = validationTypeUTF8String
- }
- default:
- switch wireTypes[fd.Kind()] {
- case protowire.VarintType:
- vi.typ = validationTypeRepeatedVarint
- case protowire.Fixed32Type:
- vi.typ = validationTypeRepeatedFixed32
- case protowire.Fixed64Type:
- vi.typ = validationTypeRepeatedFixed64
- }
- }
- case fd.IsMap():
- vi.typ = validationTypeMap
- switch fd.MapKey().Kind() {
- case pref.StringKind:
- if strs.EnforceUTF8(fd) {
- vi.keyType = validationTypeUTF8String
- }
- }
- switch fd.MapValue().Kind() {
- case pref.MessageKind:
- vi.valType = validationTypeMessage
- if ft.Kind() == reflect.Map {
- vi.mi = getMessageInfo(ft.Elem())
- }
- case pref.StringKind:
- if strs.EnforceUTF8(fd) {
- vi.valType = validationTypeUTF8String
- }
- }
- default:
- switch fd.Kind() {
- case pref.MessageKind:
- vi.typ = validationTypeMessage
- if !fd.IsWeak() {
- vi.mi = getMessageInfo(ft)
- }
- case pref.GroupKind:
- vi.typ = validationTypeGroup
- vi.mi = getMessageInfo(ft)
- case pref.StringKind:
- vi.typ = validationTypeBytes
- if strs.EnforceUTF8(fd) {
- vi.typ = validationTypeUTF8String
- }
- default:
- switch wireTypes[fd.Kind()] {
- case protowire.VarintType:
- vi.typ = validationTypeVarint
- case protowire.Fixed32Type:
- vi.typ = validationTypeFixed32
- case protowire.Fixed64Type:
- vi.typ = validationTypeFixed64
- case protowire.BytesType:
- vi.typ = validationTypeBytes
- }
- }
- }
- return vi
- }
- func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
- mi.init()
- type validationState struct {
- typ validationType
- keyType, valType validationType
- endGroup protowire.Number
- mi *MessageInfo
- tail []byte
- requiredMask uint64
- }
- // Pre-allocate some slots to avoid repeated slice reallocation.
- states := make([]validationState, 0, 16)
- states = append(states, validationState{
- typ: validationTypeMessage,
- mi: mi,
- })
- if groupTag > 0 {
- states[0].typ = validationTypeGroup
- states[0].endGroup = groupTag
- }
- initialized := true
- start := len(b)
- State:
- for len(states) > 0 {
- st := &states[len(states)-1]
- for len(b) > 0 {
- // Parse the tag (field number and wire type).
- var tag uint64
- if b[0] < 0x80 {
- tag = uint64(b[0])
- b = b[1:]
- } else if len(b) >= 2 && b[1] < 128 {
- tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
- b = b[2:]
- } else {
- var n int
- tag, n = protowire.ConsumeVarint(b)
- if n < 0 {
- return out, ValidationInvalid
- }
- b = b[n:]
- }
- var num protowire.Number
- if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
- return out, ValidationInvalid
- } else {
- num = protowire.Number(n)
- }
- wtyp := protowire.Type(tag & 7)
- if wtyp == protowire.EndGroupType {
- if st.endGroup == num {
- goto PopState
- }
- return out, ValidationInvalid
- }
- var vi validationInfo
- switch {
- case st.typ == validationTypeMap:
- switch num {
- case genid.MapEntry_Key_field_number:
- vi.typ = st.keyType
- case genid.MapEntry_Value_field_number:
- vi.typ = st.valType
- vi.mi = st.mi
- vi.requiredBit = 1
- }
- case flags.ProtoLegacy && st.mi.isMessageSet:
- switch num {
- case messageset.FieldItem:
- vi.typ = validationTypeMessageSetItem
- }
- default:
- var f *coderFieldInfo
- if int(num) < len(st.mi.denseCoderFields) {
- f = st.mi.denseCoderFields[num]
- } else {
- f = st.mi.coderFields[num]
- }
- if f != nil {
- vi = f.validation
- if vi.typ == validationTypeMessage && vi.mi == nil {
- // Probable weak field.
- //
- // TODO: Consider storing the results of this lookup somewhere
- // rather than recomputing it on every validation.
- fd := st.mi.Desc.Fields().ByNumber(num)
- if fd == nil || !fd.IsWeak() {
- break
- }
- messageName := fd.Message().FullName()
- messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
- switch err {
- case nil:
- vi.mi, _ = messageType.(*MessageInfo)
- case preg.NotFound:
- vi.typ = validationTypeBytes
- default:
- return out, ValidationUnknown
- }
- }
- break
- }
- // Possible extension field.
- //
- // TODO: We should return ValidationUnknown when:
- // 1. The resolver is not frozen. (More extensions may be added to it.)
- // 2. The resolver returns preg.NotFound.
- // In this case, a type added to the resolver in the future could cause
- // unmarshaling to begin failing. Supporting this requires some way to
- // determine if the resolver is frozen.
- xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
- if err != nil && err != preg.NotFound {
- return out, ValidationUnknown
- }
- if err == nil {
- vi = getExtensionFieldInfo(xt).validation
- }
- }
- if vi.requiredBit != 0 {
- // Check that the field has a compatible wire type.
- // We only need to consider non-repeated field types,
- // since repeated fields (and maps) can never be required.
- ok := false
- switch vi.typ {
- case validationTypeVarint:
- ok = wtyp == protowire.VarintType
- case validationTypeFixed32:
- ok = wtyp == protowire.Fixed32Type
- case validationTypeFixed64:
- ok = wtyp == protowire.Fixed64Type
- case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
- ok = wtyp == protowire.BytesType
- case validationTypeGroup:
- ok = wtyp == protowire.StartGroupType
- }
- if ok {
- st.requiredMask |= vi.requiredBit
- }
- }
- switch wtyp {
- case protowire.VarintType:
- if len(b) >= 10 {
- switch {
- case b[0] < 0x80:
- b = b[1:]
- case b[1] < 0x80:
- b = b[2:]
- case b[2] < 0x80:
- b = b[3:]
- case b[3] < 0x80:
- b = b[4:]
- case b[4] < 0x80:
- b = b[5:]
- case b[5] < 0x80:
- b = b[6:]
- case b[6] < 0x80:
- b = b[7:]
- case b[7] < 0x80:
- b = b[8:]
- case b[8] < 0x80:
- b = b[9:]
- case b[9] < 0x80 && b[9] < 2:
- b = b[10:]
- default:
- return out, ValidationInvalid
- }
- } else {
- switch {
- case len(b) > 0 && b[0] < 0x80:
- b = b[1:]
- case len(b) > 1 && b[1] < 0x80:
- b = b[2:]
- case len(b) > 2 && b[2] < 0x80:
- b = b[3:]
- case len(b) > 3 && b[3] < 0x80:
- b = b[4:]
- case len(b) > 4 && b[4] < 0x80:
- b = b[5:]
- case len(b) > 5 && b[5] < 0x80:
- b = b[6:]
- case len(b) > 6 && b[6] < 0x80:
- b = b[7:]
- case len(b) > 7 && b[7] < 0x80:
- b = b[8:]
- case len(b) > 8 && b[8] < 0x80:
- b = b[9:]
- case len(b) > 9 && b[9] < 2:
- b = b[10:]
- default:
- return out, ValidationInvalid
- }
- }
- continue State
- case protowire.BytesType:
- var size uint64
- if len(b) >= 1 && b[0] < 0x80 {
- size = uint64(b[0])
- b = b[1:]
- } else if len(b) >= 2 && b[1] < 128 {
- size = uint64(b[0]&0x7f) + uint64(b[1])<<7
- b = b[2:]
- } else {
- var n int
- size, n = protowire.ConsumeVarint(b)
- if n < 0 {
- return out, ValidationInvalid
- }
- b = b[n:]
- }
- if size > uint64(len(b)) {
- return out, ValidationInvalid
- }
- v := b[:size]
- b = b[size:]
- switch vi.typ {
- case validationTypeMessage:
- if vi.mi == nil {
- return out, ValidationUnknown
- }
- vi.mi.init()
- fallthrough
- case validationTypeMap:
- if vi.mi != nil {
- vi.mi.init()
- }
- states = append(states, validationState{
- typ: vi.typ,
- keyType: vi.keyType,
- valType: vi.valType,
- mi: vi.mi,
- tail: b,
- })
- b = v
- continue State
- case validationTypeRepeatedVarint:
- // Packed field.
- for len(v) > 0 {
- _, n := protowire.ConsumeVarint(v)
- if n < 0 {
- return out, ValidationInvalid
- }
- v = v[n:]
- }
- case validationTypeRepeatedFixed32:
- // Packed field.
- if len(v)%4 != 0 {
- return out, ValidationInvalid
- }
- case validationTypeRepeatedFixed64:
- // Packed field.
- if len(v)%8 != 0 {
- return out, ValidationInvalid
- }
- case validationTypeUTF8String:
- if !utf8.Valid(v) {
- return out, ValidationInvalid
- }
- }
- case protowire.Fixed32Type:
- if len(b) < 4 {
- return out, ValidationInvalid
- }
- b = b[4:]
- case protowire.Fixed64Type:
- if len(b) < 8 {
- return out, ValidationInvalid
- }
- b = b[8:]
- case protowire.StartGroupType:
- switch {
- case vi.typ == validationTypeGroup:
- if vi.mi == nil {
- return out, ValidationUnknown
- }
- vi.mi.init()
- states = append(states, validationState{
- typ: validationTypeGroup,
- mi: vi.mi,
- endGroup: num,
- })
- continue State
- case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
- typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
- if err != nil {
- return out, ValidationInvalid
- }
- xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
- switch {
- case err == preg.NotFound:
- b = b[n:]
- case err != nil:
- return out, ValidationUnknown
- default:
- xvi := getExtensionFieldInfo(xt).validation
- if xvi.mi != nil {
- xvi.mi.init()
- }
- states = append(states, validationState{
- typ: xvi.typ,
- mi: xvi.mi,
- tail: b[n:],
- })
- b = v
- continue State
- }
- default:
- n := protowire.ConsumeFieldValue(num, wtyp, b)
- if n < 0 {
- return out, ValidationInvalid
- }
- b = b[n:]
- }
- default:
- return out, ValidationInvalid
- }
- }
- if st.endGroup != 0 {
- return out, ValidationInvalid
- }
- if len(b) != 0 {
- return out, ValidationInvalid
- }
- b = st.tail
- PopState:
- numRequiredFields := 0
- switch st.typ {
- case validationTypeMessage, validationTypeGroup:
- numRequiredFields = int(st.mi.numRequiredFields)
- case validationTypeMap:
- // If this is a map field with a message value that contains
- // required fields, require that the value be present.
- if st.mi != nil && st.mi.numRequiredFields > 0 {
- numRequiredFields = 1
- }
- }
- // If there are more than 64 required fields, this check will
- // always fail and we will report that the message is potentially
- // uninitialized.
- if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
- initialized = false
- }
- states = states[:len(states)-1]
- }
- out.n = start - len(b)
- if initialized {
- out.initialized = true
- }
- return out, ValidationValid
- }
|