1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
//go:build go1.24
package quic
import (
"fmt"
mrand "math/rand/v2"
"slices"
"strings"
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func randomDomainName(length int) string {
const alphabet = "abcdefghijklmnopqrstuvwxyz"
b := make([]byte, length)
for i := range b {
if i > 0 && i < length-1 && mrand.IntN(5) == 0 && b[i-1] != '.' {
b[i] = '.'
} else {
b[i] = alphabet[mrand.IntN(len(alphabet))]
}
}
return string(b)
}
func TestInitialCryptoStreamClientRandomizedSizes(t *testing.T) {
skipIfDisableScramblingEnvSet(t)
for i := range 100 {
t.Run(fmt.Sprintf("run %d", i), func(t *testing.T) {
var serverName string
if mrand.Int()%4 > 0 {
serverName = randomDomainName(6 + mrand.IntN(20))
}
var clientHello []byte
if serverName == "" || !strings.Contains(serverName, ".") || mrand.Int()%2 == 0 {
t.Logf("using a ClientHello without ECH, hostname: %q", serverName)
clientHello = getClientHello(t, serverName)
} else {
t.Logf("using a ClientHello with ECH, hostname: %q", serverName)
clientHello = getClientHelloWithECH(t, serverName)
}
testInitialCryptoStreamClientRandomizedSizes(t, clientHello, serverName)
})
}
}
func testInitialCryptoStreamClientRandomizedSizes(t *testing.T, clientHello []byte, expectedServerName string) {
str := newInitialCryptoStream(true)
b := slices.Clone(clientHello)
for len(b) > 0 {
n := min(len(b), mrand.IntN(2*len(b)))
_, err := str.Write(b[:n])
require.NoError(t, err)
b = b[n:]
}
require.True(t, str.HasData())
_, err := str.Write([]byte("foobar"))
require.NoError(t, err)
segments := make(map[protocol.ByteCount][]byte)
var frames []*wire.CryptoFrame
for str.HasData() {
// fmt.Println("popping a frame")
var maxSize protocol.ByteCount
if mrand.Int()%4 == 0 {
maxSize = protocol.ByteCount(mrand.IntN(512) + 1)
} else {
maxSize = protocol.ByteCount(mrand.IntN(32) + 1)
}
f := str.PopCryptoFrame(maxSize)
if f == nil {
continue
}
frames = append(frames, f)
require.LessOrEqual(t, f.Length(protocol.Version1), maxSize)
}
t.Logf("received %d frames", len(frames))
for _, f := range frames {
t.Logf("offset %d: %d bytes", f.Offset, len(f.Data))
if expectedServerName != "" {
require.NotContainsf(t, string(f.Data), expectedServerName, "frame at offset %d contains the server name", f.Offset)
}
segments[f.Offset] = f.Data
}
reassembled := reassembleCryptoData(t, segments)
require.Equal(t, append(clientHello, []byte("foobar")...), reassembled)
if expectedServerName != "" {
require.Contains(t, string(reassembled), expectedServerName)
}
}
|