proxy.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /*
  2. *
  3. * Copyright 2017 gRPC authors.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *
  17. */
  18. package grpc
  19. import (
  20. "bufio"
  21. "errors"
  22. "fmt"
  23. "io"
  24. "net"
  25. "net/http"
  26. "net/http/httputil"
  27. "net/url"
  28. "golang.org/x/net/context"
  29. )
  30. var (
  31. // errDisabled indicates that proxy is disabled for the address.
  32. errDisabled = errors.New("proxy is disabled for the address")
  33. // The following variable will be overwritten in the tests.
  34. httpProxyFromEnvironment = http.ProxyFromEnvironment
  35. )
  36. func mapAddress(ctx context.Context, address string) (string, error) {
  37. req := &http.Request{
  38. URL: &url.URL{
  39. Scheme: "https",
  40. Host: address,
  41. },
  42. }
  43. url, err := httpProxyFromEnvironment(req)
  44. if err != nil {
  45. return "", err
  46. }
  47. if url == nil {
  48. return "", errDisabled
  49. }
  50. return url.Host, nil
  51. }
  52. // To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
  53. // It's possible that this reader reads more than what's need for the response and stores
  54. // those bytes in the buffer.
  55. // bufConn wraps the original net.Conn and the bufio.Reader to make sure we don't lose the
  56. // bytes in the buffer.
  57. type bufConn struct {
  58. net.Conn
  59. r io.Reader
  60. }
  61. func (c *bufConn) Read(b []byte) (int, error) {
  62. return c.r.Read(b)
  63. }
  64. func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) {
  65. defer func() {
  66. if err != nil {
  67. conn.Close()
  68. }
  69. }()
  70. req := (&http.Request{
  71. Method: http.MethodConnect,
  72. URL: &url.URL{Host: addr},
  73. Header: map[string][]string{"User-Agent": {grpcUA}},
  74. })
  75. if err := sendHTTPRequest(ctx, req, conn); err != nil {
  76. return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
  77. }
  78. r := bufio.NewReader(conn)
  79. resp, err := http.ReadResponse(r, req)
  80. if err != nil {
  81. return nil, fmt.Errorf("reading server HTTP response: %v", err)
  82. }
  83. defer resp.Body.Close()
  84. if resp.StatusCode != http.StatusOK {
  85. dump, err := httputil.DumpResponse(resp, true)
  86. if err != nil {
  87. return nil, fmt.Errorf("failed to do connect handshake, status code: %s", resp.Status)
  88. }
  89. return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
  90. }
  91. return &bufConn{Conn: conn, r: r}, nil
  92. }
  93. // newProxyDialer returns a dialer that connects to proxy first if necessary.
  94. // The returned dialer checks if a proxy is necessary, dial to the proxy with the
  95. // provided dialer, does HTTP CONNECT handshake and returns the connection.
  96. func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
  97. return func(ctx context.Context, addr string) (conn net.Conn, err error) {
  98. var skipHandshake bool
  99. newAddr, err := mapAddress(ctx, addr)
  100. if err != nil {
  101. if err != errDisabled {
  102. return nil, err
  103. }
  104. skipHandshake = true
  105. newAddr = addr
  106. }
  107. conn, err = dialer(ctx, newAddr)
  108. if err != nil {
  109. return
  110. }
  111. if !skipHandshake {
  112. conn, err = doHTTPConnectHandshake(ctx, conn, addr)
  113. }
  114. return
  115. }
  116. }