asndb.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. package asndb
  2. import (
  3. "archive/zip"
  4. "bytes"
  5. "crypto/md5"
  6. "encoding/csv"
  7. "encoding/hex"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "log"
  13. "net"
  14. "net/http"
  15. "os"
  16. "regexp"
  17. "strings"
  18. "sync"
  19. "git.scraperwall.com/scw/ip"
  20. privip "git.scraperwall.com/scw/ip"
  21. "github.com/google/btree"
  22. )
  23. const (
  24. asnFile = "GeoLite2-ASN-CSV.zip"
  25. asnMd5File = "GeoLite2-ASN-CSV.zip.md5"
  26. )
  27. // DB contains a b-tree of ASNs
  28. type DB struct {
  29. db *btree.BTree
  30. mutex sync.Mutex
  31. privIPs *ip.IP
  32. }
  33. // Lookup returns the ASN struct of the network that contains ip
  34. func (a *DB) Lookup(ip net.IP) *ASN {
  35. var asn *ASN
  36. privNet := a.privIPs.Network(ip)
  37. if privNet != nil {
  38. pasn, _ := NewASN(privNet.String(), "-1", "Private Network")
  39. return pasn
  40. }
  41. ipNorm := ip.To16()
  42. dummy := ASN{
  43. To: &ipNorm,
  44. }
  45. a.mutex.Lock()
  46. defer a.mutex.Unlock()
  47. a.db.AscendGreaterOrEqual(&dummy, func(item btree.Item) bool {
  48. asn = item.(*ASN)
  49. if !asn.Network.Contains(ip) {
  50. asn, _ = NewASN("0.0.0.0/32", "-1", "Unknown Network")
  51. }
  52. return false
  53. })
  54. return asn
  55. }
  56. // Size returns the number of networks in the database
  57. func (a *DB) Size() int {
  58. return a.db.Len()
  59. }
  60. // Each iterates over each element in the database
  61. func (a *DB) Each(f func(a *ASN) bool) {
  62. a.db.Ascend(func(item btree.Item) bool {
  63. return f(item.(*ASN))
  64. })
  65. }
  66. // load pulls fresh data from maxmind
  67. func (a *DB) load(baseURL string) error {
  68. asndb, err := fromURL(baseURL)
  69. if err != nil {
  70. return err
  71. }
  72. if asndb == nil {
  73. return errors.New("asndb is nil")
  74. }
  75. a.mutex.Lock()
  76. defer a.mutex.Unlock()
  77. a.db = asndb
  78. return nil
  79. }
  80. // fromURL loads data from maxmind and creates an ASNDB with this fresh data
  81. func fromURL(baseURL string) (*btree.BTree, error) {
  82. // Get MD5 sum for tar.gz file
  83. asnMd5URL := baseURL + "/" + asnMd5File
  84. resp, err := http.Get(asnMd5URL)
  85. if err != nil {
  86. return nil, err
  87. }
  88. md5Sum, err := ioutil.ReadAll(resp.Body)
  89. if err != nil {
  90. return nil, err
  91. }
  92. resp.Body.Close()
  93. asnURL := baseURL + "/" + asnFile
  94. // Load the tar.gz file
  95. resp, err = http.Get(asnURL)
  96. if err != nil {
  97. return nil, err
  98. }
  99. defer resp.Body.Close()
  100. if resp.StatusCode != http.StatusOK {
  101. return nil, fmt.Errorf("%s status %d", asnURL, resp.StatusCode)
  102. }
  103. bodyData, err := ioutil.ReadAll(resp.Body)
  104. if err != nil {
  105. return nil, err
  106. }
  107. // Build the MD5 sum of the downloaded tar.gz
  108. hash := md5.New()
  109. if _, err := io.Copy(hash, bytes.NewReader(bodyData)); err != nil {
  110. return nil, err
  111. }
  112. if string(md5Sum) != hex.EncodeToString(hash.Sum(nil)) {
  113. log.Println("asndb checksum mismatch")
  114. return nil, fmt.Errorf("checksum mismatch: %s != %s", md5Sum, hash.Sum(nil))
  115. }
  116. // Copy the data to a temporary file for zip to be able to open it
  117. tmpF, err := ioutil.TempFile("/tmp", "asndb-")
  118. if err != nil {
  119. return nil, err
  120. }
  121. defer os.Remove(tmpF.Name())
  122. io.Copy(tmpF, bytes.NewReader(bodyData))
  123. tmpF.Close()
  124. return fromFile(tmpF.Name())
  125. }
  126. func parseCSV(reader io.Reader) (*btree.BTree, error) {
  127. csvr := csv.NewReader(reader)
  128. numMatch := regexp.MustCompile(`^[0-9a-fA-F]+[\.:]`)
  129. tree := btree.New(8)
  130. for {
  131. record, err := csvr.Read()
  132. if err == io.EOF {
  133. break
  134. }
  135. if err != nil {
  136. log.Fatal(err)
  137. }
  138. // ignore the header and anything that doesn't look like an IP
  139. if !numMatch.MatchString(record[0]) {
  140. continue
  141. }
  142. a, err := NewASN(record[0], record[1], record[2])
  143. if err != nil {
  144. return nil, err
  145. }
  146. tree.ReplaceOrInsert(a)
  147. }
  148. return tree, nil
  149. }
  150. func fromFile(filename string) (*btree.BTree, error) {
  151. zipReader, err := zip.OpenReader(filename)
  152. if err != nil {
  153. return nil, err
  154. }
  155. defer zipReader.Close()
  156. buf := bytes.NewBufferString("")
  157. // find the data in the zip file
  158. for _, f := range zipReader.File {
  159. if strings.HasSuffix(f.Name, "GeoLite2-ASN-Blocks-IPv4.csv") || strings.HasSuffix(f.Name, "GeoLite2-ASN-Blocks-IPv6.csv") {
  160. asn, err := f.Open()
  161. if err != nil {
  162. return nil, err
  163. }
  164. io.Copy(buf, asn)
  165. }
  166. }
  167. if buf.Len() <= 0 {
  168. return nil, fmt.Errorf("not enough data")
  169. }
  170. // generate the tree
  171. tree, err := parseCSV(buf)
  172. if err != nil {
  173. return nil, err
  174. }
  175. return tree, nil
  176. }
  177. // New creates a new ASN database. fname denotes the path to the Maxmind ASN CSV file
  178. func New(baseURLOrFile string) (*DB, error) {
  179. db := &DB{
  180. mutex: sync.Mutex{},
  181. privIPs: privip.NewIP(),
  182. }
  183. if strings.HasPrefix(baseURLOrFile, "https://") || strings.HasPrefix(baseURLOrFile, "http://") {
  184. err := db.load(baseURLOrFile)
  185. if err != nil {
  186. return nil, err
  187. }
  188. } else {
  189. var err error
  190. db.db, err = fromFile(baseURLOrFile)
  191. if err != nil {
  192. return nil, err
  193. }
  194. }
  195. return db, nil
  196. }