1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
|
package testing
import (
"context"
"net"
"sync"
)
type ConnWaitGroup struct {
DialFunc func(context.Context, string, string) (net.Conn, error)
sync.WaitGroup
}
func (g *ConnWaitGroup) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c, err := g.DialFunc(ctx, network, address)
if err != nil {
return nil, err
}
g.Add(1)
return &groupConn{Conn: c, group: g}, nil
}
type groupConn struct {
net.Conn
group *ConnWaitGroup
once sync.Once
}
func (c *groupConn) Close() error {
defer c.once.Do(c.group.Done)
return c.Conn.Close()
}
|