1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
|
package turbotunnel
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"time"
)
// RedialPacketConn implements a long-lived net.PacketConn atop a sequence of
// other, transient net.PacketConns. RedialPacketConn creates a new
// net.PacketConn by calling a provided dialContext function. Whenever the
// net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn
// calls the dialContext function again and starts sending and receiving packets
// on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo
// methods return an error only when the dialContext function returns an error.
//
// RedialPacketConn uses static local and remote addresses that are independent
// of those of any dialed net.PacketConn.
type RedialPacketConn struct {
localAddr net.Addr
remoteAddr net.Addr
dialContext func(context.Context) (net.PacketConn, error)
recvQueue chan []byte
sendQueue chan []byte
closed chan struct{}
closeOnce sync.Once
// The first dial error, which causes the clientPacketConn to be
// closed and is returned from future read/write operations. Compare to
// the rerr and werr in io.Pipe.
err atomic.Value
}
// NewQueuePacketConn makes a new RedialPacketConn, with the given static local
// and remote addresses, and dialContext function.
func NewRedialPacketConn(
localAddr, remoteAddr net.Addr,
dialContext func(context.Context) (net.PacketConn, error),
) *RedialPacketConn {
c := &RedialPacketConn{
localAddr: localAddr,
remoteAddr: remoteAddr,
dialContext: dialContext,
recvQueue: make(chan []byte, queueSize),
sendQueue: make(chan []byte, queueSize),
closed: make(chan struct{}),
err: atomic.Value{},
}
go c.dialLoop()
return c
}
// dialLoop repeatedly calls c.dialContext and passes the resulting
// net.PacketConn to c.exchange. It returns only when c is closed or dialContext
// returns an error.
func (c *RedialPacketConn) dialLoop() {
ctx, cancel := context.WithCancel(context.Background())
for {
select {
case <-c.closed:
cancel()
return
default:
}
conn, err := c.dialContext(ctx)
if err != nil {
c.closeWithError(err)
cancel()
return
}
c.exchange(conn)
conn.Close()
}
}
// exchange calls ReadFrom on the given net.PacketConn and places the resulting
// packets in the receive queue, and takes packets from the send queue and calls
// WriteTo on them, making the current net.PacketConn active.
func (c *RedialPacketConn) exchange(conn net.PacketConn) {
readErrCh := make(chan error)
writeErrCh := make(chan error)
go func() {
defer close(readErrCh)
for {
select {
case <-c.closed:
return
case <-writeErrCh:
return
default:
}
var buf [1500]byte
n, _, err := conn.ReadFrom(buf[:])
if err != nil {
readErrCh <- err
return
}
p := make([]byte, n)
copy(p, buf[:])
select {
case c.recvQueue <- p:
default: // OK to drop packets.
}
}
}()
go func() {
defer close(writeErrCh)
for {
select {
case <-c.closed:
return
case <-readErrCh:
return
case p := <-c.sendQueue:
_, err := conn.WriteTo(p, c.remoteAddr)
if err != nil {
writeErrCh <- err
return
}
}
}
}()
select {
case <-readErrCh:
case <-writeErrCh:
}
}
// ReadFrom reads a packet from the currently active net.PacketConn. The
// packet's original remote address is replaced with the RedialPacketConn's own
// remote address.
func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
select {
case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
default:
}
select {
case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
case buf := <-c.recvQueue:
return copy(p, buf), c.remoteAddr, nil
}
}
// WriteTo writes a packet to the currently active net.PacketConn. The addr
// argument is ignored and instead replaced with the RedialPacketConn's own
// remote address.
func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
// addr is ignored.
select {
case <-c.closed:
return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
default:
}
buf := make([]byte, len(p))
copy(buf, p)
select {
case c.sendQueue <- buf:
return len(buf), nil
default:
// Drop the outgoing packet if the send queue is full.
return len(buf), nil
}
}
// closeWithError unblocks pending operations and makes future operations fail
// with the given error. If err is nil, it becomes errClosedPacketConn.
func (c *RedialPacketConn) closeWithError(err error) error {
var once bool
c.closeOnce.Do(func() {
// Store the error to be returned by future read/write
// operations.
if err == nil {
err = errors.New("operation on closed connection")
}
c.err.Store(err)
close(c.closed)
once = true
})
if !once {
return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
}
return nil
}
// Close unblocks pending operations and makes future operations fail with a
// "closed connection" error.
func (c *RedialPacketConn) Close() error {
return c.closeWithError(nil)
}
// LocalAddr returns the localAddr value that was passed to NewRedialPacketConn.
func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr }
func (c *RedialPacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
func (c *RedialPacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }
|