tcpip.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. // Copyright 2012 Google, Inc. All rights reserved.
  2. // Copyright 2009-2011 Andreas Krennmair. All rights reserved.
  3. //
  4. // Use of this source code is governed by a BSD-style license
  5. // that can be found in the LICENSE file in the root of the source
  6. // tree.
  7. package layers
  8. import (
  9. "errors"
  10. "fmt"
  11. "github.com/google/gopacket"
  12. )
  13. // Checksum computation for TCP/UDP.
  14. type tcpipchecksum struct {
  15. pseudoheader tcpipPseudoHeader
  16. }
  17. type tcpipPseudoHeader interface {
  18. pseudoheaderChecksum() (uint32, error)
  19. }
  20. func (ip *IPv4) pseudoheaderChecksum() (csum uint32, err error) {
  21. if err := ip.AddressTo4(); err != nil {
  22. return 0, err
  23. }
  24. csum += (uint32(ip.SrcIP[0]) + uint32(ip.SrcIP[2])) << 8
  25. csum += uint32(ip.SrcIP[1]) + uint32(ip.SrcIP[3])
  26. csum += (uint32(ip.DstIP[0]) + uint32(ip.DstIP[2])) << 8
  27. csum += uint32(ip.DstIP[1]) + uint32(ip.DstIP[3])
  28. return csum, nil
  29. }
  30. func (ip *IPv6) pseudoheaderChecksum() (csum uint32, err error) {
  31. if err := ip.AddressTo16(); err != nil {
  32. return 0, err
  33. }
  34. for i := 0; i < 16; i += 2 {
  35. csum += uint32(ip.SrcIP[i]) << 8
  36. csum += uint32(ip.SrcIP[i+1])
  37. csum += uint32(ip.DstIP[i]) << 8
  38. csum += uint32(ip.DstIP[i+1])
  39. }
  40. return csum, nil
  41. }
  42. // Calculate the TCP/IP checksum defined in rfc1071. The passed-in csum is any
  43. // initial checksum data that's already been computed.
  44. func tcpipChecksum(data []byte, csum uint32) uint16 {
  45. // to handle odd lengths, we loop to length - 1, incrementing by 2, then
  46. // handle the last byte specifically by checking against the original
  47. // length.
  48. length := len(data) - 1
  49. for i := 0; i < length; i += 2 {
  50. // For our test packet, doing this manually is about 25% faster
  51. // (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16.
  52. csum += uint32(data[i]) << 8
  53. csum += uint32(data[i+1])
  54. }
  55. if len(data)%2 == 1 {
  56. csum += uint32(data[length]) << 8
  57. }
  58. for csum > 0xffff {
  59. csum = (csum >> 16) + (csum & 0xffff)
  60. }
  61. return ^uint16(csum)
  62. }
  63. // computeChecksum computes a TCP or UDP checksum. headerAndPayload is the
  64. // serialized TCP or UDP header plus its payload, with the checksum zero'd
  65. // out. headerProtocol is the IP protocol number of the upper-layer header.
  66. func (c *tcpipchecksum) computeChecksum(headerAndPayload []byte, headerProtocol IPProtocol) (uint16, error) {
  67. if c.pseudoheader == nil {
  68. return 0, errors.New("TCP/IP layer 4 checksum cannot be computed without network layer... call SetNetworkLayerForChecksum to set which layer to use")
  69. }
  70. length := uint32(len(headerAndPayload))
  71. csum, err := c.pseudoheader.pseudoheaderChecksum()
  72. if err != nil {
  73. return 0, err
  74. }
  75. csum += uint32(headerProtocol)
  76. csum += length & 0xffff
  77. csum += length >> 16
  78. return tcpipChecksum(headerAndPayload, csum), nil
  79. }
  80. // SetNetworkLayerForChecksum tells this layer which network layer is wrapping it.
  81. // This is needed for computing the checksum when serializing, since TCP/IP transport
  82. // layer checksums depends on fields in the IPv4 or IPv6 layer that contains it.
  83. // The passed in layer must be an *IPv4 or *IPv6.
  84. func (i *tcpipchecksum) SetNetworkLayerForChecksum(l gopacket.NetworkLayer) error {
  85. switch v := l.(type) {
  86. case *IPv4:
  87. i.pseudoheader = v
  88. case *IPv6:
  89. i.pseudoheader = v
  90. default:
  91. return fmt.Errorf("cannot use layer type %v for tcp checksum network layer", l.LayerType())
  92. }
  93. return nil
  94. }