1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
|
//+build linux
package vsock
import (
"errors"
"syscall"
"testing"
"github.com/google/go-cmp/cmp"
"golang.org/x/sys/unix"
)
func Test_dialLinuxErrorClosesFile(t *testing.T) {
var closed bool
cfd := &testConnFD{
// Track when fd.Close is called.
close: func() error {
closed = true
return nil
},
// Always return an error on connect.
connect: func(sa unix.Sockaddr) error {
return errors.New("error during connect")
},
}
if _, err := dialLinux(cfd, 0, 0); err == nil {
t.Fatal("expected an error, but none occurred")
}
if diff := cmp.Diff(true, closed); diff != "" {
t.Fatalf("unexpected closed value (-want +got):\n%s", diff)
}
}
func Test_dialLinuxFull(t *testing.T) {
const (
localCID uint32 = 3
localPort uint32 = 1024
remoteCID uint32 = Host
remotePort uint32 = 2048
)
lsa := &unix.SockaddrVM{
CID: localCID,
Port: localPort,
}
rsa := &unix.SockaddrVM{
CID: remoteCID,
Port: remotePort,
}
var (
closed bool
closedRead bool
closedWrite bool
syscallConn bool
)
cfd := &testConnFD{
connect: func(sa unix.Sockaddr) error {
if diff := cmp.Diff(rsa, sa.(*unix.SockaddrVM), cmp.AllowUnexported(*rsa)); diff != "" {
t.Fatalf("unexpected connect sockaddr (-want +got):\n%s", diff)
}
return nil
},
getsockname: func() (unix.Sockaddr, error) {
return lsa, nil
},
setNonblocking: func(name string) error {
if diff := cmp.Diff(name, "vsock:vm(3):1024"); diff != "" {
t.Fatalf("unexpected non-blocking file name (-want +got):\n%s", diff)
}
return nil
},
close: func() error {
closed = true
return nil
},
shutdown: func(how int) error {
switch how {
case unix.SHUT_RD:
closedRead = true
case unix.SHUT_WR:
closedWrite = true
default:
t.Fatalf("unexpected how constant in shutdown: %d", how)
}
return nil
},
syscallConn: func() (syscall.RawConn, error) {
// No need to really do anything.
syscallConn = true
return nil, nil
},
}
c, err := dialLinux(cfd, remoteCID, remotePort)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
localAddr := &Addr{
ContextID: localCID,
Port: localPort,
}
if diff := cmp.Diff(localAddr, c.LocalAddr()); diff != "" {
t.Fatalf("unexpected local address (-want +got):\n%s", diff)
}
remoteAddr := &Addr{
ContextID: remoteCID,
Port: remotePort,
}
if diff := cmp.Diff(remoteAddr, c.RemoteAddr()); diff != "" {
t.Fatalf("unexpected remote address (-want +got):\n%s", diff)
}
if _, err := c.SyscallConn(); err != nil {
t.Fatalf("failed to test syscall conn: %v", err)
}
if !syscallConn {
t.Fatal("expected call to SyscallConn, but none occurred")
}
// Verify Close/Shutdown plumbing.
funcs := []func() error{
c.Close,
c.CloseRead,
c.CloseWrite,
}
for i, fn := range funcs {
if err := fn(); err != nil {
t.Fatalf("failed to invoke function %d: %v", i, err)
}
}
if !closed || !closedRead || !closedWrite {
t.Fatalf("expected calls to Close (%t), CloseRead (%t), and CloseWrite (%t)",
closed, closedRead, closedWrite)
}
}
|