configure_transport.go 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. // Copyright 2015 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build go1.6
  5. package http2
  6. import (
  7. "crypto/tls"
  8. "fmt"
  9. "net/http"
  10. )
  11. func configureTransport(t1 *http.Transport) (*Transport, error) {
  12. connPool := new(clientConnPool)
  13. t2 := &Transport{
  14. ConnPool: noDialClientConnPool{connPool},
  15. t1: t1,
  16. }
  17. connPool.t = t2
  18. if err := registerHTTPSProtocol(t1, noDialH2RoundTripper{t2}); err != nil {
  19. return nil, err
  20. }
  21. if t1.TLSClientConfig == nil {
  22. t1.TLSClientConfig = new(tls.Config)
  23. }
  24. if !strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
  25. t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
  26. }
  27. if !strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
  28. t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
  29. }
  30. upgradeFn := func(authority string, c *tls.Conn) http.RoundTripper {
  31. addr := authorityAddr("https", authority)
  32. if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
  33. go c.Close()
  34. return erringRoundTripper{err}
  35. } else if !used {
  36. // Turns out we don't need this c.
  37. // For example, two goroutines made requests to the same host
  38. // at the same time, both kicking off TCP dials. (since protocol
  39. // was unknown)
  40. go c.Close()
  41. }
  42. return t2
  43. }
  44. if m := t1.TLSNextProto; len(m) == 0 {
  45. t1.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{
  46. "h2": upgradeFn,
  47. }
  48. } else {
  49. m["h2"] = upgradeFn
  50. }
  51. return t2, nil
  52. }
  53. // registerHTTPSProtocol calls Transport.RegisterProtocol but
  54. // converting panics into errors.
  55. func registerHTTPSProtocol(t *http.Transport, rt noDialH2RoundTripper) (err error) {
  56. defer func() {
  57. if e := recover(); e != nil {
  58. err = fmt.Errorf("%v", e)
  59. }
  60. }()
  61. t.RegisterProtocol("https", rt)
  62. return nil
  63. }
  64. // noDialH2RoundTripper is a RoundTripper which only tries to complete the request
  65. // if there's already has a cached connection to the host.
  66. // (The field is exported so it can be accessed via reflect from net/http; tested
  67. // by TestNoDialH2RoundTripperType)
  68. type noDialH2RoundTripper struct{ *Transport }
  69. func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
  70. res, err := rt.Transport.RoundTrip(req)
  71. if isNoCachedConnError(err) {
  72. return nil, http.ErrSkipAltProtocol
  73. }
  74. return res, err
  75. }