1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
package httpz
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
_ http.RoundTripper = (*UpgradeRoundTripper)(nil)
)
const (
requestBodyData = "request_jkasdbfkadsbfkadbfkjasbfkasbdf"
requestUpgradeBodyData = "upgrade_request_asdfjkasbfkasdf"
responseUpgradeBodyData = "upgrade_response_asdfasdfadsf"
)
func TestUpgradeRoundTripper_HappyPath(t *testing.T) {
var wg sync.WaitGroup
wg.Wait()
wg.Add(1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer wg.Done()
t.Log("SRV: Reading request")
reqBody, err := io.ReadAll(r.Body)
if !assert.NoError(t, err) {
return
}
t.Log("SRV: Read request")
assert.Equal(t, requestBodyData, string(reqBody))
t.Log("SRV: Writing response")
w.WriteHeader(http.StatusSwitchingProtocols)
// 101 does not allow response body
t.Log("SRV: Wrote response")
conn, wr, err := w.(http.Hijacker).Hijack()
if !assert.NoError(t, err) {
return
}
defer func() {
t.Log("SRV: Closing conn")
assert.NoError(t, conn.Close())
t.Log("SRV: Closed conn")
}()
connBody := make([]byte, len(requestUpgradeBodyData))
t.Log("SRV: Reading conn request")
_, err = io.ReadFull(wr, connBody)
t.Log("SRV: Read conn request")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, requestUpgradeBodyData, string(connBody))
t.Log("SRV: Writing conn response")
_, err = conn.Write([]byte(responseUpgradeBodyData))
t.Log("SRV: Wrote conn response")
if !assert.NoError(t, err) {
return
}
}))
defer server.Close()
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, strings.NewReader(requestBodyData))
require.NoError(t, err)
rt := UpgradeRoundTripper{
Dialer: &net.Dialer{},
TLSDialer: &tls.Dialer{},
}
resp, err := rt.RoundTrip(req)
require.NoError(t, err)
defer func() {
assert.NoError(t, rt.Conn.Close())
}()
respData, err := io.ReadAll(resp.Body)
assert.NoError(t, resp.Body.Close())
require.NoError(t, err) // check err from ReadAll
assert.Empty(t, respData)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
_, err = rt.Conn.Write([]byte(requestUpgradeBodyData))
require.NoError(t, err)
connBody, err := io.ReadAll(rt.ConnReader)
require.NoError(t, err)
assert.Equal(t, responseUpgradeBodyData, string(connBody))
}
|