File: WebSocketFrameDecoder.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (271 lines) | stat: -rw-r--r-- 11,353 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIO

/// Errors thrown by the NIO websocket module.
public enum NIOWebSocketError: Error {
    /// The frame being sent is larger than the configured maximum
    /// acceptable frame size
    case invalidFrameLength

    /// A control frame may not be fragmented.
    case fragmentedControlFrame

    /// A control frame may not have a length more than 125 bytes.
    case multiByteControlFrameLength
}

extension WebSocketErrorCode {
    init(_ error: NIOWebSocketError) {
        switch error {
        case .invalidFrameLength:
            self = .messageTooLarge
        case .fragmentedControlFrame,
             .multiByteControlFrameLength:
            self = .protocolError
        }
    }
}

extension ByteBuffer {
    /// Applies the WebSocket unmasking operation.
    ///
    /// - parameters:
    ///     - maskingKey: The masking key.
    public mutating func webSocketUnmask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) {
        /// Shhhh: secretly unmasking and masking are the same operation!
        webSocketMask(maskingKey, indexOffset: indexOffset)
    }

    /// Applies the websocket masking operation.
    ///
    /// - parameters:
    ///     - maskingKey: The masking key.
    ///     - indexOffset: An integer offset to apply to the index into the masking key.
    ///         This is used when masking multiple "contiguous" byte buffers, to ensure that
    ///         the masking key is applied uniformly to the collection rather than from the
    ///         start each time.
    public mutating func webSocketMask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) {
        self.withUnsafeMutableReadableBytes {
            for (index, byte) in $0.enumerated() {
                $0[index] = byte ^ maskingKey[(index + indexOffset) % 4]
            }
        }
    }
}

/// The current state of the frame decoder.
enum DecoderState {
    /// Waiting for a frame.
    case idle

    /// The initial frame byte has been received, but the length byte
    /// has not.
    case firstByteReceived(firstByte: UInt8)

    /// The length byte indicates that we need to wait for the length word, and we're
    /// currently waiting for it.
    case waitingForLengthWord(firstByte: UInt8, masked: Bool)

    /// The length byte indicates that we need to wait for the length qword, and
    /// we're currently waiting for it.
    case waitingForLengthQWord(firstByte: UInt8, masked: Bool)

    /// The mask bit indicates we are expecting a mask key.
    case waitingForMask(firstByte: UInt8, length: Int)

    /// All the header data is complete, we are waiting for the application data.
    case waitingForData(firstByte: UInt8, length: Int, maskingKey: WebSocketMaskingKey?)
}

enum ParseResult {
    case insufficientData
    case continueParsing
    case result(WebSocketFrame)
}

/// An incremental websocket frame parser.
///
/// This parser attempts to parse a websocket frame incrementally, keeping as much parsing state around as possible to ensure that
/// we don't repeatedly partially parse the data.
struct WSParser {
    /// The current state of the decoder during incremental parse.
    var state: DecoderState = .idle

    mutating func parseStep(_ buffer: inout ByteBuffer) -> ParseResult {
        switch self.state {
        case .idle:
            // This is a new buffer. We want to find the first octet and save it off.
            guard let firstByte = buffer.readInteger(as: UInt8.self) else {
                return .insufficientData
            }
            self.state = .firstByteReceived(firstByte: firstByte)
            return .continueParsing

        case .firstByteReceived(let firstByte):
            // Now we're looking for the length. We begin by finding the length byte to see if we
            // need any more data.
            guard let lengthByte = buffer.readInteger(as: UInt8.self) else {
                return .insufficientData
            }

            let masked = (lengthByte & 0x80) != 0

            switch (lengthByte & 0x7F, masked) {
            case (126, _):
                self.state = .waitingForLengthWord(firstByte: firstByte, masked: masked)
            case (127, _):
                self.state = .waitingForLengthQWord(firstByte: firstByte, masked: masked)
            case (let len, true):
                assert(len <= 125)
                self.state = .waitingForMask(firstByte: firstByte, length: Int(len))
            case (let len, false):
                assert(len <= 125)
                self.state = .waitingForData(firstByte: firstByte, length: Int(len), maskingKey: nil)
            }
            return .continueParsing

        case .waitingForLengthWord(let firstByte, let masked):
            // We've got a one-word length here.
            guard let lengthWord = buffer.readInteger(as: UInt16.self) else {
                return .insufficientData
            }

            if masked {
                self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthWord))
            } else {
                self.state = .waitingForData(firstByte: firstByte, length: Int(lengthWord), maskingKey: nil)
            }
            return .continueParsing

