1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
|
// Copyright 2016 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
func extractSNI(r io.Reader) (string, int, error) {
handshake, tlsver, err := handshakeRecord(r)
if err != nil {
return "", 0, fmt.Errorf("reading TLS record: %s", err)
}
sni, err := parseHello(handshake)
if err != nil {
return "", 0, fmt.Errorf("reading ClientHello: %s", err)
}
if len(sni) == 0 {
// ClientHello did not present an SNI extension. Valid packet,
// no hostname.
return "", tlsver, nil
}
hostname, err := parseSNI(sni)
if err != nil {
return "", 0, fmt.Errorf("parsing SNI extension: %s", err)
}
return hostname, tlsver, nil
}
// Extract the indicated hostname, if any, from the given SNI
// extension bytes.
func parseSNI(b []byte) (string, error) {
b, _, err := vector(b, 2)
if err != nil {
return "", err
}
var ret []byte
for len(b) >= 3 {
typ := b[0]
ret, b, err = vector(b[1:], 2)
if err != nil {
return "", fmt.Errorf("truncated SNI extension")
}
if typ == sniHostnameID {
return string(ret), nil
}
}
if len(b) != 0 {
return "", fmt.Errorf("trailing garbage at end of SNI extension")
}
// No DNS-based SNI present.
return "", nil
}
const sniExtensionID = 0
const sniHostnameID = 0
// Parse a TLS handshake record as a ClientHello message and extract
// the SNI extension bytes, if any.
func parseHello(b []byte) ([]byte, error) {
if len(b) == 0 {
return nil, errors.New("zero length handshake record")
}
if b[0] != 1 {
return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0])
}
// We're expecting a stricter TLS parser to run after we've
// proxied, so we ignore any trailing bytes that might be present
// (e.g. another handshake message).
b, _, err := vector(b[1:], 3)
if err != nil {
return nil, fmt.Errorf("reading ClientHello: %s", err)
}
// ClientHello must be at least 34 bytes to reach the first vector
// length byte. The actual minimal size is larger than that, but
// vector() will correctly handle truncated packets.
if len(b) < 34 {
return nil, errors.New("ClientHello packet too short")
}
if b[0] != 3 {
return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1])
}
switch b[1] {
case 1, 2, 3:
// TLS 1.0, TLS 1.1, TLS 1.2
default:
return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1])
}
// Skip over version and random struct
b = b[34:]
// We don't technically care about SessionID, but we care that the
// framing is well-formed all the way up to the SNI field, so that
// we are sure that we're pulling the same SNI bytes as the
// eventual TLS implementation.
vec, b, err := vector(b, 1)
if err != nil {
return nil, fmt.Errorf("reading ClientHello SessionID: %s", err)
}
if len(vec) > 32 {
return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec))
}
// Likewise, we're just checking the bare minimum of framing.
vec, b, err = vector(b, 2)
if err != nil {
return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err)
}
if len(vec) < 2 || len(vec)%2 != 0 {
return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec))
}
vec, b, err = vector(b, 1)
if err != nil {
return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err)
}
if len(vec) < 1 {
return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec))
}
// Finally, we reach the extensions.
if len(b) == 0 {
// No extensions. This is not an error, it just means we have
// no SNI payload.
return nil, nil
}
b, vec, err = vector(b, 2)
if err != nil {
return nil, fmt.Errorf("reading ClientHello extensions: %s", err)
}
if len(vec) != 0 {
return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec))
}
for len(b) >= 4 {
typ := binary.BigEndian.Uint16(b[:2])
vec, b, err = vector(b[2:], 2)
if err != nil {
return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err)
}
if typ == sniExtensionID {
// Found the SNI extension, return its payload. We don't
// care about anything in the packet beyond this point.
return vec, nil
}
}
if len(b) != 0 {
return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b))
}
// Successfully parsed all extensions, but there was no SNI.
return nil, nil
}
const maxTLSRecordLength = 16384
// Read one TLS record, which must be for the handshake protocol, from r.
func handshakeRecord(r io.Reader) ([]byte, int, error) {
var hdr struct {
Type uint8
Major, Minor uint8
Length uint16
}
if err := binary.Read(r, binary.BigEndian, &hdr); err != nil {
return nil, 0, fmt.Errorf("reading TLS record header: %s", err)
}
if hdr.Type != 22 {
return nil, 0, fmt.Errorf("TLS record is not a handshake")
}
if hdr.Major != 3 {
return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
}
switch hdr.Minor {
case 1, 2, 3:
// TLS 1.0, TLS 1.1, TLS 1.2
default:
return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor)
}
if hdr.Length > maxTLSRecordLength {
return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength)
}
ret := make([]byte, hdr.Length)
if _, err := io.ReadFull(r, ret); err != nil {
return nil, 0, err
}
return ret, int(hdr.Minor), nil
}
func vector(b []byte, lenBytes int) ([]byte, []byte, error) {
if len(b) < lenBytes {
return nil, nil, errors.New("not enough space in packet for vector")
}
var l int
for _, b := range b[:lenBytes] {
l = (l << 8) + int(b)
}
if len(b) < l+lenBytes {
return nil, nil, errors.New("not enough space in packet for vector")
}
return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil
}
|