validate.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package impl
  5. import (
  6. "fmt"
  7. "math"
  8. "math/bits"
  9. "reflect"
  10. "unicode/utf8"
  11. "google.golang.org/protobuf/encoding/protowire"
  12. "google.golang.org/protobuf/internal/encoding/messageset"
  13. "google.golang.org/protobuf/internal/flags"
  14. "google.golang.org/protobuf/internal/genid"
  15. "google.golang.org/protobuf/internal/strs"
  16. pref "google.golang.org/protobuf/reflect/protoreflect"
  17. preg "google.golang.org/protobuf/reflect/protoregistry"
  18. piface "google.golang.org/protobuf/runtime/protoiface"
  19. )
  20. // ValidationStatus is the result of validating the wire-format encoding of a message.
  21. type ValidationStatus int
  22. const (
  23. // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
  24. // The validator was unable to render a judgement.
  25. //
  26. // The only causes of this status are an aberrant message type appearing somewhere
  27. // in the message or a failure in the extension resolver.
  28. ValidationUnknown ValidationStatus = iota + 1
  29. // ValidationInvalid indicates that unmarshaling the message will fail.
  30. ValidationInvalid
  31. // ValidationValid indicates that unmarshaling the message will succeed.
  32. ValidationValid
  33. )
  34. func (v ValidationStatus) String() string {
  35. switch v {
  36. case ValidationUnknown:
  37. return "ValidationUnknown"
  38. case ValidationInvalid:
  39. return "ValidationInvalid"
  40. case ValidationValid:
  41. return "ValidationValid"
  42. default:
  43. return fmt.Sprintf("ValidationStatus(%d)", int(v))
  44. }
  45. }
  46. // Validate determines whether the contents of the buffer are a valid wire encoding
  47. // of the message type.
  48. //
  49. // This function is exposed for testing.
  50. func Validate(mt pref.MessageType, in piface.UnmarshalInput) (out piface.UnmarshalOutput, _ ValidationStatus) {
  51. mi, ok := mt.(*MessageInfo)
  52. if !ok {
  53. return out, ValidationUnknown
  54. }
  55. if in.Resolver == nil {
  56. in.Resolver = preg.GlobalTypes
  57. }
  58. o, st := mi.validate(in.Buf, 0, unmarshalOptions{
  59. flags: in.Flags,
  60. resolver: in.Resolver,
  61. })
  62. if o.initialized {
  63. out.Flags |= piface.UnmarshalInitialized
  64. }
  65. return out, st
  66. }
  67. type validationInfo struct {
  68. mi *MessageInfo
  69. typ validationType
  70. keyType, valType validationType
  71. // For non-required fields, requiredBit is 0.
  72. //
  73. // For required fields, requiredBit's nth bit is set, where n is a
  74. // unique index in the range [0, MessageInfo.numRequiredFields).
  75. //
  76. // If there are more than 64 required fields, requiredBit is 0.
  77. requiredBit uint64
  78. }
  79. type validationType uint8
  80. const (
  81. validationTypeOther validationType = iota
  82. validationTypeMessage
  83. validationTypeGroup
  84. validationTypeMap
  85. validationTypeRepeatedVarint
  86. validationTypeRepeatedFixed32
  87. validationTypeRepeatedFixed64
  88. validationTypeVarint
  89. validationTypeFixed32
  90. validationTypeFixed64
  91. validationTypeBytes
  92. validationTypeUTF8String
  93. validationTypeMessageSetItem
  94. )
  95. func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
  96. var vi validationInfo
  97. switch {
  98. case fd.ContainingOneof() != nil && !fd.ContainingOneof().IsSynthetic():
  99. switch fd.Kind() {
  100. case pref.MessageKind:
  101. vi.typ = validationTypeMessage
  102. if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
  103. vi.mi = getMessageInfo(ot.Field(0).Type)
  104. }
  105. case pref.GroupKind:
  106. vi.typ = validationTypeGroup
  107. if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
  108. vi.mi = getMessageInfo(ot.Field(0).Type)
  109. }
  110. case pref.StringKind:
  111. if strs.EnforceUTF8(fd) {
  112. vi.typ = validationTypeUTF8String
  113. }
  114. }
  115. default:
  116. vi = newValidationInfo(fd, ft)
  117. }
  118. if fd.Cardinality() == pref.Required {
  119. // Avoid overflow. The required field check is done with a 64-bit mask, with
  120. // any message containing more than 64 required fields always reported as
  121. // potentially uninitialized, so it is not important to get a precise count
  122. // of the required fields past 64.
  123. if mi.numRequiredFields < math.MaxUint8 {
  124. mi.numRequiredFields++
  125. vi.requiredBit = 1 << (mi.numRequiredFields - 1)
  126. }
  127. }
  128. return vi
  129. }
  130. func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
  131. var vi validationInfo
  132. switch {
  133. case fd.IsList():
  134. switch fd.Kind() {
  135. case pref.MessageKind:
  136. vi.typ = validationTypeMessage
  137. if ft.Kind() == reflect.Slice {
  138. vi.mi = getMessageInfo(ft.Elem())
  139. }
  140. case pref.GroupKind:
  141. vi.typ = validationTypeGroup
  142. if ft.Kind() == reflect.Slice {
  143. vi.mi = getMessageInfo(ft.Elem())
  144. }
  145. case pref.StringKind:
  146. vi.typ = validationTypeBytes
  147. if strs.EnforceUTF8(fd) {
  148. vi.typ = validationTypeUTF8String
  149. }
  150. default:
  151. switch wireTypes[fd.Kind()] {
  152. case protowire.VarintType:
  153. vi.typ = validationTypeRepeatedVarint
  154. case protowire.Fixed32Type:
  155. vi.typ = validationTypeRepeatedFixed32
  156. case protowire.Fixed64Type:
  157. vi.typ = validationTypeRepeatedFixed64
  158. }
  159. }
  160. case fd.IsMap():
  161. vi.typ = validationTypeMap
  162. switch fd.MapKey().Kind() {
  163. case pref.StringKind:
  164. if strs.EnforceUTF8(fd) {
  165. vi.keyType = validationTypeUTF8String
  166. }
  167. }
  168. switch fd.MapValue().Kind() {
  169. case pref.MessageKind:
  170. vi.valType = validationTypeMessage
  171. if ft.Kind() == reflect.Map {
  172. vi.mi = getMessageInfo(ft.Elem())
  173. }
  174. case pref.StringKind:
  175. if strs.EnforceUTF8(fd) {
  176. vi.valType = validationTypeUTF8String
  177. }
  178. }
  179. default:
  180. switch fd.Kind() {
  181. case pref.MessageKind:
  182. vi.typ = validationTypeMessage
  183. if !fd.IsWeak() {
  184. vi.mi = getMessageInfo(ft)
  185. }
  186. case pref.GroupKind:
  187. vi.typ = validationTypeGroup
  188. vi.mi = getMessageInfo(ft)
  189. case pref.StringKind:
  190. vi.typ = validationTypeBytes
  191. if strs.EnforceUTF8(fd) {
  192. vi.typ = validationTypeUTF8String
  193. }
  194. default:
  195. switch wireTypes[fd.Kind()] {
  196. case protowire.VarintType:
  197. vi.typ = validationTypeVarint
  198. case protowire.Fixed32Type:
  199. vi.typ = validationTypeFixed32
  200. case protowire.Fixed64Type:
  201. vi.typ = validationTypeFixed64
  202. case protowire.BytesType:
  203. vi.typ = validationTypeBytes
  204. }
  205. }
  206. }
  207. return vi
  208. }
  209. func (mi *MessageInfo) validate(b []byte, groupTag protowire.Number, opts unmarshalOptions) (out unmarshalOutput, result ValidationStatus) {
  210. mi.init()
  211. type validationState struct {
  212. typ validationType
  213. keyType, valType validationType
  214. endGroup protowire.Number
  215. mi *MessageInfo
  216. tail []byte
  217. requiredMask uint64
  218. }
  219. // Pre-allocate some slots to avoid repeated slice reallocation.
  220. states := make([]validationState, 0, 16)
  221. states = append(states, validationState{
  222. typ: validationTypeMessage,
  223. mi: mi,
  224. })
  225. if groupTag > 0 {
  226. states[0].typ = validationTypeGroup
  227. states[0].endGroup = groupTag
  228. }
  229. initialized := true
  230. start := len(b)
  231. State:
  232. for len(states) > 0 {
  233. st := &states[len(states)-1]
  234. for len(b) > 0 {
  235. // Parse the tag (field number and wire type).
  236. var tag uint64
  237. if b[0] < 0x80 {
  238. tag = uint64(b[0])
  239. b = b[1:]
  240. } else if len(b) >= 2 && b[1] < 128 {
  241. tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
  242. b = b[2:]
  243. } else {
  244. var n int
  245. tag, n = protowire.ConsumeVarint(b)
  246. if n < 0 {
  247. return out, ValidationInvalid
  248. }
  249. b = b[n:]
  250. }
  251. var num protowire.Number
  252. if n := tag >> 3; n < uint64(protowire.MinValidNumber) || n > uint64(protowire.MaxValidNumber) {
  253. return out, ValidationInvalid
  254. } else {
  255. num = protowire.Number(n)
  256. }
  257. wtyp := protowire.Type(tag & 7)
  258. if wtyp == protowire.EndGroupType {
  259. if st.endGroup == num {
  260. goto PopState
  261. }
  262. return out, ValidationInvalid
  263. }
  264. var vi validationInfo
  265. switch {
  266. case st.typ == validationTypeMap:
  267. switch num {
  268. case genid.MapEntry_Key_field_number:
  269. vi.typ = st.keyType
  270. case genid.MapEntry_Value_field_number:
  271. vi.typ = st.valType
  272. vi.mi = st.mi
  273. vi.requiredBit = 1
  274. }
  275. case flags.ProtoLegacy && st.mi.isMessageSet:
  276. switch num {
  277. case messageset.FieldItem:
  278. vi.typ = validationTypeMessageSetItem
  279. }
  280. default:
  281. var f *coderFieldInfo
  282. if int(num) < len(st.mi.denseCoderFields) {
  283. f = st.mi.denseCoderFields[num]
  284. } else {
  285. f = st.mi.coderFields[num]
  286. }
  287. if f != nil {
  288. vi = f.validation
  289. if vi.typ == validationTypeMessage && vi.mi == nil {
  290. // Probable weak field.
  291. //
  292. // TODO: Consider storing the results of this lookup somewhere
  293. // rather than recomputing it on every validation.
  294. fd := st.mi.Desc.Fields().ByNumber(num)
  295. if fd == nil || !fd.IsWeak() {
  296. break
  297. }
  298. messageName := fd.Message().FullName()
  299. messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
  300. switch err {
  301. case nil:
  302. vi.mi, _ = messageType.(*MessageInfo)
  303. case preg.NotFound:
  304. vi.typ = validationTypeBytes
  305. default:
  306. return out, ValidationUnknown
  307. }
  308. }
  309. break
  310. }
  311. // Possible extension field.
  312. //
  313. // TODO: We should return ValidationUnknown when:
  314. // 1. The resolver is not frozen. (More extensions may be added to it.)
  315. // 2. The resolver returns preg.NotFound.
  316. // In this case, a type added to the resolver in the future could cause
  317. // unmarshaling to begin failing. Supporting this requires some way to
  318. // determine if the resolver is frozen.
  319. xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
  320. if err != nil && err != preg.NotFound {
  321. return out, ValidationUnknown
  322. }
  323. if err == nil {
  324. vi = getExtensionFieldInfo(xt).validation
  325. }
  326. }
  327. if vi.requiredBit != 0 {
  328. // Check that the field has a compatible wire type.
  329. // We only need to consider non-repeated field types,
  330. // since repeated fields (and maps) can never be required.
  331. ok := false
  332. switch vi.typ {
  333. case validationTypeVarint:
  334. ok = wtyp == protowire.VarintType
  335. case validationTypeFixed32:
  336. ok = wtyp == protowire.Fixed32Type
  337. case validationTypeFixed64:
  338. ok = wtyp == protowire.Fixed64Type
  339. case validationTypeBytes, validationTypeUTF8String, validationTypeMessage:
  340. ok = wtyp == protowire.BytesType
  341. case validationTypeGroup:
  342. ok = wtyp == protowire.StartGroupType
  343. }
  344. if ok {
  345. st.requiredMask |= vi.requiredBit
  346. }
  347. }
  348. switch wtyp {
  349. case protowire.VarintType:
  350. if len(b) >= 10 {
  351. switch {
  352. case b[0] < 0x80:
  353. b = b[1:]
  354. case b[1] < 0x80:
  355. b = b[2:]
  356. case b[2] < 0x80:
  357. b = b[3:]
  358. case b[3] < 0x80:
  359. b = b[4:]
  360. case b[4] < 0x80:
  361. b = b[5:]
  362. case b[5] < 0x80:
  363. b = b[6:]
  364. case b[6] < 0x80:
  365. b = b[7:]
  366. case b[7] < 0x80:
  367. b = b[8:]
  368. case b[8] < 0x80:
  369. b = b[9:]
  370. case b[9] < 0x80 && b[9] < 2:
  371. b = b[10:]
  372. default:
  373. return out, ValidationInvalid
  374. }
  375. } else {
  376. switch {
  377. case len(b) > 0 && b[0] < 0x80:
  378. b = b[1:]
  379. case len(b) > 1 && b[1] < 0x80:
  380. b = b[2:]
  381. case len(b) > 2 && b[2] < 0x80:
  382. b = b[3:]
  383. case len(b) > 3 && b[3] < 0x80:
  384. b = b[4:]
  385. case len(b) > 4 && b[4] < 0x80:
  386. b = b[5:]
  387. case len(b) > 5 && b[5] < 0x80:
  388. b = b[6:]
  389. case len(b) > 6 && b[6] < 0x80:
  390. b = b[7:]
  391. case len(b) > 7 && b[7] < 0x80:
  392. b = b[8:]
  393. case len(b) > 8 && b[8] < 0x80:
  394. b = b[9:]
  395. case len(b) > 9 && b[9] < 2:
  396. b = b[10:]
  397. default:
  398. return out, ValidationInvalid
  399. }
  400. }
  401. continue State
  402. case protowire.BytesType:
  403. var size uint64
  404. if len(b) >= 1 && b[0] < 0x80 {
  405. size = uint64(b[0])
  406. b = b[1:]
  407. } else if len(b) >= 2 && b[1] < 128 {
  408. size = uint64(b[0]&0x7f) + uint64(b[1])<<7
  409. b = b[2:]
  410. } else {
  411. var n int
  412. size, n = protowire.ConsumeVarint(b)
  413. if n < 0 {
  414. return out, ValidationInvalid
  415. }
  416. b = b[n:]
  417. }
  418. if size > uint64(len(b)) {
  419. return out, ValidationInvalid
  420. }
  421. v := b[:size]
  422. b = b[size:]
  423. switch vi.typ {
  424. case validationTypeMessage:
  425. if vi.mi == nil {
  426. return out, ValidationUnknown
  427. }
  428. vi.mi.init()
  429. fallthrough
  430. case validationTypeMap:
  431. if vi.mi != nil {
  432. vi.mi.init()
  433. }
  434. states = append(states, validationState{
  435. typ: vi.typ,
  436. keyType: vi.keyType,
  437. valType: vi.valType,
  438. mi: vi.mi,
  439. tail: b,
  440. })
  441. b = v
  442. continue State
  443. case validationTypeRepeatedVarint:
  444. // Packed field.
  445. for len(v) > 0 {
  446. _, n := protowire.ConsumeVarint(v)
  447. if n < 0 {
  448. return out, ValidationInvalid
  449. }
  450. v = v[n:]
  451. }
  452. case validationTypeRepeatedFixed32:
  453. // Packed field.
  454. if len(v)%4 != 0 {
  455. return out, ValidationInvalid
  456. }
  457. case validationTypeRepeatedFixed64:
  458. // Packed field.
  459. if len(v)%8 != 0 {
  460. return out, ValidationInvalid
  461. }
  462. case validationTypeUTF8String:
  463. if !utf8.Valid(v) {
  464. return out, ValidationInvalid
  465. }
  466. }
  467. case protowire.Fixed32Type:
  468. if len(b) < 4 {
  469. return out, ValidationInvalid
  470. }
  471. b = b[4:]
  472. case protowire.Fixed64Type:
  473. if len(b) < 8 {
  474. return out, ValidationInvalid
  475. }
  476. b = b[8:]
  477. case protowire.StartGroupType:
  478. switch {
  479. case vi.typ == validationTypeGroup:
  480. if vi.mi == nil {
  481. return out, ValidationUnknown
  482. }
  483. vi.mi.init()
  484. states = append(states, validationState{
  485. typ: validationTypeGroup,
  486. mi: vi.mi,
  487. endGroup: num,
  488. })
  489. continue State
  490. case flags.ProtoLegacy && vi.typ == validationTypeMessageSetItem:
  491. typeid, v, n, err := messageset.ConsumeFieldValue(b, false)
  492. if err != nil {
  493. return out, ValidationInvalid
  494. }
  495. xt, err := opts.resolver.FindExtensionByNumber(st.mi.Desc.FullName(), typeid)
  496. switch {
  497. case err == preg.NotFound:
  498. b = b[n:]
  499. case err != nil:
  500. return out, ValidationUnknown
  501. default:
  502. xvi := getExtensionFieldInfo(xt).validation
  503. if xvi.mi != nil {
  504. xvi.mi.init()
  505. }
  506. states = append(states, validationState{
  507. typ: xvi.typ,
  508. mi: xvi.mi,
  509. tail: b[n:],
  510. })
  511. b = v
  512. continue State
  513. }
  514. default:
  515. n := protowire.ConsumeFieldValue(num, wtyp, b)
  516. if n < 0 {
  517. return out, ValidationInvalid
  518. }
  519. b = b[n:]
  520. }
  521. default:
  522. return out, ValidationInvalid
  523. }
  524. }
  525. if st.endGroup != 0 {
  526. return out, ValidationInvalid
  527. }
  528. if len(b) != 0 {
  529. return out, ValidationInvalid
  530. }
  531. b = st.tail
  532. PopState:
  533. numRequiredFields := 0
  534. switch st.typ {
  535. case validationTypeMessage, validationTypeGroup:
  536. numRequiredFields = int(st.mi.numRequiredFields)
  537. case validationTypeMap:
  538. // If this is a map field with a message value that contains
  539. // required fields, require that the value be present.
  540. if st.mi != nil && st.mi.numRequiredFields > 0 {
  541. numRequiredFields = 1
  542. }
  543. }
  544. // If there are more than 64 required fields, this check will
  545. // always fail and we will report that the message is potentially
  546. // uninitialized.
  547. if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
  548. initialized = false
  549. }
  550. states = states[:len(states)-1]
  551. }
  552. out.n = start - len(b)
  553. if initialized {
  554. out.initialized = true
  555. }
  556. return out, ValidationValid
  557. }