1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271
|
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
/// Errors thrown by the NIO websocket module.
public enum NIOWebSocketError: Error {
/// The frame being sent is larger than the configured maximum
/// acceptable frame size
case invalidFrameLength
/// A control frame may not be fragmented.
case fragmentedControlFrame
/// A control frame may not have a length more than 125 bytes.
case multiByteControlFrameLength
}
extension WebSocketErrorCode {
init(_ error: NIOWebSocketError) {
switch error {
case .invalidFrameLength:
self = .messageTooLarge
case .fragmentedControlFrame,
.multiByteControlFrameLength:
self = .protocolError
}
}
}
extension ByteBuffer {
/// Applies the WebSocket unmasking operation.
///
/// - parameters:
/// - maskingKey: The masking key.
public mutating func webSocketUnmask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) {
/// Shhhh: secretly unmasking and masking are the same operation!
webSocketMask(maskingKey, indexOffset: indexOffset)
}
/// Applies the websocket masking operation.
///
/// - parameters:
/// - maskingKey: The masking key.
/// - indexOffset: An integer offset to apply to the index into the masking key.
/// This is used when masking multiple "contiguous" byte buffers, to ensure that
/// the masking key is applied uniformly to the collection rather than from the
/// start each time.
public mutating func webSocketMask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) {
self.withUnsafeMutableReadableBytes {
for (index, byte) in $0.enumerated() {
$0[index] = byte ^ maskingKey[(index + indexOffset) % 4]
}
}
}
}
/// The current state of the frame decoder.
enum DecoderState {
/// Waiting for a frame.
case idle
/// The initial frame byte has been received, but the length byte
/// has not.
case firstByteReceived(firstByte: UInt8)
/// The length byte indicates that we need to wait for the length word, and we're
/// currently waiting for it.
case waitingForLengthWord(firstByte: UInt8, masked: Bool)
/// The length byte indicates that we need to wait for the length qword, and
/// we're currently waiting for it.
case waitingForLengthQWord(firstByte: UInt8, masked: Bool)
/// The mask bit indicates we are expecting a mask key.
case waitingForMask(firstByte: UInt8, length: Int)
/// All the header data is complete, we are waiting for the application data.
case waitingForData(firstByte: UInt8, length: Int, maskingKey: WebSocketMaskingKey?)
}
enum ParseResult {
case insufficientData
case continueParsing
case result(WebSocketFrame)
}
/// An incremental websocket frame parser.
///
/// This parser attempts to parse a websocket frame incrementally, keeping as much parsing state around as possible to ensure that
/// we don't repeatedly partially parse the data.
struct WSParser {
/// The current state of the decoder during incremental parse.
var state: DecoderState = .idle
mutating func parseStep(_ buffer: inout ByteBuffer) -> ParseResult {
switch self.state {
case .idle:
// This is a new buffer. We want to find the first octet and save it off.
guard let firstByte = buffer.readInteger(as: UInt8.self) else {
return .insufficientData
}
self.state = .firstByteReceived(firstByte: firstByte)
return .continueParsing
case .firstByteReceived(let firstByte):
// Now we're looking for the length. We begin by finding the length byte to see if we
// need any more data.
guard let lengthByte = buffer.readInteger(as: UInt8.self) else {
return .insufficientData
}
let masked = (lengthByte & 0x80) != 0
switch (lengthByte & 0x7F, masked) {
case (126, _):
self.state = .waitingForLengthWord(firstByte: firstByte, masked: masked)
case (127, _):
self.state = .waitingForLengthQWord(firstByte: firstByte, masked: masked)
case (let len, true):
assert(len <= 125)
self.state = .waitingForMask(firstByte: firstByte, length: Int(len))
case (let len, false):
assert(len <= 125)
self.state = .waitingForData(firstByte: firstByte, length: Int(len), maskingKey: nil)
}
return .continueParsing
case .waitingForLengthWord(let firstByte, let masked):
// We've got a one-word length here.
guard let lengthWord = buffer.readInteger(as: UInt16.self) else {
return .insufficientData
}
if masked {
self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthWord))
} else {
self.state = .waitingForData(firstByte: firstByte, length: Int(lengthWord), maskingKey: nil)
}
return .continueParsing
case .waitingForLengthQWord(let firstByte, let masked):
// We've got a qword of length here.
guard let lengthQWord = buffer.readInteger(as: UInt64.self) else {
return .insufficientData
}
if masked {
self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthQWord))
} else {
self.state = .waitingForData(firstByte: firstByte, length: Int(lengthQWord), maskingKey: nil)
}
return .continueParsing
case .waitingForMask(let firstByte, let length):
// We're waiting for the masking key.
guard let maskingKey = buffer.readInteger(as: UInt32.self) else {
return .insufficientData
}
self.state = .waitingForData(firstByte: firstByte, length: length, maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey))
return .continueParsing
case .waitingForData(let firstByte, let length, let maskingKey):
guard let data = buffer.readSlice(length: length) else {
return .insufficientData
}
let frame = WebSocketFrame(firstByte: firstByte, maskKey: maskingKey, applicationData: data)
self.state = .idle
return .result(frame)
}
}
/// Apply a number of validations to the incremental state, ensuring that the frame we're
/// receiving is valid.
func validateState(maxFrameSize: Int) throws {
switch self.state {
case .waitingForMask(let firstByte, let length), .waitingForData(let firstByte, let length, _):
if length > maxFrameSize {
throw NIOWebSocketError.invalidFrameLength
}
let isControlFrame = (firstByte & 0x08) != 0
let isFragment = (firstByte & 0x80) == 0
if isControlFrame && isFragment {
throw NIOWebSocketError.fragmentedControlFrame
}
if isControlFrame && length > 125 {
throw NIOWebSocketError.multiByteControlFrameLength
}
case .idle, .firstByteReceived, .waitingForLengthWord, .waitingForLengthQWord:
// No validation necessary in this state as we have no length to validate.
break
}
}
}
/// An inbound `ChannelHandler` that deserializes websocket frames into a structured
/// format for further processing.
///
/// This decoder has limited enforcement of compliance to RFC 6455. In particular, to guarantee
/// that the decoder can handle arbitrary extensions, only normative MUST/MUST NOTs that do not
/// relate to extensions (e.g. the requirement that control frames not have lengths larger than
/// 125 bytes) are enforced by this decoder.
///
/// This decoder does not have any support for decoding extensions. If you wish to support
/// extensions, you should implement a message-to-message decoder that performs the appropriate
/// frame transformation as needed. All the frame data is assumed to be application data by this
/// parser.
public final class WebSocketFrameDecoder: ByteToMessageDecoder {
public typealias InboundIn = ByteBuffer
public typealias InboundOut = WebSocketFrame
public typealias OutboundOut = WebSocketFrame
/// The maximum frame size the decoder is willing to tolerate from the remote peer.
/* private but tests */ let maxFrameSize: Int
/// Our parser state.
private var parser = WSParser()
/// Construct a new `WebSocketFrameDecoder`
///
/// - parameters:
/// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the
/// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but
/// this is an objectively unreasonable maximum value (on AMD64 systems it is not
/// possible to even allocate a buffer large enough to handle this size), so we
/// set a lower one. The default value is the same as the default HTTP/2 max frame
/// size, `2**14` bytes. Users may override this to any value up to `UInt32.max`.
/// Users are strongly encouraged not to increase this value unless they absolutely
/// must, as the decoder will not produce partial frames, meaning that it will hold
/// on to data until the *entire* body is received.
/// - automaticErrorHandling: Whether this `ChannelHandler` should automatically handle
/// protocol errors in frame serialization, or whether it should allow the pipeline
/// to handle them.
public init(maxFrameSize: Int = 1 << 14) {
precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size")
self.maxFrameSize = maxFrameSize
}
public func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
// Even though the calling code will loop around calling us in `decode`, we can't quite
// rely on that: sometimes we have zero-length elements to parse, and the caller doesn't
// guarantee to call us with zero-length bytes.
while true {
switch parser.parseStep(&buffer) {
case .result(let frame):
context.fireChannelRead(self.wrapInboundOut(frame))
return .continue
case .continueParsing:
try self.parser.validateState(maxFrameSize: self.maxFrameSize)
// loop again, might be 'waiting' for 0 bytes
case .insufficientData:
return .needMoreData
}
}
}
}
|