1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
package httpz
import (
"bufio"
"crypto/tls"
"fmt"
"net"
"net/http"
"k8s.io/apimachinery/third_party/forked/golang/netutil"
)
// UpgradeRoundTripper allows to access the underlying network connection after round tripping a request/response.
// A http.RoundTripper must be safe for concurrent use by multiple goroutines, but this implementation is not.
// It does not pool network connections so it's not worth the complexity. Instead, the using code must
// use a new instance for each request.
// See https://pkg.go.dev/net/http#RoundTripper.
type UpgradeRoundTripper struct {
// Dialer is the dialer used to connect.
Dialer *net.Dialer
// TLSDialer is the dialer used to connect over TLS.
TLSDialer *tls.Dialer
// Conn is the underlying network connection to the remote server.
Conn net.Conn
// ConnReader is a buffered reader for Conn.
// It may contain some data that has been buffered from Conn while reading the server's response.
ConnReader *bufio.Reader
}
func (u *UpgradeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
conn, err := u.dial(req)
if err != nil {
if req.Body != nil {
_ = req.Body.Close()
}
return nil, err
}
cc := NewContextConn(conn)
go cc.CloseOnDone(req.Context())
success := false
defer func() {
if !success {
_ = cc.Close()
}
}()
if err = req.Write(cc); err != nil {
return nil, err
}
connReader := bufio.NewReader(cc)
resp, err := http.ReadResponse(connReader, req)
if err != nil {
return nil, err
}
u.Conn = cc
u.ConnReader = connReader
success = true
return resp, nil
}
func (u *UpgradeRoundTripper) dial(req *http.Request) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(req.URL)
ctx := req.Context()
switch req.URL.Scheme {
case "http":
return u.Dialer.DialContext(ctx, "tcp", dialAddr)
case "https":
return u.TLSDialer.DialContext(ctx, "tcp", dialAddr)
default:
return nil, fmt.Errorf("unsupported URL scheme: %s", req.URL.Scheme)
}
}
|