1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
|
package retry_test
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/avast/retry-go/v4"
"github.com/stretchr/testify/assert"
)
// RetriableError is a custom error that contains a positive duration for the next retry
type RetriableError struct {
Err error
RetryAfter time.Duration
}
// Error returns error message and a Retry-After duration
func (e *RetriableError) Error() string {
return fmt.Sprintf("%s (retry after %v)", e.Err.Error(), e.RetryAfter)
}
var _ error = (*RetriableError)(nil)
// TestCustomRetryFunction shows how to use a custom retry function
func TestCustomRetryFunction(t *testing.T) {
attempts := 5 // server succeeds after 5 attempts
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if attempts > 0 {
// inform the client to retry after one second using standard
// HTTP 429 status code with Retry-After header in seconds
w.Header().Add("Retry-After", "1")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("Server limit reached"))
attempts--
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("hello"))
}))
defer ts.Close()
var body []byte
err := retry.Do(
func() error {
resp, err := http.Get(ts.URL)
if err == nil {
defer func() {
if err := resp.Body.Close(); err != nil {
panic(err)
}
}()
body, err = ioutil.ReadAll(resp.Body)
if resp.StatusCode != 200 {
err = fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
if resp.StatusCode == http.StatusTooManyRequests {
// check Retry-After header if it contains seconds to wait for the next retry
if retryAfter, e := strconv.ParseInt(resp.Header.Get("Retry-After"), 10, 32); e == nil {
// the server returns 0 to inform that the operation cannot be retried
if retryAfter <= 0 {
return retry.Unrecoverable(err)
}
return &RetriableError{
Err: err,
RetryAfter: time.Duration(retryAfter) * time.Second,
}
}
// A real implementation should also try to http.Parse the retryAfter response header
// to conform with HTTP specification. Herein we know here that we return only seconds.
}
}
}
return err
},
retry.DelayType(func(n uint, err error, config *retry.Config) time.Duration {
fmt.Println("Server fails with: " + err.Error())
if retriable, ok := err.(*RetriableError); ok {
fmt.Printf("Client follows server recommendation to retry after %v\n", retriable.RetryAfter)
return retriable.RetryAfter
}
// apply a default exponential back off strategy
return retry.BackOffDelay(n, err, config)
}),
)
fmt.Println("Server responds with: " + string(body))
assert.NoError(t, err)
assert.Equal(t, "hello", string(body))
}
|