File: tcp_mux_multi_test.go

package info (click to toggle)
golang-github-pion-ice.v2 2.3.1-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 684 kB
  • sloc: makefile: 5
file content (128 lines) | stat: -rw-r--r-- 3,804 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
//go:build !js
// +build !js

package ice

import (
	"io"
	"net"
	"testing"

	"github.com/pion/logging"
	"github.com/pion/stun"
	"github.com/pion/transport/v2/test"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestMultiTCPMux_Recv(t *testing.T) {
	for name, bufSize := range map[string]int{
		"no buffer":    0,
		"buffered 4MB": 4 * 1024 * 1024,
	} {
		bufSize := bufSize
		t.Run(name, func(t *testing.T) {
			report := test.CheckRoutines(t)
			defer report()

			loggerFactory := logging.NewDefaultLoggerFactory()

			var muxInstances []TCPMux
			for i := 0; i < 3; i++ {
				listener, err := net.ListenTCP("tcp", &net.TCPAddr{
					IP:   net.IP{127, 0, 0, 1},
					Port: 0,
				})
				require.NoError(t, err, "error starting listener")
				defer func() {
					_ = listener.Close()
				}()

				tcpMux := NewTCPMuxDefault(TCPMuxParams{
					Listener:        listener,
					Logger:          loggerFactory.NewLogger("ice"),
					ReadBufferSize:  20,
					WriteBufferSize: bufSize,
				})
				muxInstances = append(muxInstances, tcpMux)
				require.NotNil(t, tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
			}

			multiMux := NewMultiTCPMuxDefault(muxInstances...)
			defer func() {
				_ = multiMux.Close()
			}()

			pktConns, err := multiMux.GetAllConns("myufrag", false, net.IP{127, 0, 0, 1})
			require.NoError(t, err, "error retrieving muxed connection for ufrag")

			for _, pktConn := range pktConns {
				defer func() {
					_ = pktConn.Close()
				}()
				conn, err := net.DialTCP("tcp", nil, pktConn.LocalAddr().(*net.TCPAddr))
				require.NoError(t, err, "error dialing test tcp connection")

				msg := stun.New()
				msg.Type = stun.MessageType{Method: stun.MethodBinding, Class: stun.ClassRequest}
				msg.Add(stun.AttrUsername, []byte("myufrag:otherufrag"))
				msg.Encode()

				n, err := writeStreamingPacket(conn, msg.Raw)
				require.NoError(t, err, "error writing tcp stun packet")

				recv := make([]byte, n)
				n2, rAddr, err := pktConn.ReadFrom(recv)
				require.NoError(t, err, "error receiving data")
				assert.Equal(t, conn.LocalAddr(), rAddr, "remote tcp address mismatch")
				assert.Equal(t, n, n2, "received byte size mismatch")
				assert.Equal(t, msg.Raw, recv, "received bytes mismatch")

				// check echo response
				n, err = pktConn.WriteTo(recv, conn.LocalAddr())
				require.NoError(t, err, "error writing echo stun packet")
				recvEcho := make([]byte, n)
				n3, err := readStreamingPacket(conn, recvEcho)
				require.NoError(t, err, "error receiving echo data")
				assert.Equal(t, n2, n3, "received byte size mismatch")
				assert.Equal(t, msg.Raw, recvEcho, "received bytes mismatch")
			}
		})
	}
}

func TestMultiTCPMux_NoDeadlockWhenClosingUnusedPacketConn(t *testing.T) {
	report := test.CheckRoutines(t)
	defer report()

	loggerFactory := logging.NewDefaultLoggerFactory()

	var tcpMuxInstances []TCPMux
	for i := 0; i < 3; i++ {
		listener, err := net.ListenTCP("tcp", &net.TCPAddr{
			IP:   net.IP{127, 0, 0, 1},
			Port: 0,
		})
		require.NoError(t, err, "error starting listener")
		defer func() {
			_ = listener.Close()
		}()

		tcpMux := NewTCPMuxDefault(TCPMuxParams{
			Listener:       listener,
			Logger:         loggerFactory.NewLogger("ice"),
			ReadBufferSize: 20,
		})
		tcpMuxInstances = append(tcpMuxInstances, tcpMux)
	}
	muxMulti := NewMultiTCPMuxDefault(tcpMuxInstances...)

	_, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1})
	require.NoError(t, err, "error getting conn by ufrag")

	require.NoError(t, muxMulti.Close(), "error closing tcpMux")

	conn, err := muxMulti.GetAllConns("test", false, net.IP{127, 0, 0, 1})
	assert.Nil(t, conn, "should receive nil because mux is closed")
	assert.Equal(t, io.ErrClosedPipe, err, "should receive error because mux is closed")
}