1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
package gocql
import (
"context"
"net"
"strconv"
"sync"
"testing"
"time"
)
type OneConnTestServer struct {
Err error
Addr net.IP
Port int
listener net.Listener
acceptChan chan struct{}
mu sync.Mutex
closed bool
}
func NewOneConnTestServer() (*OneConnTestServer, error) {
lstn, err := net.Listen("tcp4", "localhost:0")
if err != nil {
return nil, err
}
addr, port := parseAddressPort(lstn.Addr().String())
return &OneConnTestServer{
listener: lstn,
acceptChan: make(chan struct{}),
Addr: addr,
Port: port,
}, nil
}
func (c *OneConnTestServer) Accepted() chan struct{} {
return c.acceptChan
}
func (c *OneConnTestServer) Close() {
c.lockedClose()
}
func (c *OneConnTestServer) Serve() {
conn, err := c.listener.Accept()
c.Err = err
if conn != nil {
conn.Close()
}
c.lockedClose()
}
func (c *OneConnTestServer) lockedClose() {
c.mu.Lock()
defer c.mu.Unlock()
if !c.closed {
close(c.acceptChan)
c.listener.Close()
c.closed = true
}
}
func parseAddressPort(hostPort string) (net.IP, int) {
host, portStr, err := net.SplitHostPort(hostPort)
if err != nil {
return net.ParseIP(""), 0
}
port, _ := strconv.Atoi(portStr)
return net.ParseIP(host), port
}
func testConnErrorHandler(t *testing.T) ConnErrorHandler {
return connErrorHandlerFn(func(conn *Conn, err error, closed bool) {
t.Errorf("in connection handler: %v", err)
})
}
func assertConnectionEventually(t *testing.T, wait time.Duration, srvr *OneConnTestServer) {
ctx, cancel := context.WithTimeout(context.Background(), wait)
defer cancel()
select {
case <-ctx.Done():
if ctx.Err() != nil {
t.Errorf("waiting for connection: %v", ctx.Err())
}
case <-srvr.Accepted():
if srvr.Err != nil {
t.Errorf("accepting connection: %v", srvr.Err)
}
}
}
func TestSession_connect_WithNoTranslator(t *testing.T) {
srvr, err := NewOneConnTestServer()
assertNil(t, "error when creating tcp server", err)
defer srvr.Close()
session := createTestSession()
defer session.Close()
go srvr.Serve()
Connect(&HostInfo{
connectAddress: srvr.Addr,
port: srvr.Port,
}, session.connCfg, testConnErrorHandler(t), session)
assertConnectionEventually(t, 500*time.Millisecond, srvr)
}
func TestSession_connect_WithTranslator(t *testing.T) {
srvr, err := NewOneConnTestServer()
assertNil(t, "error when creating tcp server", err)
defer srvr.Close()
session := createTestSession()
defer session.Close()
session.cfg.AddressTranslator = staticAddressTranslator(srvr.Addr, srvr.Port)
go srvr.Serve()
// the provided address will be translated
Connect(&HostInfo{
connectAddress: net.ParseIP("10.10.10.10"),
port: 5432,
}, session.connCfg, testConnErrorHandler(t), session)
assertConnectionEventually(t, 500*time.Millisecond, srvr)
}
|