1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
|
package wstunnel
import (
"context"
"fmt"
"hash/fnv"
"io"
"math/rand"
"net"
"net/http"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"
)
var (
_ net.Listener = (*onceCloseListener)(nil)
_ net.Listener = (*protocolListener)(nil)
_ net.Listener = (*wrapperServer)(nil)
_ http.Handler = (*HttpHandler)(nil)
_ net.Conn = (*readerConn)(nil)
)
type testStuff struct {
ctx context.Context
serverAddr net.Addr
wrappedLis Listener
}
// wstunnel+gRPC is tested in internal/tool/grpctool/max_conn_age_wstunnel_test.go.
func TestClientServerVariousBufferSizes(t *testing.T) {
t.Run("1kbyte", func(t *testing.T) {
testHarness(t, func(t *testing.T, stuff *testStuff) {
testEcho(t, 1024, 128, stuff)
})
})
t.Run("64kbyte", func(t *testing.T) {
testHarness(t, func(t *testing.T, stuff *testStuff) {
testEcho(t, 64*1024, 128, stuff)
})
})
t.Run("128kbyte", func(t *testing.T) {
testHarness(t, func(t *testing.T, stuff *testStuff) {
testEcho(t, 128*1024, 128, stuff)
})
})
}
func testEcho(t *testing.T, writeSize, writeCount int, stuff *testStuff) {
var serverWg sync.WaitGroup
defer serverWg.Wait()
defer stuff.wrappedLis.Close()
serverWg.Add(1)
go func() {
defer serverWg.Done()
serverConn, err := stuff.wrappedLis.Accept()
if !assert.NoError(t, err) {
return
}
defer serverConn.Close()
_, err = io.Copy(serverConn, serverConn) // echo
assert.NoError(t, err)
}()
conn, _, err := Dial(stuff.ctx, fmt.Sprintf("ws://%s", stuff.serverAddr.String()), &websocket.DialOptions{}) // nolint: bodyclose
require.NoError(t, err)
defer conn.Close(websocket.StatusNormalClosure, "")
conn.SetReadLimit(1024 * 1024)
// Read and hash data
var clientWg sync.WaitGroup
readHash := fnv.New128()
clientWg.Add(1)
go func() {
defer clientWg.Done()
toRead := int64(writeSize * writeCount)
netConn := websocket.NetConn(stuff.ctx, conn, websocket.MessageBinary)
copied, err := io.Copy(readHash, io.LimitReader(netConn, toRead))
if assert.NoError(t, err) {
assert.Equal(t, toRead, copied)
}
}()
// Generate, hash and write random data
writeHash := fnv.New128()
buf := make([]byte, writeSize)
for i := 0; i < writeCount; i++ {
rand.Read(buf)
writeHash.Write(buf)
connErr := conn.Write(stuff.ctx, websocket.MessageBinary, buf)
if !assert.NoError(t, connErr) {
break
}
}
clientWg.Wait() // wait for client to be done
assert.Equal(t, writeHash.Sum(nil), readHash.Sum(nil))
}
func testHarness(t *testing.T, test func(*testing.T, *testStuff)) {
lis, err := net.Listen("tcp", "localhost:0")
require.NoError(t, err)
defer lis.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wrapper := ListenerWrapper{
ReadLimit: 1024 * 1024,
}
wrappedLis := wrapper.Wrap(lis, false)
ts := &testStuff{
ctx: ctx,
serverAddr: lis.Addr(),
wrappedLis: wrappedLis,
}
defer func() {
assert.NoError(t, wrappedLis.Close()) // stop accepting connections
assert.NoError(t, wrappedLis.Shutdown(ctx)) // wait for running connections
}()
test(t, ts)
}
|