File: extended_header.go

package info (click to toggle)
golang-github-lucas-clemente-quic-go 0.54.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,312 kB
  • sloc: sh: 54; makefile: 7
file content (164 lines) | stat: -rw-r--r-- 5,362 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package wire

import (
	"encoding/binary"
	"errors"
	"fmt"
	"io"

	"github.com/quic-go/quic-go/internal/protocol"
	"github.com/quic-go/quic-go/internal/utils"
	"github.com/quic-go/quic-go/quicvarint"
)

// ErrInvalidReservedBits is returned when the reserved bits are incorrect.
// When this error is returned, parsing continues, and an ExtendedHeader is returned.
// This is necessary because we need to decrypt the packet in that case,
// in order to avoid a timing side-channel.
var ErrInvalidReservedBits = errors.New("invalid reserved bits")

// ExtendedHeader is the header of a QUIC packet.
type ExtendedHeader struct {
	Header

	typeByte byte

	KeyPhase protocol.KeyPhaseBit

	PacketNumberLen protocol.PacketNumberLen
	PacketNumber    protocol.PacketNumber

	parsedLen protocol.ByteCount
}

func (h *ExtendedHeader) parse(data []byte) (bool /* reserved bits valid */, error) {
	// read the (now unencrypted) first byte
	h.typeByte = data[0]
	h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
	if protocol.ByteCount(len(data)) < h.Header.ParsedLen()+protocol.ByteCount(h.PacketNumberLen) {
		return false, io.EOF
	}

	pn, err := readPacketNumber(data[h.Header.ParsedLen():], h.PacketNumberLen)
	if err != nil {
		return true, nil
	}
	h.PacketNumber = pn
	reservedBitsValid := h.typeByte&0xc == 0

	h.parsedLen = h.Header.ParsedLen() + protocol.ByteCount(h.PacketNumberLen)
	return reservedBitsValid, err
}

// Append appends the Header.
func (h *ExtendedHeader) Append(b []byte, v protocol.Version) ([]byte, error) {
	if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
		return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
	}
	if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
		return nil, fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
	}

	var packetType uint8
	if v == protocol.Version2 {
		switch h.Type {
		case protocol.PacketTypeInitial:
			packetType = 0b01
		case protocol.PacketType0RTT:
			packetType = 0b10
		case protocol.PacketTypeHandshake:
			packetType = 0b11
		case protocol.PacketTypeRetry:
			packetType = 0b00
		}
	} else {
		switch h.Type {
		case protocol.PacketTypeInitial:
			packetType = 0b00
		case protocol.PacketType0RTT:
			packetType = 0b01
		case protocol.PacketTypeHandshake:
			packetType = 0b10
		case protocol.PacketTypeRetry:
			packetType = 0b11
		}
	}
	firstByte := 0xc0 | packetType<<4
	if h.Type != protocol.PacketTypeRetry {
		// Retry packets don't have a packet number
		firstByte |= uint8(h.PacketNumberLen - 1)
	}

	b = append(b, firstByte)
	b = append(b, make([]byte, 4)...)
	binary.BigEndian.PutUint32(b[len(b)-4:], uint32(h.Version))
	b = append(b, uint8(h.DestConnectionID.Len()))
	b = append(b, h.DestConnectionID.Bytes()...)
	b = append(b, uint8(h.SrcConnectionID.Len()))
	b = append(b, h.SrcConnectionID.Bytes()...)

	//nolint:exhaustive
	switch h.Type {
	case protocol.PacketTypeRetry:
		b = append(b, h.Token...)
		return b, nil
	case protocol.PacketTypeInitial:
		b = quicvarint.Append(b, uint64(len(h.Token)))
		b = append(b, h.Token...)
	}
	b = quicvarint.AppendWithLen(b, uint64(h.Length), 2)
	return appendPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}

// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
	return h.parsedLen
}

// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(_ protocol.Version) protocol.ByteCount {
	length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
	if h.Type == protocol.PacketTypeInitial {
		length += protocol.ByteCount(quicvarint.Len(uint64(len(h.Token))) + len(h.Token))
	}
	return length
}

// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
	var token string
	if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
		if len(h.Token) == 0 {
			token = "Token: (empty), "
		} else {
			token = fmt.Sprintf("Token: %#x, ", h.Token)
		}
		if h.Type == protocol.PacketTypeRetry {
			logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
			return
		}
	}
	logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
}

func appendPacketNumber(b []byte, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) ([]byte, error) {
	switch pnLen {
	case protocol.PacketNumberLen1:
		b = append(b, uint8(pn))
	case protocol.PacketNumberLen2:
		buf := make([]byte, 2)
		binary.BigEndian.PutUint16(buf, uint16(pn))
		b = append(b, buf...)
	case protocol.PacketNumberLen3:
		buf := make([]byte, 4)
		binary.BigEndian.PutUint32(buf, uint32(pn))
		b = append(b, buf[1:]...)
	case protocol.PacketNumberLen4:
		buf := make([]byte, 4)
		binary.BigEndian.PutUint32(buf, uint32(pn))
		b = append(b, buf...)
	default:
		return nil, fmt.Errorf("invalid packet number length: %d", pnLen)
	}
	return b, nil
}