1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
|
package accesscontrol_test
import (
"bytes"
"io"
"testing"
"github.com/charmbracelet/wish/accesscontrol"
"github.com/charmbracelet/wish/testsession"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
)
const out = "hello world"
func TestMiddleware(t *testing.T) {
requireEmpty := func(tb testing.TB, s string) {
tb.Helper()
if s != "" {
tb.Errorf("expected output to be empty, got %q", s)
}
}
requireOutput := func(tb testing.TB, s string) {
tb.Helper()
if out != s {
t.Errorf("expected %q, got %q", out, s)
}
}
t.Run("no allowed cmds no cmd", func(t *testing.T) {
var b bytes.Buffer
if err := setup(t, &b).Run(""); err != nil {
t.Error(err)
}
requireOutput(t, b.String())
})
t.Run("no allowed cmds with cmd", func(t *testing.T) {
var b bytes.Buffer
if err := setup(t, &b).Run("echo"); err == nil {
t.Errorf("should have errored")
}
requireEmpty(t, b.String())
})
t.Run("allowed cmds no cmd", func(t *testing.T) {
var b bytes.Buffer
if err := setup(t, &b, "echo").Run(""); err != nil {
t.Error(err)
}
requireOutput(t, b.String())
})
t.Run("allowed cmds with allowed cmd", func(t *testing.T) {
var b bytes.Buffer
if err := setup(t, &b, "echo").Run("echo"); err != nil {
t.Error(err)
}
requireOutput(t, b.String())
})
t.Run("allowed cmds with disallowed cmd", func(t *testing.T) {
var b bytes.Buffer
if err := setup(t, &b, "echo").Run("cat"); err == nil {
t.Error(err)
}
requireEmpty(t, b.String())
})
t.Run("allowed cmds with allowed cmd followed disallowed cmd", func(t *testing.T) {
var b bytes.Buffer
if err := setup(t, &b, "echo").Run("cat echo"); err == nil {
t.Error(err)
}
requireEmpty(t, b.String())
})
}
func setup(t *testing.T, w io.Writer, allowedCmds ...string) *gossh.Session {
session, _, cleanup := testsession.New(t, &ssh.Server{
Handler: accesscontrol.Middleware(allowedCmds...)(func(s ssh.Session) {
s.Write([]byte(out))
}),
}, nil)
t.Cleanup(cleanup)
session.Stdout = w
return session
}
|