1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
package poolhttp
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func requireBody(t *testing.T, want string, r io.ReadCloser) {
t.Helper()
t.Cleanup(func() {
require.NoError(t, r.Close())
})
b, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, want, string(b))
}
func TestClient(t *testing.T) {
httpSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello World")
}))
t.Cleanup(httpSrv.Close)
tlsSrv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello World")
}))
t.Cleanup(tlsSrv.Close)
tests := []struct {
name string
client *Client
srv *httptest.Server
}{
{"http", New(func() *http.Client { return httpSrv.Client() }), httpSrv},
{"tls", New(func() *http.Client { return tlsSrv.Client() }), tlsSrv},
{"nil", New(func() *http.Client { return nil }), httpSrv},
{"empty", &Client{}, httpSrv},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
resp, err := tc.client.Get(tc.srv.URL)
require.NoError(t, err)
requireBody(t, "Hello World\n", resp.Body)
req, err := http.NewRequest("GET", tc.srv.URL, http.NoBody)
require.NoError(t, err)
resp, err = tc.client.Do(req)
require.NoError(t, err)
requireBody(t, "Hello World\n", resp.Body)
client := &http.Client{
Transport: tc.client.Transport(),
}
resp, err = client.Get(tc.srv.URL)
require.NoError(t, err)
requireBody(t, "Hello World\n", resp.Body)
})
}
}
func TestClient_SetNew(t *testing.T) {
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello World")
}))
t.Cleanup(srv.Close)
c := New(func() *http.Client {
return srv.Client()
})
tests := []struct {
name string
client *http.Client
assertion assert.ErrorAssertionFunc
}{
{"ok", srv.Client(), assert.NoError},
{"fail", http.DefaultClient, assert.Error},
{"ok again", srv.Client(), assert.NoError},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
c.SetNew(func() *http.Client {
return tc.client
})
_, err := c.Get(srv.URL)
tc.assertion(t, err)
})
}
}
func TestClient_parallel(t *testing.T) {
t.Parallel()
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello World")
}))
t.Cleanup(srv.Close)
c := New(func() *http.Client {
return srv.Client()
})
req, err := http.NewRequest("GET", srv.URL, http.NoBody)
require.NoError(t, err)
for i := range 10 {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Parallel()
resp, err := c.Get(srv.URL)
require.NoError(t, err)
requireBody(t, "Hello World\n", resp.Body)
resp, err = c.Do(req)
require.NoError(t, err)
requireBody(t, "Hello World\n", resp.Body)
})
}
}
|