1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
package transfer
import (
"bytes"
"crypto/hmac"
"crypto/md5"
"encoding/binary"
"errors"
"fmt"
"hash"
"io"
"net"
"syscall"
"time"
)
// digestMD5IntegrityConn returns a net.Conn wrapper that peforms md5-digest
// integrity checks on data passing over it.
type digestMD5IntegrityConn struct {
conn net.Conn
readDeadline time.Time
readBuf bytes.Buffer
writeBuf bytes.Buffer
sendSeqNum int
readSeqNum int
encodeMAC hash.Hash
decodeMAC hash.Hash
}
func newDigestMD5IntegrityConn(conn net.Conn, kic, kis []byte) digestMD5Conn {
return &digestMD5IntegrityConn{
conn: conn,
encodeMAC: hmac.New(md5.New, kic),
decodeMAC: hmac.New(md5.New, kis),
}
}
func (d *digestMD5IntegrityConn) Close() error {
return d.conn.Close()
}
func (d *digestMD5IntegrityConn) LocalAddr() net.Addr {
return d.conn.LocalAddr()
}
func (d *digestMD5IntegrityConn) RemoteAddr() net.Addr {
return d.conn.RemoteAddr()
}
func (d *digestMD5IntegrityConn) SetDeadline(t time.Time) error {
d.readDeadline = t
return d.conn.SetDeadline(t)
}
func (d *digestMD5IntegrityConn) SetReadDeadline(t time.Time) error {
d.readDeadline = t
return d.conn.SetReadDeadline(t)
}
func (d *digestMD5IntegrityConn) SetWriteDeadline(t time.Time) error {
return d.conn.SetWriteDeadline(t)
}
func (d *digestMD5IntegrityConn) Write(b []byte) (n int, err error) {
inputLen := len(b)
seqBuf := lenEncodeBytes(d.sendSeqNum)
outputLen := macDataLen + inputLen + macHMACLen + macMsgTypeLen + macSeqNumLen
d.writeBuf.Reset()
d.writeBuf.Grow(outputLen)
binary.Write(&d.writeBuf, binary.BigEndian, int32(outputLen-macDataLen))
d.writeBuf.Write(b)
hmac := msgHMAC(d.encodeMAC, seqBuf, b)
d.writeBuf.Write(hmac)
d.writeBuf.Write(macMsgType[:])
binary.Write(&d.writeBuf, binary.BigEndian, int32(d.sendSeqNum))
d.sendSeqNum++
wr, err := d.writeBuf.WriteTo(d.conn)
return int(wr), err
}
// Read will decode the underlying bytes and then copy them from our
// buffer into the provided byte slice
func (d *digestMD5IntegrityConn) Read(b []byte) (int, error) {
if !d.readDeadline.IsZero() && d.readDeadline.Before(time.Now()) {
return 0, syscall.ETIMEDOUT
}
n, err := d.readBuf.Read(b)
if len(b) == n || (err != nil && err != io.EOF) {
return n, err
}
var sz int32
err = binary.Read(d.conn, binary.BigEndian, &sz)
if err != nil {
return n, err
}
d.readBuf.Reset()
d.readBuf.Grow(int(sz))
_, err = io.CopyN(&d.readBuf, d.conn, int64(sz))
if err != nil {
return n, err
}
decoded, err := d.decode(d.readBuf.Bytes())
if err != nil {
return n, err
}
d.readBuf.Truncate(len(decoded))
return d.readBuf.Read(b[n:])
}
// decode will decode a message from the server and perform the integrity
// protection check, removing the verification and mac data in what is returned
// the slice returned is an alias to the buffer and must be either used or
// copied to a new slice before calling decode again
func (d *digestMD5IntegrityConn) decode(input []byte) ([]byte, error) {
inputLen := len(input)
if inputLen < saslIntegrityPrefixLength {
return nil, fmt.Errorf("Input length smaller than the integrity prefix")
}
seqBuf := lenEncodeBytes(d.readSeqNum)
dataLen := inputLen - macHMACLen - macMsgTypeLen - macSeqNumLen
hmac := msgHMAC(d.decodeMAC, seqBuf, input[:dataLen])
seqNumStart := inputLen - macSeqNumLen
msgTypeStart := seqNumStart - macMsgTypeLen
origHashStart := msgTypeStart - macHMACLen
if !bytes.Equal(hmac, input[origHashStart:origHashStart+macHMACLen]) ||
!bytes.Equal(macMsgType[:], input[msgTypeStart:msgTypeStart+macMsgTypeLen]) ||
!bytes.Equal(seqBuf[:], input[seqNumStart:seqNumStart+macSeqNumLen]) {
return nil, errors.New("HMAC Integrity Check failed")
}
d.readSeqNum++
return input[:dataLen], nil
}
// msgHMAC implements the HMAC wrapper per the RFC:
//
// HMAC(ki, {seqnum, msg})[0..9].
func msgHMAC(mac hash.Hash, seq [4]byte, msg []byte) []byte {
mac.Reset()
mac.Write(seq[:])
mac.Write(msg)
return mac.Sum(nil)[:10]
}
|