1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
package quic
import (
"encoding/binary"
"errors"
"io"
)
const (
extTypeSNI = 0
extTypeECH = 0xfe0d
)
// findSNIAndECH parses the given byte slice as a ClientHello, and locates:
// - the position and length of the Server Name Indication (SNI) extension,
// - the position of the Encrypted Client Hello (ECH) extension.
// If no SNI extension is found, it returns -1 for the SNI position.
// If no ECH extension is found, it returns -1 for the ECH position.
func findSNIAndECH(data []byte) (sniPos, sniLen, echPos int, err error) {
if len(data) < 4 {
return 0, 0, 0, io.ErrUnexpectedEOF
}
if data[0] != 1 {
return 0, 0, 0, errors.New("not a ClientHello")
}
handshakeLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
if len(data) != 4+handshakeLen {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos := 4
// Skip protocol version (2 bytes)
if parsePos+2 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += 2
// skip random (32 bytes)
if parsePos+32 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += 32
// session ID
if parsePos+1 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
sessionIDLen := int(data[parsePos])
parsePos++
if parsePos+sessionIDLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += sessionIDLen
// cipher suites
if parsePos+2 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
cipherSuitesLen := int(binary.BigEndian.Uint16(data[parsePos:]))
parsePos += 2
if parsePos+cipherSuitesLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += cipherSuitesLen
// compression methods
if parsePos+1 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
compressionMethodsLen := int(data[parsePos])
parsePos++
if parsePos+compressionMethodsLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
parsePos += compressionMethodsLen
// extensions
if parsePos+2 > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
extensionsLen := int(binary.BigEndian.Uint16(data[parsePos:]))
parsePos += 2
if parsePos+extensionsLen > len(data) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
extensionsStart := parsePos
extensions := data[extensionsStart : extensionsStart+extensionsLen]
// parse extensions
var extPos int
sniPos = -1
echPos = -1
for extPos+4 <= extensionsLen {
extType := binary.BigEndian.Uint16(extensions[extPos:])
extLen := int(binary.BigEndian.Uint16(extensions[extPos+2:]))
if extPos+4+extLen > extensionsLen {
return 0, 0, 0, io.ErrUnexpectedEOF
}
switch extType {
case extTypeSNI:
if sniPos != -1 {
return 0, 0, 0, errors.New("multiple SNI extensions")
}
sniData := extensions[extPos+4 : extPos+4+extLen]
if len(sniData) < 2 {
return 0, 0, 0, io.ErrUnexpectedEOF
}
nameListLen := int(binary.BigEndian.Uint16(sniData))
if len(sniData) != 2+nameListLen {
return 0, 0, 0, io.ErrUnexpectedEOF
}
listPos := 2
for listPos+3 <= nameListLen+2 {
nameType := sniData[listPos]
sniLen = int(binary.BigEndian.Uint16(sniData[listPos+1:]))
if listPos+3+sniLen > len(sniData) {
return 0, 0, 0, io.ErrUnexpectedEOF
}
if nameType == 0 { // host_name
sniPos = extensionsStart + extPos + 4 + listPos + 3
break // stop after first host_name
}
listPos += 3 + sniLen
}
if sniPos == 0 {
return 0, 0, 0, errors.New("SNI host_name not found")
}
case extTypeECH:
if echPos != -1 {
return 0, 0, 0, errors.New("multiple ECH extensions")
}
echPos = extensionsStart + extPos
}
extPos += 4 + extLen
if sniPos != -1 && echPos != -1 {
break
}
}
return sniPos, sniLen, echPos, nil
}
|