File: mtu_test.go

package info (click to toggle)
golang-github-lucas-clemente-quic-go 0.54.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,312 kB
  • sloc: sh: 54; makefile: 7
file content (195 lines) | stat: -rw-r--r-- 5,259 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
package self_test

import (
	"bytes"
	"context"
	"fmt"
	"io"
	"net"
	"sync"
	"testing"
	"time"

	"github.com/quic-go/quic-go"
	quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
	"github.com/quic-go/quic-go/internal/protocol"
	"github.com/quic-go/quic-go/logging"

	"github.com/stretchr/testify/require"
)

func TestInitialPacketSize(t *testing.T) {
	server := newUDPConnLocalhost(t)
	client := newUDPConnLocalhost(t)

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	done := make(chan struct{})
	go func() {
		defer close(done)
		quic.Dial(ctx, client, server.LocalAddr(), getTLSClientConfig(), getQuicConfig(&quic.Config{
			InitialPacketSize: 1337,
		}))
	}()

	buf := make([]byte, 2000)
	n, _, err := server.ReadFrom(buf)
	require.NoError(t, err)
	require.Equal(t, 1337, n)

	cancel()
	<-done
}

func TestPathMTUDiscovery(t *testing.T) {
	rtt := scaleDuration(5 * time.Millisecond)
	const mtu = 1400

	ln, err := quic.Listen(
		newUDPConnLocalhost(t),
		getTLSConfig(),
		getQuicConfig(&quic.Config{
			InitialPacketSize:       1234,
			DisablePathMTUDiscovery: true,
			EnableDatagrams:         true,
		}),
	)
	require.NoError(t, err)
	defer ln.Close()

	serverErrChan := make(chan error, 1)
	go func() {
		conn, err := ln.Accept(context.Background())
		if err != nil {
			serverErrChan <- err
			return
		}
		str, err := conn.AcceptStream(context.Background())
		if err != nil {
			serverErrChan <- err
			return
		}
		defer str.Close()
		if _, err := io.Copy(str, str); err != nil {
			serverErrChan <- err
			return
		}
	}()

	var mx sync.Mutex
	var maxPacketSizeServer int
	var clientPacketSizes []int
	proxy := &quicproxy.Proxy{
		Conn:        newUDPConnLocalhost(t),
		ServerAddr:  ln.Addr().(*net.UDPAddr),
		DelayPacket: func(quicproxy.Direction, net.Addr, net.Addr, []byte) time.Duration { return rtt / 2 },
		DropPacket: func(dir quicproxy.Direction, _, _ net.Addr, packet []byte) bool {
			if len(packet) > mtu {
				return true
			}
			mx.Lock()
			defer mx.Unlock()
			switch dir {
			case quicproxy.DirectionIncoming:
				clientPacketSizes = append(clientPacketSizes, len(packet))
			case quicproxy.DirectionOutgoing:
				if len(packet) > maxPacketSizeServer {
					maxPacketSizeServer = len(packet)
				}
			}
			return false
		},
	}
	require.NoError(t, proxy.Start())
	defer proxy.Close()

	// Make sure to use v4-only socket here.
	// We can't reliably set the DF bit on dual-stack sockets on older versions of macOS (before Sequoia).
	tr := &quic.Transport{Conn: newUDPConnLocalhost(t)}
	defer tr.Close()

	var mtus []logging.ByteCount
	conn, err := tr.Dial(
		context.Background(),
		proxy.LocalAddr(),
		getTLSClientConfig(),
		getQuicConfig(&quic.Config{
			InitialPacketSize: protocol.MinInitialPacketSize,
			EnableDatagrams:   true,
			Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
				return &logging.ConnectionTracer{
					UpdatedMTU: func(mtu logging.ByteCount, _ bool) { mtus = append(mtus, mtu) },
				}
			},
		}),
	)
	require.NoError(t, err)
	defer conn.CloseWithError(0, "")

	err = conn.SendDatagram(make([]byte, 2000))
	require.Error(t, err)
	var datagramErr *quic.DatagramTooLargeError
	require.ErrorAs(t, err, &datagramErr)
	initialMaxDatagramSize := datagramErr.MaxDatagramPayloadSize

	str, err := conn.OpenStream()
	require.NoError(t, err)

	clientErrChan := make(chan error, 1)
	go func() {
		data, err := io.ReadAll(str)
		if err != nil {
			clientErrChan <- err
			return
		}
		if !bytes.Equal(data, PRDataLong) {
			clientErrChan <- fmt.Errorf("echoed data doesn't match: %x", data)
			return
		}
		clientErrChan <- nil
	}()

	_, err = str.Write(PRDataLong)
	require.NoError(t, err)
	str.Close()

	select {
	case err := <-clientErrChan:
		require.NoError(t, err)
	case err := <-serverErrChan:
		t.Fatalf("server error: %v", err)
	case <-time.After(20 * time.Second):
		t.Fatal("timeout")
	}

	err = conn.SendDatagram(make([]byte, 2000))
	require.Error(t, err)
	require.ErrorAs(t, err, &datagramErr)
	finalMaxDatagramSize := datagramErr.MaxDatagramPayloadSize

	mx.Lock()
	defer mx.Unlock()
	require.NotEmpty(t, mtus)

	maxPacketSizeClient := int(mtus[len(mtus)-1])
	t.Logf("max client packet size: %d, MTU: %d", maxPacketSizeClient, mtu)
	t.Logf("max datagram size: initial: %d, final: %d", initialMaxDatagramSize, finalMaxDatagramSize)
	t.Logf("max server packet size: %d, MTU: %d", maxPacketSizeServer, mtu)

	require.GreaterOrEqual(t, maxPacketSizeClient, mtu-25)
	const maxDiff = 40 // this includes the 21 bytes for the short header, 16 bytes for the encryption tag, and framing overhead
	require.GreaterOrEqual(t, int(initialMaxDatagramSize), protocol.MinInitialPacketSize-maxDiff)
	require.GreaterOrEqual(t, int(finalMaxDatagramSize), maxPacketSizeClient-maxDiff)
	// MTU discovery was disabled on the server side
	require.Equal(t, 1234, maxPacketSizeServer)

	var numPacketsLargerThanDiscoveredMTU int
	for _, s := range clientPacketSizes {
		if s > maxPacketSizeClient {
			numPacketsLargerThanDiscoveredMTU++
		}
	}
	// The client shouldn't have sent any packets larger than the MTU it discovered,
	// except for at most one MTU probe packet.
	require.LessOrEqual(t, numPacketsLargerThanDiscoveredMTU, 1)
}