1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
package self_test
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"sync/atomic"
"testing"
"time"
"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/stretchr/testify/require"
)
func TestConnectionMigration(t *testing.T) {
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer ln.Close()
tr1 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
defer tr1.Close()
tr2 := &quic.Transport{Conn: newUDPConnLocalhost(t)}
defer tr2.Close()
var packetsPath1, packetsPath2 atomic.Int64
const rtt = 5 * time.Millisecond
proxy := quicproxy.Proxy{
Conn: newUDPConnLocalhost(t),
ServerAddr: ln.Addr().(*net.UDPAddr),
DelayPacket: func(dir quicproxy.Direction, from, to net.Addr, _ []byte) time.Duration {
var port int
switch dir {
case quicproxy.DirectionIncoming:
port = from.(*net.UDPAddr).Port
case quicproxy.DirectionOutgoing:
port = to.(*net.UDPAddr).Port
}
switch port {
case tr1.Conn.LocalAddr().(*net.UDPAddr).Port:
packetsPath1.Add(1)
case tr2.Conn.LocalAddr().(*net.UDPAddr).Port:
packetsPath2.Add(1)
default:
fmt.Println("address not found", from)
}
return rtt / 2
},
}
require.NoError(t, proxy.Start())
defer proxy.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := tr1.Dial(ctx, proxy.LocalAddr(), getTLSClientConfig(), getQuicConfig(nil))
require.NoError(t, err)
defer conn.CloseWithError(0, "")
sconn, err := ln.Accept(ctx)
require.NoError(t, err)
defer sconn.CloseWithError(0, "")
sendAndReceiveFile := func(t *testing.T) {
t.Helper()
str, err := conn.OpenUniStream()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
defer close(errChan)
sstr, err := sconn.AcceptUniStream(ctx)
if err != nil {
errChan <- fmt.Errorf("accepting stream: %w", err)
return
}
data, err := io.ReadAll(sstr)
if err != nil {
errChan <- fmt.Errorf("reading stream data: %w", err)
return
}
if !bytes.Equal(data, PRData) {
errChan <- errors.New("unexpected data")
}
}()
_, err = str.Write(PRData)
require.NoError(t, err)
require.NoError(t, str.Close())
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timed out waiting for data")
}
}
sendAndReceiveFile(t) // stream 2
require.NotZero(t, packetsPath1.Load())
require.Zero(t, packetsPath2.Load())
// probing the path causes a few packets to be sent on path 2
path, err := conn.AddPath(tr2)
require.NoError(t, err)
require.ErrorIs(t, path.Switch(), quic.ErrPathNotValidated)
require.NoError(t, path.Probe(ctx))
require.Less(t, int(packetsPath2.Load()), 5)
// make sure that no more packets are sent on path 2 before switching to the path
c2 := packetsPath2.Load()
sendAndReceiveFile(t) // stream 6
require.Equal(t, packetsPath2.Load(), c2)
time.Sleep(3 * rtt) // wait for ACKs
// now switch and make sure that no packets are sent on path 1
require.NoError(t, path.Switch())
sendAndReceiveFile(t) // stream 10
c1 := packetsPath1.Load()
require.Equal(t, c1, packetsPath1.Load())
require.Greater(t, packetsPath2.Load(), c2)
require.Equal(t, tr2.Conn.LocalAddr(), conn.LocalAddr())
// switch back to the handshake path
time.Sleep(3 * rtt) // wait for ACKs
c1BeforeSwitch := packetsPath1.Load()
c2BeforeSwitch := packetsPath2.Load()
path2, err := conn.AddPath(tr1)
require.NoError(t, err)
require.NoError(t, path2.Probe(ctx))
time.Sleep(3 * rtt) // wait for ACKs
require.NoError(t, path2.Switch())
sendAndReceiveFile(t) // stream 14
require.Greater(t, packetsPath1.Load(), c1BeforeSwitch)
// some path probing might have happened
require.Less(t, int(packetsPath2.Load()-c2BeforeSwitch), 20)
require.Equal(t, tr1.Conn.LocalAddr(), conn.LocalAddr())
}
|