1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
|
package connlimit
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/v2/testutils"
"github.com/vulcand/oxy/v2/utils"
)
// We've hit the limit and were able to proceed once the request has completed.
func TestHitLimitAndRelease(t *testing.T) {
wait := make(chan bool)
proceed := make(chan bool)
finish := make(chan bool)
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
t.Logf("%v", req.Header)
if req.Header.Get("Wait") != "" {
proceed <- true
<-wait
}
_, _ = w.Write([]byte("hello"))
})
cl, err := New(handler, headerLimit, 1)
require.NoError(t, err)
srv := httptest.NewServer(cl)
t.Cleanup(srv.Close)
go func() {
re, _, errGet := testutils.Get(srv.URL, testutils.Header("Limit", "a"), testutils.Header("wait", "yes"))
require.NoError(t, errGet)
assert.Equal(t, http.StatusOK, re.StatusCode)
finish <- true
}()
<-proceed
re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a"))
require.NoError(t, err)
assert.Equal(t, http.StatusTooManyRequests, re.StatusCode)
// request from another source succeeds
re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "b"))
require.NoError(t, err)
assert.Equal(t, http.StatusOK, re.StatusCode)
// Once the first request finished, next one succeeds
close(wait)
<-finish
re, _, err = testutils.Get(srv.URL, testutils.Header("Limit", "a"))
require.NoError(t, err)
assert.Equal(t, http.StatusOK, re.StatusCode)
}
// We've hit the limit and were able to proceed once the request has completed.
func TestCustomHandlers(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte("hello"))
})
errHandler := utils.ErrorHandlerFunc(func(w http.ResponseWriter, req *http.Request, err error) {
w.WriteHeader(http.StatusTeapot)
_, _ = w.Write([]byte(http.StatusText(http.StatusTeapot)))
})
l, err := New(handler, headerLimit, 0, ErrorHandler(errHandler))
require.NoError(t, err)
srv := httptest.NewServer(l)
t.Cleanup(srv.Close)
re, _, err := testutils.Get(srv.URL, testutils.Header("Limit", "a"))
require.NoError(t, err)
assert.Equal(t, http.StatusTeapot, re.StatusCode)
}
// We've hit the limit and were able to proceed once the request has completed.
func TestFaultyExtract(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte("hello"))
})
l, err := New(handler, faultyExtract, 1)
require.NoError(t, err)
srv := httptest.NewServer(l)
t.Cleanup(srv.Close)
re, _, err := testutils.Get(srv.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, re.StatusCode)
}
func headerLimiter(req *http.Request) (string, int64, error) {
return req.Header.Get("Limit"), 1, nil
}
func faultyExtractor(_ *http.Request) (string, int64, error) {
return "", -1, fmt.Errorf("oops")
}
var headerLimit = utils.ExtractorFunc(headerLimiter)
var faultyExtract = utils.ExtractorFunc(faultyExtractor)
|