File: local_port_forwarder_test.go

package info (click to toggle)
golang-github-microsoft-dev-tunnels 0.0.25-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,988 kB
  • sloc: cs: 9,969; java: 2,767; javascript: 328; xml: 186; makefile: 5
file content (105 lines) | stat: -rw-r--r-- 2,389 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package tunnelssh

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"testing"
	"time"

	"github.com/microsoft/dev-tunnels/go/tunnels/ssh/messages"
)

type mockChannelOpener struct {
	openChannelFunc func(string, string, int, string, int) (io.ReadWriteCloser, error)
}

func (m *mockChannelOpener) openChannel(
	channelType string,
	originIP string,
	originPort int,
	host string,
	port int,
) (io.ReadWriteCloser, error) {
	return m.openChannelFunc(channelType, originIP, originPort, host, port)
}

type mockChannel struct {
	*bytes.Buffer
}

func (m *mockChannel) Close() error {
	return nil
}

func TestLocalPortForwarderPortForwardChannelType(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	streamData := "stream-data"
	host := "127.0.0.1"
	port := 8080

	stream := &mockChannel{bytes.NewBufferString(streamData)}
	co := &mockChannelOpener{
		openChannelFunc: func(channelType, originIP string, originPort int, host string, port int) (io.ReadWriteCloser, error) {
			if channelType != messages.PortForwardChannelType {
				return nil, fmt.Errorf("expected channel type %s, got %s", messages.PortForwardChannelType, channelType)
			}
			return stream, nil
		},
	}

	lpf := newLocalPortForwarder(co, messages.PortForwardChannelType, host, port)
	done := make(chan error, 2)

	go func() {
		done <- lpf.startForwarding(ctx)
	}()

	go func() {
		var conn net.Conn

		// We retry DialTimeout in a loop to deal with a race in forwarder startup.
		for tries := 0; conn == nil && tries < 2; tries++ {
			conn, _ = net.DialTimeout("tcp", fmt.Sprintf(":%d", port), 2*time.Second)
			if conn == nil {
				time.Sleep(1 * time.Second)
			}
		}
		if conn == nil {
			done <- errors.New("failed to connect to forwarded port")
			return
		}

		b := make([]byte, len(streamData))
		if _, err := conn.Read(b); err != nil && err != io.EOF {
			done <- fmt.Errorf("reading stream: %w", err)
			return
		}
		if string(b) != streamData {
			done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
			return
		}

		if _, err := conn.Write([]byte("new-data")); err != nil {
			done <- fmt.Errorf("writing to stream: %w", err)
			return
		}

		done <- nil
	}()

	select {
	case err := <-done:
		if err != nil {
			t.Errorf("Unexpected error: %v", err)
		}
	}
}