1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
|
package quic
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type headerDecryptor interface {
DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte)
}
type headerParseError struct {
err error
}
func (e *headerParseError) Unwrap() error {
return e.err
}
func (e *headerParseError) Error() string {
return e.err.Error()
}
type unpackedPacket struct {
hdr *wire.ExtendedHeader
encryptionLevel protocol.EncryptionLevel
data []byte
}
// The packetUnpacker unpacks QUIC packets.
type packetUnpacker struct {
cs handshake.CryptoSetup
shortHdrConnIDLen int
}
var _ unpacker = &packetUnpacker{}
func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker {
return &packetUnpacker{
cs: cs,
shortHdrConnIDLen: shortHdrConnIDLen,
}
}
// UnpackLongHeader unpacks a Long Header packet.
// If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits.
// If any other error occurred when parsing the header, the error is of type headerParseError.
// If decrypting the payload fails for any reason, the error is the error returned by the AEAD.
func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
var encLevel protocol.EncryptionLevel
var extHdr *wire.ExtendedHeader
var decrypted []byte
//nolint:exhaustive // Retry packets can't be unpacked.
switch hdr.Type {
case protocol.PacketTypeInitial:
encLevel = protocol.EncryptionInitial
opener, err := u.cs.GetInitialOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketTypeHandshake:
encLevel = protocol.EncryptionHandshake
opener, err := u.cs.GetHandshakeOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
case protocol.PacketType0RTT:
encLevel = protocol.Encryption0RTT
opener, err := u.cs.Get0RTTOpener()
if err != nil {
return nil, err
}
extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown packet type: %s", hdr.Type)
}
if len(decrypted) == 0 {
return nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
}
}
return &unpackedPacket{
hdr: extHdr,
encryptionLevel: encLevel,
data: decrypted,
}, nil
}
func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
opener, err := u.cs.Get1RTTOpener()
if err != nil {
return 0, 0, 0, nil, err
}
pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data)
if err != nil {
return 0, 0, 0, nil, err
}
if len(decrypted) == 0 {
return 0, 0, 0, nil, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "empty packet",
}
}
return pn, pnLen, kp, decrypted, nil
}
func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, []byte, error) {
extHdr, parseErr := u.unpackLongHeader(opener, hdr, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, nil, parseErr
}
extHdrLen := extHdr.ParsedLen()
extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen)
decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen])
if err != nil {
return nil, nil, err
}
if parseErr != nil {
return nil, nil, parseErr
}
return extHdr, decrypted, nil
}
func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data)
// If the reserved bits are set incorrectly, we still need to continue unpacking.
// This avoids a timing side-channel, which otherwise might allow an attacker
// to gain information about the header encryption.
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return 0, 0, 0, nil, &headerParseError{parseErr}
}
pn = opener.DecodePacketNumber(pn, pnLen)
decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l])
if err != nil {
return 0, 0, 0, nil, err
}
return pn, pnLen, kp, decrypted, parseErr
}
func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) {
hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen
if len(data) < hdrLen+4+16 {
return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen)
}
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return l, pn, pnLen, kp, parseErr
}
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if pnLen != protocol.PacketNumberLen4 {
copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):])
}
return l, pn, pnLen, kp, parseErr
}
// The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError.
func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
extHdr, err := unpackLongHeader(hd, hdr, data)
if err != nil && err != wire.ErrInvalidReservedBits {
return nil, &headerParseError{err: err}
}
return extHdr, err
}
func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte) (*wire.ExtendedHeader, error) {
hdrLen := hdr.ParsedLen()
if protocol.ByteCount(len(data)) < hdrLen+4+16 {
return nil, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen)
}
// The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it.
// 1. save a copy of the 4 bytes
origPNBytes := make([]byte, 4)
copy(origPNBytes, data[hdrLen:hdrLen+4])
// 2. decrypt the header, assuming a 4 byte packet number
hd.DecryptHeader(
data[hdrLen+4:hdrLen+4+16],
&data[0],
data[hdrLen:hdrLen+4],
)
// 3. parse the header (and learn the actual length of the packet number)
extHdr, parseErr := hdr.ParseExtended(data)
if parseErr != nil && parseErr != wire.ErrInvalidReservedBits {
return nil, parseErr
}
// 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier
if extHdr.PacketNumberLen != protocol.PacketNumberLen4 {
copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):])
}
return extHdr, parseErr
}
|