1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
|
package smb2
import (
"errors"
"io"
"net"
)
const (
maxDirectTCPSize = 0xffffff // 16777215
// maxNetBTSize = 0x1ffff // 131071
)
type transport interface {
Write(p []byte) (n int, err error)
ReadSize() (size int, err error)
Read(p []byte) (n int, err error)
Close() error
}
type directTCP struct {
sb [4]byte
rb [4]byte
conn net.Conn
}
func direct(tcpConn net.Conn) transport {
return &directTCP{conn: tcpConn}
}
func (t *directTCP) Write(p []byte) (n int, err error) {
if len(p) > maxDirectTCPSize {
return 0, errors.New("max transport size exceeds")
}
bs := t.sb[:]
be.PutUint32(bs, uint32(len(p)))
_, err = t.conn.Write(bs)
if err != nil {
return 0, err
}
n, err = t.conn.Write(p)
if err != nil {
return 0, err
}
return n + 4, nil
}
func (t *directTCP) ReadSize() (size int, err error) {
bs := t.rb[:]
_, err = io.ReadFull(t.conn, bs)
if err != nil {
return 0, err
}
if bs[0] != 0 {
return 0, errors.New("invalid transport format")
}
return int(be.Uint32(bs)), nil
}
func (t *directTCP) Read(p []byte) (n int, err error) {
n, err = io.ReadFull(t.conn, p)
if err != nil {
return 0, err
}
return n, err
}
func (t *directTCP) Close() error {
return t.conn.Close()
}
|