traverse.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package maxminddb
  2. import "net"
  3. // Internal structure used to keep track of nodes we still need to visit.
  4. type netNode struct {
  5. ip net.IP
  6. bit uint
  7. pointer uint
  8. }
  9. // Networks represents a set of subnets that we are iterating over.
  10. type Networks struct {
  11. reader *Reader
  12. nodes []netNode // Nodes we still have to visit.
  13. lastNode netNode
  14. err error
  15. }
  16. // Networks returns an iterator that can be used to traverse all networks in
  17. // the database.
  18. //
  19. // Please note that a MaxMind DB may map IPv4 networks into several locations
  20. // in in an IPv6 database. This iterator will iterate over all of these
  21. // locations separately.
  22. func (r *Reader) Networks() *Networks {
  23. s := 4
  24. if r.Metadata.IPVersion == 6 {
  25. s = 16
  26. }
  27. return &Networks{
  28. reader: r,
  29. nodes: []netNode{
  30. {
  31. ip: make(net.IP, s),
  32. },
  33. },
  34. }
  35. }
  36. // Next prepares the next network for reading with the Network method. It
  37. // returns true if there is another network to be processed and false if there
  38. // are no more networks or if there is an error.
  39. func (n *Networks) Next() bool {
  40. for len(n.nodes) > 0 {
  41. node := n.nodes[len(n.nodes)-1]
  42. n.nodes = n.nodes[:len(n.nodes)-1]
  43. for {
  44. if node.pointer < n.reader.Metadata.NodeCount {
  45. ipRight := make(net.IP, len(node.ip))
  46. copy(ipRight, node.ip)
  47. if len(ipRight) <= int(node.bit>>3) {
  48. n.err = newInvalidDatabaseError(
  49. "invalid search tree at %v/%v", ipRight, node.bit)
  50. return false
  51. }
  52. ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8))
  53. rightPointer, err := n.reader.readNode(node.pointer, 1)
  54. if err != nil {
  55. n.err = err
  56. return false
  57. }
  58. node.bit++
  59. n.nodes = append(n.nodes, netNode{
  60. pointer: rightPointer,
  61. ip: ipRight,
  62. bit: node.bit,
  63. })
  64. node.pointer, err = n.reader.readNode(node.pointer, 0)
  65. if err != nil {
  66. n.err = err
  67. return false
  68. }
  69. } else if node.pointer > n.reader.Metadata.NodeCount {
  70. n.lastNode = node
  71. return true
  72. } else {
  73. break
  74. }
  75. }
  76. }
  77. return false
  78. }
  79. // Network returns the current network or an error if there is a problem
  80. // decoding the data for the network. It takes a pointer to a result value to
  81. // decode the network's data into.
  82. func (n *Networks) Network(result interface{}) (*net.IPNet, error) {
  83. if err := n.reader.retrieveData(n.lastNode.pointer, result); err != nil {
  84. return nil, err
  85. }
  86. return &net.IPNet{
  87. IP: n.lastNode.ip,
  88. Mask: net.CIDRMask(int(n.lastNode.bit), len(n.lastNode.ip)*8),
  89. }, nil
  90. }
  91. // Err returns an error, if any, that was encountered during iteration.
  92. func (n *Networks) Err() error {
  93. return n.err
  94. }