1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
|
package main
import (
"encoding/binary"
"errors"
"log"
"net"
"os"
"sync"
"syscall"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
// DefaultConnTrackTimeout is the default timeout used for UDP connection
// tracking.
DefaultConnTrackTimeout = 90 * time.Second
// UDPBufSize is the buffer size for the UDP proxy
UDPBufSize = 65507
)
// A net.Addr where the IP is split into two fields so you can use it as a key
// in a map:
type connTrackKey struct {
IPHigh uint64
IPLow uint64
Port int
}
func newConnTrackKey(addr *net.UDPAddr) *connTrackKey {
if len(addr.IP) == net.IPv4len {
return &connTrackKey{
IPHigh: 0,
IPLow: uint64(binary.BigEndian.Uint32(addr.IP)),
Port: addr.Port,
}
}
return &connTrackKey{
IPHigh: binary.BigEndian.Uint64(addr.IP[:8]),
IPLow: binary.BigEndian.Uint64(addr.IP[8:]),
Port: addr.Port,
}
}
type connTrackMap map[connTrackKey]*connTrackEntry
// connTrackEntry wraps a UDP connection to provide thread-safe [net.Conn.Write]
// and [net.Conn.Close] operations.
type connTrackEntry struct {
conn *net.UDPConn
lastW time.Time
// This lock should be held before calling Write or Close on the wrapped
// net.UDPConn. Read can be called concurrently to these operations.
//
// Never lock mu without locking UDPProxy.connTrackLock first.
mu sync.Mutex
}
func newConnTrackEntry(conn *net.UDPConn) *connTrackEntry {
return &connTrackEntry{
conn: conn,
mu: sync.Mutex{},
}
}
func (cte *connTrackEntry) lastWrite() time.Time {
cte.mu.Lock()
defer cte.mu.Unlock()
return cte.lastW
}
// UDPProxy is proxy for which handles UDP datagrams. It implements the Proxy
// interface to handle UDP traffic forwarding between the frontend and backend
// addresses.
type UDPProxy struct {
listener *net.UDPConn
frontendAddr *net.UDPAddr
backendAddr *net.UDPAddr
connTrackTable connTrackMap
connTrackLock sync.Mutex
connTrackTimeout time.Duration
ipVer ipVersion
}
// NewUDPProxy creates a new UDPProxy.
func NewUDPProxy(listener *net.UDPConn, backendAddr *net.UDPAddr, ipVer ipVersion) (*UDPProxy, error) {
return &UDPProxy{
listener: listener,
frontendAddr: listener.LocalAddr().(*net.UDPAddr),
backendAddr: backendAddr,
connTrackTable: make(connTrackMap),
connTrackTimeout: DefaultConnTrackTimeout,
ipVer: ipVer,
}, nil
}
func (proxy *UDPProxy) replyLoop(cte *connTrackEntry, serverAddr net.IP, clientAddr *net.UDPAddr, clientKey *connTrackKey) {
defer func() {
proxy.connTrackLock.Lock()
delete(proxy.connTrackTable, *clientKey)
cte.mu.Lock()
proxy.connTrackLock.Unlock()
cte.conn.Close()
}()
var oob []byte
if proxy.ipVer == ip4 {
cm := &ipv4.ControlMessage{Src: serverAddr}
oob = cm.Marshal()
} else {
cm := &ipv6.ControlMessage{Src: serverAddr}
oob = cm.Marshal()
}
readBuf := make([]byte, UDPBufSize)
for {
cte.conn.SetReadDeadline(time.Now().Add(proxy.connTrackTimeout))
again:
read, err := cte.conn.Read(readBuf)
if err != nil {
if err, ok := err.(*net.OpError); ok && errors.Is(err.Err, syscall.ECONNREFUSED) {
// This will happen if the last write failed
// (e.g: nothing is actually listening on the
// proxied port on the container), ignore it
// and continue until DefaultConnTrackTimeout
// expires:
goto again
}
// If the UDP connection is one-sided (i.e. the backend never sends
// replies), the connTrackEntry should not be GC'd until no writes
// happen for proxy.connTrackTimeout.
//
// Since the ReadDeadline is set to proxy.connTrackTimeout, in such
// case, the connTrackEntry will be GC'd at most after 2 * proxy.connTrackTimeout.
if errors.Is(err, os.ErrDeadlineExceeded) && time.Since(cte.lastWrite()) < proxy.connTrackTimeout {
continue
}
return
}
for i := 0; i != read; {
written, _, err := proxy.listener.WriteMsgUDP(readBuf[i:read], oob, clientAddr)
if err != nil {
return
}
i += written
}
}
}
// Run starts forwarding the traffic using UDP.
func (proxy *UDPProxy) Run() {
readBuf := make([]byte, UDPBufSize)
var oob []byte
if proxy.ipVer == ip4 {
oob = ipv4.NewControlMessage(ipv4.FlagDst)
} else {
oob = ipv6.NewControlMessage(ipv6.FlagDst)
}
for {
read, _, _, from, err := proxy.listener.ReadMsgUDP(readBuf, oob)
if err != nil {
// The frontend listener socket might be closed by the signal
// handler. In that case, don't log anything - it's not an error.
if !errors.Is(err, net.ErrClosed) {
log.Printf("Stopping proxy on udp/%v for udp/%v (%s)", proxy.frontendAddr, proxy.backendAddr, err)
}
break
}
fromKey := newConnTrackKey(from)
proxy.connTrackLock.Lock()
cte, hit := proxy.connTrackTable[*fromKey]
if !hit {
proxyConn, err := net.DialUDP("udp", nil, proxy.backendAddr)
if err != nil {
log.Printf("Can't proxy a datagram to udp/%s: %s\n", proxy.backendAddr, err)
proxy.connTrackLock.Unlock()
continue
}
cte = newConnTrackEntry(proxyConn)
proxy.connTrackTable[*fromKey] = cte
daddr, err := readDestFromCmsg(oob, proxy.ipVer)
if err != nil {
log.Printf("Failed to parse control message: %v", err)
proxy.connTrackLock.Unlock()
continue
}
go proxy.replyLoop(cte, daddr, from, fromKey)
}
cte.mu.Lock()
proxy.connTrackLock.Unlock()
cte.conn.SetWriteDeadline(time.Now().Add(proxy.connTrackTimeout))
for i := 0; i != read; {
written, err := cte.conn.Write(readBuf[i:read])
if err != nil {
log.Printf("Can't proxy a datagram to udp/%s: %s\n", proxy.backendAddr, err)
break
}
i += written
cte.lastW = time.Now()
}
cte.mu.Unlock()
}
}
func readDestFromCmsg(oob []byte, ipVer ipVersion) (_ net.IP, err error) {
defer func() {
// In case of partial upgrade / downgrade, docker-proxy could read
// control messages from a socket which doesn't have the sockopt
// IP_PKTINFO enabled. In that case, the control message will be all-0
// and Go's ControlMessage.Parse() will report an 'invalid header
// length' error. In that case, ignore the error and return an empty
// daddr - the kernel will pick a source address for us anyway (but
// maybe it'll be the wrong one).
if err != nil && err.Error() == "invalid header length" {
err = nil
}
}()
if ipVer == ip4 {
cm := &ipv4.ControlMessage{}
if err := cm.Parse(oob); err != nil {
return nil, err
}
return cm.Dst, nil
}
cm := &ipv6.ControlMessage{}
if err := cm.Parse(oob); err != nil {
return nil, err
}
return cm.Dst, nil
}
// Close ungracefully stops forwarding the traffic.
func (proxy *UDPProxy) Close() {
proxy.listener.Close()
proxy.connTrackLock.Lock()
defer proxy.connTrackLock.Unlock()
for _, cte := range proxy.connTrackTable {
// Unlike the GC logic in replyLoop, we want to close the connections
// immediately, even if there are pending and in-progress writes. So no
// need to lock cte.mu here.
cte.conn.Close()
}
}
|