1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
|
// Package pktline implements utility functions for working with the Git
// pkt-line format. See
// https://git-scm.com/docs/protocol-common#_pkt_line_format
package pktline
import (
"bufio"
"bytes"
"fmt"
"io"
"strconv"
"sync"
)
const (
// MaxSidebandData is the maximum number of bytes that fits into one Git
// pktline side-band-64k packet.
MaxSidebandData = MaxPktSize - 5
// MaxPktSize is the maximum size of content of a Git pktline side-band-64k
// packet, including size of length and band number
// https://gitlab.com/gitlab-org/git/-/blob/v2.30.0/pkt-line.h#L216
MaxPktSize = 65520
)
// NewScanner returns a bufio.Scanner that splits on Git pktline boundaries
func NewScanner(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, MaxPktSize), MaxPktSize)
scanner.Split(pktLineSplitter)
return scanner
}
// Data returns the packet pkt without its length header. The length
// header is not validated. Returns an empty slice when pkt is a magic packet such
// as '0000'.
func Data(pkt []byte) []byte {
return pkt[4:]
}
// Payload returns the pktline's data. It verifies that the length header matches what we expect as
// data.
func Payload(pkt []byte) ([]byte, error) {
if len(pkt) < 4 {
return nil, fmt.Errorf("packet too small")
}
if IsFlush(pkt) {
return nil, fmt.Errorf("flush packets do not have a payload")
}
lengthHeader := string(pkt[:4])
length, err := strconv.ParseUint(lengthHeader, 16, 16)
if err != nil {
return nil, fmt.Errorf("parsing length header %q: %w", lengthHeader, err)
}
if uint64(len(pkt)) != length {
return nil, fmt.Errorf("packet length %d does not match header length %d", len(pkt), length)
}
return pkt[4:], nil
}
// IsFlush detects the special flush packet '0000'
func IsFlush(pkt []byte) bool {
return bytes.Equal(pkt, PktFlush())
}
// WriteString writes a string with pkt-line framing
func WriteString(w io.Writer, str string) (int, error) {
pktLen := len(str) + 4
if pktLen > MaxPktSize {
return 0, fmt.Errorf("string too large: %d bytes", len(str))
}
_, err := fmt.Fprintf(w, "%04x%s", pktLen, str)
return len(str), err
}
// WriteFlush writes a pkt flush packet.
func WriteFlush(w io.Writer) error {
_, err := w.Write(PktFlush())
return err
}
// WriteDelim writes a pkt delim packet.
func WriteDelim(w io.Writer) error {
_, err := w.Write(PktDelim())
return err
}
// PktDone returns the bytes for a "done" packet.
func PktDone() []byte {
return []byte("0009done\n")
}
// PktDelim returns the bytes for a "delim" packet.
func PktDelim() []byte {
return []byte("0001")
}
// PktFlush returns the bytes for a "flush" packet.
func PktFlush() []byte {
return []byte("0000")
}
func pktLineSplitter(data []byte, atEOF bool) (advance int, token []byte, err error) {
if len(data) < 4 {
if atEOF && len(data) > 0 {
return 0, nil, fmt.Errorf("pktLineSplitter: incomplete length prefix on %q", data)
}
return 0, nil, nil // want more data
}
// We have at least 4 bytes available so we can decode the 4-hex digit
// length prefix of the packet line.
pktLength64, err := strconv.ParseInt(string(data[:4]), 16, 0)
if err != nil {
return 0, nil, fmt.Errorf("pktLineSplitter: decode length: %w", err)
}
// Cast is safe because we requested an int-size number from strconv.ParseInt
pktLength := int(pktLength64)
if pktLength < 0 || pktLength > MaxPktSize {
return 0, nil, fmt.Errorf("pktLineSplitter: invalid length: %d", pktLength)
}
if pktLength < 4 {
// Special case: magic empty packet 0000, 0001, 0002 or 0003.
return 4, data[:4], nil
}
if len(data) < pktLength {
// data contains incomplete packet
if atEOF {
return 0, nil, io.ErrUnexpectedEOF
}
return 0, nil, nil // want more data
}
return pktLength, data[:pktLength], nil
}
// SidebandWriter multiplexes byte streams into a single side-band-64k stream.
type SidebandWriter struct {
w io.Writer
m sync.Mutex
buf [MaxPktSize]byte // Use a buffer to coalesce header and payload into one write syscall
}
// NewSidebandWriter instantiates a new SidebandWriter.
func NewSidebandWriter(w io.Writer) *SidebandWriter { return &SidebandWriter{w: w} }
func (sw *SidebandWriter) writeBand(band byte, data []byte) (int, error) {
sw.m.Lock()
defer sw.m.Unlock()
n := 0
for len(data) > 0 {
const headerSize = 5
chunkSize := copy(sw.buf[headerSize:], data)
header := chunkSize + headerSize
copy(sw.buf[:4], fmt.Sprintf("%04x", header))
sw.buf[4] = band
if _, err := sw.w.Write(sw.buf[:header]); err != nil {
return n, err
}
data = data[chunkSize:]
n += chunkSize
}
return n, nil
}
// Writer returns an io.Writer that writes into the multiplexed stream.
// Writers for different bands can be used concurrently.
func (sw *SidebandWriter) Writer(band byte) io.Writer {
return writerFunc(func(p []byte) (int, error) {
return sw.writeBand(band, p)
})
}
type writerFunc func([]byte) (int, error)
func (wf writerFunc) Write(p []byte) (int, error) { return wf(p) }
type invalidSidebandPacketError struct{ pkt string }
func (err *invalidSidebandPacketError) Error() string {
return fmt.Sprintf("invalid sideband packet: %q", err.pkt)
}
// EachSidebandPacket iterates over a side-band-64k pktline stream. For
// each packet, it will call fn with the band ID and the packet. Fn must
// not retain the packet.
func EachSidebandPacket(r io.Reader, fn func(byte, []byte) error) error {
scanner := NewScanner(r)
for scanner.Scan() {
data := Data(scanner.Bytes())
if len(data) == 0 {
return &invalidSidebandPacketError{scanner.Text()}
}
if err := fn(data[0], data[1:]); err != nil {
return err
}
}
return scanner.Err()
}
|