        case .waitingForLengthQWord(let firstByte, let masked):
            // We've got a qword of length here.
            guard let lengthQWord = buffer.readInteger(as: UInt64.self) else {
                return .insufficientData
            }

            if masked {
                self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthQWord))
            } else {
                self.state = .waitingForData(firstByte: firstByte, length: Int(lengthQWord), maskingKey: nil)
            }
            return .continueParsing

        case .waitingForMask(let firstByte, let length):
            // We're waiting for the masking key.
            guard let maskingKey = buffer.readInteger(as: UInt32.self) else {
                return .insufficientData
            }

            self.state = .waitingForData(firstByte: firstByte, length: length, maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey))
            return .continueParsing

        case .waitingForData(let firstByte, let length, let maskingKey):
            guard let data = buffer.readSlice(length: length) else {
                return .insufficientData
            }

            let frame = WebSocketFrame(firstByte: firstByte, maskKey: maskingKey, applicationData: data)
            self.state = .idle
            return .result(frame)
        }
    }

    /// Apply a number of validations to the incremental state, ensuring that the frame we're
    /// receiving is valid.
    func validateState(maxFrameSize: Int) throws {
        switch self.state {
        case .waitingForMask(let firstByte, let length), .waitingForData(let firstByte, let length, _):
            if length > maxFrameSize {
                throw NIOWebSocketError.invalidFrameLength
            }

            let isControlFrame = (firstByte & 0x08) != 0
            let isFragment = (firstByte & 0x80) == 0

            if isControlFrame && isFragment {
                throw NIOWebSocketError.fragmentedControlFrame
            }
            if isControlFrame && length > 125 {
                throw NIOWebSocketError.multiByteControlFrameLength
            }
        case .idle, .firstByteReceived, .waitingForLengthWord, .waitingForLengthQWord:
            // No validation necessary in this state as we have no length to validate.
            break
        }
    }
}

/// An inbound `ChannelHandler` that deserializes websocket frames into a structured
/// format for further processing.
///
/// This decoder has limited enforcement of compliance to RFC 6455. In particular, to guarantee
/// that the decoder can handle arbitrary extensions, only normative MUST/MUST NOTs that do not
/// relate to extensions (e.g. the requirement that control frames not have lengths larger than
/// 125 bytes) are enforced by this decoder.
///
/// This decoder does not have any support for decoding extensions. If you wish to support
/// extensions, you should implement a message-to-message decoder that performs the appropriate
/// frame transformation as needed. All the frame data is assumed to be application data by this
/// parser.
public final class WebSocketFrameDecoder: ByteToMessageDecoder {
    public typealias InboundIn = ByteBuffer
    public typealias InboundOut = WebSocketFrame
    public typealias OutboundOut = WebSocketFrame

    /// The maximum frame size the decoder is willing to tolerate from the remote peer.
    /* private but tests */ let maxFrameSize: Int

    /// Our parser state.
    private var parser = WSParser()

    /// Construct a new `WebSocketFrameDecoder`
    ///
    /// - parameters:
    ///     - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the
    ///         remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but
    ///         this is an objectively unreasonable maximum value (on AMD64 systems it is not
    ///         possible to even allocate a buffer large enough to handle this size), so we
    ///         set a lower one. The default value is the same as the default HTTP/2 max frame
    ///         size, `2**14` bytes. Users may override this to any value up to `UInt32.max`.
    ///         Users are strongly encouraged not to increase this value unless they absolutely
    ///         must, as the decoder will not produce partial frames, meaning that it will hold
    ///         on to data until the *entire* body is received.
    ///     - automaticErrorHandling: Whether this `ChannelHandler` should automatically handle
    ///         protocol errors in frame serialization, or whether it should allow the pipeline
    ///         to handle them.
    public init(maxFrameSize: Int = 1 << 14) {
        precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size")
        self.maxFrameSize = maxFrameSize
    }

    public func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState  {
        // Even though the calling code will loop around calling us in `decode`, we can't quite
        // rely on that: sometimes we have zero-length elements to parse, and the caller doesn't
        // guarantee to call us with zero-length bytes.
        while true {
            switch parser.parseStep(&buffer) {
            case .result(let frame):
                context.fireChannelRead(self.wrapInboundOut(frame))
                return .continue
            case .continueParsing:
                try self.parser.validateState(maxFrameSize: self.maxFrameSize)
                // loop again, might be 'waiting' for 0 bytes
            case .insufficientData:
                return .needMoreData
            }
        }
    }
}