123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- // Copyright 2010 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package proto
- import (
- "errors"
- "fmt"
- "reflect"
- "google.golang.org/protobuf/encoding/protowire"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/reflect/protoreflect"
- "google.golang.org/protobuf/reflect/protoregistry"
- "google.golang.org/protobuf/runtime/protoiface"
- "google.golang.org/protobuf/runtime/protoimpl"
- )
- type (
- // ExtensionDesc represents an extension descriptor and
- // is used to interact with an extension field in a message.
- //
- // Variables of this type are generated in code by protoc-gen-go.
- ExtensionDesc = protoimpl.ExtensionInfo
- // ExtensionRange represents a range of message extensions.
- // Used in code generated by protoc-gen-go.
- ExtensionRange = protoiface.ExtensionRangeV1
- // Deprecated: Do not use; this is an internal type.
- Extension = protoimpl.ExtensionFieldV1
- // Deprecated: Do not use; this is an internal type.
- XXX_InternalExtensions = protoimpl.ExtensionFields
- )
- // ErrMissingExtension reports whether the extension was not present.
- var ErrMissingExtension = errors.New("proto: missing extension")
- var errNotExtendable = errors.New("proto: not an extendable proto.Message")
- // HasExtension reports whether the extension field is present in m
- // either as an explicitly populated field or as an unknown field.
- func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() {
- return false
- }
- // Check whether any populated known field matches the field number.
- xtd := xt.TypeDescriptor()
- if isValidExtension(mr.Descriptor(), xtd) {
- has = mr.Has(xtd)
- } else {
- mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
- has = int32(fd.Number()) == xt.Field
- return !has
- })
- }
- // Check whether any unknown field matches the field number.
- for b := mr.GetUnknown(); !has && len(b) > 0; {
- num, _, n := protowire.ConsumeField(b)
- has = int32(num) == xt.Field
- b = b[n:]
- }
- return has
- }
- // ClearExtension removes the extension field from m
- // either as an explicitly populated field or as an unknown field.
- func ClearExtension(m Message, xt *ExtensionDesc) {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() {
- return
- }
- xtd := xt.TypeDescriptor()
- if isValidExtension(mr.Descriptor(), xtd) {
- mr.Clear(xtd)
- } else {
- mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
- if int32(fd.Number()) == xt.Field {
- mr.Clear(fd)
- return false
- }
- return true
- })
- }
- clearUnknown(mr, fieldNum(xt.Field))
- }
- // ClearAllExtensions clears all extensions from m.
- // This includes populated fields and unknown fields in the extension range.
- func ClearAllExtensions(m Message) {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() {
- return
- }
- mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
- if fd.IsExtension() {
- mr.Clear(fd)
- }
- return true
- })
- clearUnknown(mr, mr.Descriptor().ExtensionRanges())
- }
- // GetExtension retrieves a proto2 extended field from m.
- //
- // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
- // then GetExtension parses the encoded field and returns a Go value of the specified type.
- // If the field is not present, then the default value is returned (if one is specified),
- // otherwise ErrMissingExtension is reported.
- //
- // If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
- // then GetExtension returns the raw encoded bytes for the extension field.
- func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
- return nil, errNotExtendable
- }
- // Retrieve the unknown fields for this extension field.
- var bo protoreflect.RawFields
- for bi := mr.GetUnknown(); len(bi) > 0; {
- num, _, n := protowire.ConsumeField(bi)
- if int32(num) == xt.Field {
- bo = append(bo, bi[:n]...)
- }
- bi = bi[n:]
- }
- // For type incomplete descriptors, only retrieve the unknown fields.
- if xt.ExtensionType == nil {
- return []byte(bo), nil
- }
- // If the extension field only exists as unknown fields, unmarshal it.
- // This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
- xtd := xt.TypeDescriptor()
- if !isValidExtension(mr.Descriptor(), xtd) {
- return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
- }
- if !mr.Has(xtd) && len(bo) > 0 {
- m2 := mr.New()
- if err := (proto.UnmarshalOptions{
- Resolver: extensionResolver{xt},
- }.Unmarshal(bo, m2.Interface())); err != nil {
- return nil, err
- }
- if m2.Has(xtd) {
- mr.Set(xtd, m2.Get(xtd))
- clearUnknown(mr, fieldNum(xt.Field))
- }
- }
- // Check whether the message has the extension field set or a default.
- var pv protoreflect.Value
- switch {
- case mr.Has(xtd):
- pv = mr.Get(xtd)
- case xtd.HasDefault():
- pv = xtd.Default()
- default:
- return nil, ErrMissingExtension
- }
- v := xt.InterfaceOf(pv)
- rv := reflect.ValueOf(v)
- if isScalarKind(rv.Kind()) {
- rv2 := reflect.New(rv.Type())
- rv2.Elem().Set(rv)
- v = rv2.Interface()
- }
- return v, nil
- }
- // extensionResolver is a custom extension resolver that stores a single
- // extension type that takes precedence over the global registry.
- type extensionResolver struct{ xt protoreflect.ExtensionType }
- func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
- if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
- return r.xt, nil
- }
- return protoregistry.GlobalTypes.FindExtensionByName(field)
- }
- func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
- if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
- return r.xt, nil
- }
- return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
- }
- // GetExtensions returns a list of the extensions values present in m,
- // corresponding with the provided list of extension descriptors, xts.
- // If an extension is missing in m, the corresponding value is nil.
- func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() {
- return nil, errNotExtendable
- }
- vs := make([]interface{}, len(xts))
- for i, xt := range xts {
- v, err := GetExtension(m, xt)
- if err != nil {
- if err == ErrMissingExtension {
- continue
- }
- return vs, err
- }
- vs[i] = v
- }
- return vs, nil
- }
- // SetExtension sets an extension field in m to the provided value.
- func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
- return errNotExtendable
- }
- rv := reflect.ValueOf(v)
- if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
- return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
- }
- if rv.Kind() == reflect.Ptr {
- if rv.IsNil() {
- return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
- }
- if isScalarKind(rv.Elem().Kind()) {
- v = rv.Elem().Interface()
- }
- }
- xtd := xt.TypeDescriptor()
- if !isValidExtension(mr.Descriptor(), xtd) {
- return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
- }
- mr.Set(xtd, xt.ValueOf(v))
- clearUnknown(mr, fieldNum(xt.Field))
- return nil
- }
- // SetRawExtension inserts b into the unknown fields of m.
- //
- // Deprecated: Use Message.ProtoReflect.SetUnknown instead.
- func SetRawExtension(m Message, fnum int32, b []byte) {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() {
- return
- }
- // Verify that the raw field is valid.
- for b0 := b; len(b0) > 0; {
- num, _, n := protowire.ConsumeField(b0)
- if int32(num) != fnum {
- panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
- }
- b0 = b0[n:]
- }
- ClearExtension(m, &ExtensionDesc{Field: fnum})
- mr.SetUnknown(append(mr.GetUnknown(), b...))
- }
- // ExtensionDescs returns a list of extension descriptors found in m,
- // containing descriptors for both populated extension fields in m and
- // also unknown fields of m that are in the extension range.
- // For the later case, an type incomplete descriptor is provided where only
- // the ExtensionDesc.Field field is populated.
- // The order of the extension descriptors is undefined.
- func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
- mr := MessageReflect(m)
- if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
- return nil, errNotExtendable
- }
- // Collect a set of known extension descriptors.
- extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
- mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
- if fd.IsExtension() {
- xt := fd.(protoreflect.ExtensionTypeDescriptor)
- if xd, ok := xt.Type().(*ExtensionDesc); ok {
- extDescs[fd.Number()] = xd
- }
- }
- return true
- })
- // Collect a set of unknown extension descriptors.
- extRanges := mr.Descriptor().ExtensionRanges()
- for b := mr.GetUnknown(); len(b) > 0; {
- num, _, n := protowire.ConsumeField(b)
- if extRanges.Has(num) && extDescs[num] == nil {
- extDescs[num] = nil
- }
- b = b[n:]
- }
- // Transpose the set of descriptors into a list.
- var xts []*ExtensionDesc
- for num, xt := range extDescs {
- if xt == nil {
- xt = &ExtensionDesc{Field: int32(num)}
- }
- xts = append(xts, xt)
- }
- return xts, nil
- }
- // isValidExtension reports whether xtd is a valid extension descriptor for md.
- func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
- return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
- }
- // isScalarKind reports whether k is a protobuf scalar kind (except bytes).
- // This function exists for historical reasons since the representation of
- // scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
- func isScalarKind(k reflect.Kind) bool {
- switch k {
- case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
- return true
- default:
- return false
- }
- }
- // clearUnknown removes unknown fields from m where remover.Has reports true.
- func clearUnknown(m protoreflect.Message, remover interface {
- Has(protoreflect.FieldNumber) bool
- }) {
- var bo protoreflect.RawFields
- for bi := m.GetUnknown(); len(bi) > 0; {
- num, _, n := protowire.ConsumeField(bi)
- if !remover.Has(num) {
- bo = append(bo, bi[:n]...)
- }
- bi = bi[n:]
- }
- if bi := m.GetUnknown(); len(bi) != len(bo) {
- m.SetUnknown(bo)
- }
- }
- type fieldNum protoreflect.FieldNumber
- func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
- return protoreflect.FieldNumber(n1) == n2
- }
|