1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
package turn
import (
"encoding/binary"
"errors"
"net"
"time"
"github.com/pion/stun"
"github.com/pion/turn/v2/internal/proto"
)
var (
errInvalidTURNFrame = errors.New("data is not a valid TURN frame, no STUN or ChannelData found")
errIncompleteTURNFrame = errors.New("data contains incomplete STUN or TURN frame")
)
// STUNConn wraps a net.Conn and implements
// net.PacketConn by being STUN aware and
// packetizing the stream
type STUNConn struct {
nextConn net.Conn
buff []byte
}
const (
stunHeaderSize = 20
channelDataLengthSize = 2
channelDataNumberSize = channelDataLengthSize
channelDataHeaderSize = channelDataLengthSize + channelDataNumberSize
channelDataPadding = 4
)
// Given a buffer give the last offset of the TURN frame
// If the buffer isn't a valid STUN or ChannelData packet
// or the length doesn't match return false
func consumeSingleTURNFrame(p []byte) (int, error) {
// Too short to determine if ChannelData or STUN
if len(p) < 9 {
return 0, errIncompleteTURNFrame
}
var datagramSize uint16
switch {
case stun.IsMessage(p):
datagramSize = binary.BigEndian.Uint16(p[2:4]) + stunHeaderSize
case proto.ChannelNumber(binary.BigEndian.Uint16(p[0:2])).Valid():
datagramSize = binary.BigEndian.Uint16(p[channelDataNumberSize:channelDataHeaderSize])
if paddingOverflow := (datagramSize + channelDataPadding) % channelDataPadding; paddingOverflow != 0 {
datagramSize = (datagramSize + channelDataPadding) - paddingOverflow
}
datagramSize += channelDataHeaderSize
case len(p) < stunHeaderSize:
return 0, errIncompleteTURNFrame
default:
return 0, errInvalidTURNFrame
}
if len(p) < int(datagramSize) {
return 0, errIncompleteTURNFrame
}
return int(datagramSize), nil
}
// ReadFrom implements ReadFrom from net.PacketConn
func (s *STUNConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
// First pass any buffered data from previous reads
n, err = consumeSingleTURNFrame(s.buff)
if errors.Is(err, errInvalidTURNFrame) {
return 0, nil, err
} else if err == nil {
copy(p, s.buff[:n])
s.buff = s.buff[n:]
return n, s.nextConn.RemoteAddr(), nil
}
// Then read from the nextConn, appending to our buff
n, err = s.nextConn.Read(p)
if err != nil {
return 0, nil, err
}
s.buff = append(s.buff, append([]byte{}, p[:n]...)...)
return s.ReadFrom(p)
}
// WriteTo implements WriteTo from net.PacketConn
func (s *STUNConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return s.nextConn.Write(p)
}
// Close implements Close from net.PacketConn
func (s *STUNConn) Close() error {
return s.nextConn.Close()
}
// LocalAddr implements LocalAddr from net.PacketConn
func (s *STUNConn) LocalAddr() net.Addr {
return s.nextConn.LocalAddr()
}
// SetDeadline implements SetDeadline from net.PacketConn
func (s *STUNConn) SetDeadline(t time.Time) error {
return s.nextConn.SetDeadline(t)
}
// SetReadDeadline implements SetReadDeadline from net.PacketConn
func (s *STUNConn) SetReadDeadline(t time.Time) error {
return s.nextConn.SetReadDeadline(t)
}
// SetWriteDeadline implements SetWriteDeadline from net.PacketConn
func (s *STUNConn) SetWriteDeadline(t time.Time) error {
return s.nextConn.SetWriteDeadline(t)
}
// NewSTUNConn creates a STUNConn
func NewSTUNConn(nextConn net.Conn) *STUNConn {
return &STUNConn{nextConn: nextConn}
}
|