1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
|
package retryablehttp
import (
"context"
"errors"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"sync/atomic"
"testing"
)
func TestRoundTripper_implements(t *testing.T) {
// Compile-time proof of interface satisfaction.
var _ http.RoundTripper = &RoundTripper{}
}
func TestRoundTripper_init(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()
// Start with a new empty RoundTripper.
rt := &RoundTripper{}
// RoundTrip once.
req, _ := http.NewRequest("GET", ts.URL, nil)
if _, err := rt.RoundTrip(req); err != nil {
t.Fatal(err)
}
// Check that the Client was initialized.
if rt.Client == nil {
t.Fatal("expected rt.Client to be initialized")
}
// Save the Client for later comparison.
initialClient := rt.Client
// RoundTrip again.
req, _ = http.NewRequest("GET", ts.URL, nil)
if _, err := rt.RoundTrip(req); err != nil {
t.Fatal(err)
}
// Check that the underlying Client is unchanged.
if rt.Client != initialClient {
t.Fatalf("expected %v, got %v", initialClient, rt.Client)
}
}
func TestRoundTripper_RoundTrip(t *testing.T) {
var reqCount int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqNo := atomic.AddInt32(&reqCount, 1)
if reqNo < 3 {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte("success!"))
}
}))
defer ts.Close()
// Make a client with some custom settings to verify they are used.
retryClient := NewClient()
retryClient.CheckRetry = func(_ context.Context, resp *http.Response, _ error) (bool, error) {
return resp.StatusCode == 404, nil
}
// Get the standard client and execute the request.
client := retryClient.StandardClient()
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
// Check the response to ensure the client behaved as expected.
if resp.StatusCode != 200 {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
if v, err := ioutil.ReadAll(resp.Body); err != nil {
t.Fatal(err)
} else if string(v) != "success!" {
t.Fatalf("expected %q, got %q", "success!", v)
}
}
func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) {
// Make a client with some custom settings to verify they are used.
retryClient := NewClient()
retryClient.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
if err != nil {
return true, err
}
return false, nil
}
retryClient.ErrorHandler = PassthroughErrorHandler
expectedError := &url.Error{
Op: "Get",
URL: "http://999.999.999.999:999/",
Err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &net.DNSError{
Name: "999.999.999.999",
Err: "no such host",
IsNotFound: true,
},
},
}
// Get the standard client and execute the request.
client := retryClient.StandardClient()
_, err := client.Get("http://999.999.999.999:999/")
// assert expectations
if !reflect.DeepEqual(expectedError, normalizeError(err)) {
t.Fatalf("expected %q, got %q", expectedError, err)
}
}
func normalizeError(err error) error {
var dnsError *net.DNSError
if errors.As(err, &dnsError) {
// this field is populated with the DNS server on on CI, but not locally
dnsError.Server = ""
}
return err
}
|