File: replayprotection_test.go

package info (click to toggle)
golang-github-pion-dtls-v3 3.0.7-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 2,124 kB
  • sloc: makefile: 4
file content (125 lines) | stat: -rw-r--r-- 2,561 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT

package dtls

import (
	"context"
	"net"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"github.com/pion/transport/v3/dpipe"
	"github.com/pion/transport/v3/test"
	"github.com/stretchr/testify/assert"
)

func TestReplayProtection(t *testing.T) { //nolint:cyclop
	// Limit runtime in case of deadlocks
	lim := test.TimeOut(5 * time.Second)
	defer lim.Stop()

	// Check for leaking routines
	report := test.CheckRoutines(t)
	defer report()

	c0, c1 := dpipe.Pipe()
	c2, c3 := dpipe.Pipe()
	conn := []net.Conn{c0, c1, c2, c3}

	var wgRoutines sync.WaitGroup
	var cntReplays int32 = 1

	ctxReplayDone, replayDone := context.WithCancel(context.Background())

	replaySendDone := func() {
		cnt := atomic.AddInt32(&cntReplays, -1)
		if cnt == 0 {
			replayDone()
		}
	}

	replayer := func(ca, cb net.Conn) {
		defer wgRoutines.Done()
		// Man in the middle
		for {
			b := make([]byte, 2048)
			n, rerr := ca.Read(b)
			if rerr != nil {
				return
			}
			_, werr := cb.Write(b[:n])
			assert.NoError(t, werr)

			atomic.AddInt32(&cntReplays, 1)
			go func() {
				defer replaySendDone()
				// Replay bit later
				time.Sleep(time.Millisecond)
				_, werr := cb.Write(b[:n])
				assert.NoError(t, werr)
			}()
		}
	}
	wgRoutines.Add(2)
	go replayer(conn[1], conn[2])
	go replayer(conn[2], conn[1])

	ca, cb, err := pipeConn(conn[0], conn[3])
	assert.NoError(t, err)

	const numMsgs = 10

	var received [2][][]byte
	for i, c := range []net.Conn{ca, cb} {
		i := i
		c := c
		wgRoutines.Add(1)
		atomic.AddInt32(&cntReplays, 1) // Keep locked until the final message
		var lastMsgDone sync.Once
		go func() {
			defer wgRoutines.Done()
			for {
				b := make([]byte, 2048)
				n, rerr := c.Read(b)
				if rerr != nil {
					return
				}
				received[i] = append(received[i], b[:n])
				if b[0] == numMsgs-1 {
					// Final message received
					lastMsgDone.Do(func() {
						defer replaySendDone()
					})
				}
			}
		}()
	}

	var sent [][]byte
	for i := 0; i < numMsgs; i++ {
		data := []byte{byte(i)}
		sent = append(sent, data)
		_, werr := ca.Write(data)
		assert.NoError(t, werr)
		_, werr = cb.Write(data)
		assert.NoError(t, werr)
	}

	replaySendDone()
	<-ctxReplayDone.Done()
	time.Sleep(10 * time.Millisecond) // Ensure all replayed packets are sent

	for i := 0; i < 4; i++ {
		assert.NoError(t, conn[i].Close())
	}
	assert.NoError(t, ca.Close())
	assert.NoError(t, cb.Close())
	wgRoutines.Wait()

	for _, r := range received {
		assert.Equal(t, sent, r)
	}
}