1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
|
package stdlib_test
import (
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"github.com/ulule/limiter/v3"
"github.com/ulule/limiter/v3/drivers/middleware/stdlib"
"github.com/ulule/limiter/v3/drivers/store/memory"
)
func TestHTTPMiddleware(t *testing.T) {
is := require.New(t)
request, err := http.NewRequest("GET", "/", nil)
is.NoError(err)
is.NotNil(request)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, thr := w.Write([]byte("hello"))
if thr != nil {
panic(thr)
}
})
store := memory.NewStore()
is.NotZero(store)
rate, err := limiter.NewRateFromFormatted("10-M")
is.NoError(err)
is.NotZero(rate)
middleware := stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler)
is.NotZero(middleware)
success := int64(10)
clients := int64(100)
//
// Sequential
//
for i := int64(1); i <= clients; i++ {
resp := httptest.NewRecorder()
middleware.ServeHTTP(resp, request)
if i <= success {
is.Equal(resp.Code, http.StatusOK)
} else {
is.Equal(resp.Code, http.StatusTooManyRequests)
}
}
//
// Concurrent
//
store = memory.NewStore()
is.NotZero(store)
middleware = stdlib.NewMiddleware(limiter.New(store, rate)).Handler(handler)
is.NotZero(middleware)
wg := &sync.WaitGroup{}
counter := int64(0)
for i := int64(1); i <= clients; i++ {
wg.Add(1)
go func() {
resp := httptest.NewRecorder()
middleware.ServeHTTP(resp, request)
if resp.Code == http.StatusOK {
atomic.AddInt64(&counter, 1)
}
wg.Done()
}()
}
wg.Wait()
is.Equal(success, atomic.LoadInt64(&counter))
}
|