1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package dtls
import (
"context"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/pion/transport/v3/dpipe"
"github.com/pion/transport/v3/test"
"github.com/stretchr/testify/assert"
)
func TestReplayProtection(t *testing.T) { //nolint:cyclop
// Limit runtime in case of deadlocks
lim := test.TimeOut(5 * time.Second)
defer lim.Stop()
// Check for leaking routines
report := test.CheckRoutines(t)
defer report()
c0, c1 := dpipe.Pipe()
c2, c3 := dpipe.Pipe()
conn := []net.Conn{c0, c1, c2, c3}
var wgRoutines sync.WaitGroup
var cntReplays int32 = 1
ctxReplayDone, replayDone := context.WithCancel(context.Background())
replaySendDone := func() {
cnt := atomic.AddInt32(&cntReplays, -1)
if cnt == 0 {
replayDone()
}
}
replayer := func(ca, cb net.Conn) {
defer wgRoutines.Done()
// Man in the middle
for {
b := make([]byte, 2048)
n, rerr := ca.Read(b)
if rerr != nil {
return
}
_, werr := cb.Write(b[:n])
assert.NoError(t, werr)
atomic.AddInt32(&cntReplays, 1)
go func() {
defer replaySendDone()
// Replay bit later
time.Sleep(time.Millisecond)
_, werr := cb.Write(b[:n])
assert.NoError(t, werr)
}()
}
}
wgRoutines.Add(2)
go replayer(conn[1], conn[2])
go replayer(conn[2], conn[1])
ca, cb, err := pipeConn(conn[0], conn[3])
assert.NoError(t, err)
const numMsgs = 10
var received [2][][]byte
for i, c := range []net.Conn{ca, cb} {
i := i
c := c
wgRoutines.Add(1)
atomic.AddInt32(&cntReplays, 1) // Keep locked until the final message
var lastMsgDone sync.Once
go func() {
defer wgRoutines.Done()
for {
b := make([]byte, 2048)
n, rerr := c.Read(b)
if rerr != nil {
return
}
received[i] = append(received[i], b[:n])
if b[0] == numMsgs-1 {
// Final message received
lastMsgDone.Do(func() {
defer replaySendDone()
})
}
}
}()
}
var sent [][]byte
for i := 0; i < numMsgs; i++ {
data := []byte{byte(i)}
sent = append(sent, data)
_, werr := ca.Write(data)
assert.NoError(t, werr)
_, werr = cb.Write(data)
assert.NoError(t, werr)
}
replaySendDone()
<-ctxReplayDone.Done()
time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent
for i := 0; i < 4; i++ {
assert.NoError(t, conn[i].Close())
}
assert.NoError(t, ca.Close())
assert.NoError(t, cb.Close())
wgRoutines.Wait()
for _, r := range received {
assert.Equal(t, sent, r)
}
}
|