File: session_test.go

package info (click to toggle)
golang-github-microsoft-dev-tunnels 0.0.25-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,988 kB
  • sloc: cs: 9,969; java: 2,767; javascript: 328; xml: 186; makefile: 5
file content (149 lines) | stat: -rw-r--r-- 3,135 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package tunnelssh

import (
	"context"
	"errors"
	"testing"

	"golang.org/x/crypto/ssh"
)

type mockActivator struct {
	ActivateFunc func(context.Context, *Session) error
}

func (m *mockActivator) Activate(ctx context.Context, s *Session) error {
	return m.ActivateFunc(ctx, s)
}

func TestSessionActivate(t *testing.T) {
	session := NewSession(nil)
	ma := &mockActivator{
		ActivateFunc: func(ctx context.Context, s *Session) error {
			if s != session {
				return errors.New("invalid session")
			}
			return nil
		},
	}
	if err := session.Activate(context.Background(), ma); err != nil {
		t.Errorf("session.Activate() error = %v", err)
	}
}

type mockNewChannel struct {
	AcceptFunc      func() (ssh.Channel, <-chan *ssh.Request, error)
	ChannelTypeFunc func() string
	RejectFunc      func(ssh.RejectionReason, string) error
	ExtraDataFunc   func() []byte
}

func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
	return m.AcceptFunc()
}

func (m *mockNewChannel) ExtraData() []byte {
	return m.ExtraDataFunc()
}

func (m *mockNewChannel) ChannelType() string {
	return m.ChannelTypeFunc()
}

func (m *mockNewChannel) Reject(reason ssh.RejectionReason, message string) error {
	return m.RejectFunc(reason, message)
}

func TestSessionChannels(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	channelType := "testChannel"
	session := NewSession(nil)
	var n int
	session.AddChannelHandler(channelType, func(ctx context.Context, newChannel ssh.NewChannel) {
		n++
	})

	chans := make(chan ssh.NewChannel)
	go session.handleChannels(ctx, chans)

	// successful channel
	chans <- &mockNewChannel{
		ChannelTypeFunc: func() string {
			return channelType
		},
	}

	// rejected channel
	called := make(chan struct{})
	chans <- &mockNewChannel{
		ChannelTypeFunc: func() string {
			return "otherChannel"
		},
		RejectFunc: func(reason ssh.RejectionReason, message string) error {
			close(called)
			return nil
		},
	}

	if n != 1 {
		t.Errorf("n = %d, want 1", n)
	}

	// wait for the channel to be rejected
	<-called
}

type mockSSHRequest struct {
	TypeFunc  func() string
	ReplyFunc func(bool, []byte) error
}

func (m *mockSSHRequest) Type() string {
	return m.TypeFunc()
}

func (m *mockSSHRequest) Reply(ok bool, message []byte) error {
	return m.ReplyFunc(ok, message)
}

func TestSessionRequests(t *testing.T) {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	session := NewSession(nil)
	var n int
	session.AddRequestHandler("testRequest", func(ctx context.Context, req SSHRequest) {
		n++
	})

	reqs := make(chan SSHRequest)
	go session.handleRequests(ctx, reqs)

	reqs <- &mockSSHRequest{
		TypeFunc: func() string {
			return "testRequest"
		},
	}
	if n != 1 {
		t.Errorf("n = %d, want 1", n)
	}

	called := make(chan struct{})
	reqs <- &mockSSHRequest{
		TypeFunc: func() string {
			return "otherRequest"
		},
		ReplyFunc: func(ok bool, message []byte) error {
			close(called)
			return nil
		},
	}

	// wait for the request to be rejected
	<-called
}