1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO
/// `NIOWebSocketFrameAggregator` buffers inbound fragmented `WebSocketFrame`'s and aggregates them into a single `WebSocketFrame`.
/// It guarantees that a `WebSocketFrame` with an `opcode` of `.continuation` is never forwarded.
/// Frames which are not fragmented are just forwarded without any processing.
/// Fragmented frames are unmasked, concatenated and forwarded as a new `WebSocketFrame` which is either a `.binary` or `.text` frame.
/// `extensionData`, `rsv1`, `rsv2` and `rsv3` are lost if a frame is fragmented because they cannot be concatenated.
/// - Note: `.ping`, `.pong`, `.closeConnection` frames are forwarded during frame aggregation
public final class NIOWebSocketFrameAggregator: ChannelInboundHandler {
public enum Error: Swift.Error {
case nonFinalFragmentSizeIsTooSmall
case tooManyFragments
case accumulatedFrameSizeIsTooLarge
case receivedNewFrameWithoutFinishingPrevious
case didReceiveFragmentBeforeReceivingTextOrBinaryFrame
}
public typealias InboundIn = WebSocketFrame
public typealias InboundOut = WebSocketFrame
private let minNonFinalFragmentSize: Int
private let maxAccumulatedFrameCount: Int
private let maxAccumulatedFrameSize: Int
private var bufferedFrames: [WebSocketFrame] = []
private var accumulatedFrameSize: Int = 0
/// Configures a `NIOWebSocketFrameAggregator`.
/// - Parameters:
/// - minNonFinalFragmentSize: Minimum size in bytes of a fragment which is not the last fragment of a complete frame. Used to defend agains many really small payloads.
/// - maxAccumulatedFrameCount: Maximum number of fragments which are allowed to result in a complete frame.
/// - maxAccumulatedFrameSize: Maximum accumulated size in bytes of buffered fragments. It is essentially the maximum allowed size of an incoming frame after all fragments are concatenated.
public init(
minNonFinalFragmentSize: Int,
maxAccumulatedFrameCount: Int,
maxAccumulatedFrameSize: Int
) {
self.minNonFinalFragmentSize = minNonFinalFragmentSize
self.maxAccumulatedFrameCount = maxAccumulatedFrameCount
self.maxAccumulatedFrameSize = maxAccumulatedFrameSize
}
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let frame = unwrapInboundIn(data)
do {
switch frame.opcode {
case .continuation:
guard let firstFrameOpcode = self.bufferedFrames.first?.opcode else {
throw Error.didReceiveFragmentBeforeReceivingTextOrBinaryFrame
}
try self.bufferFrame(frame)
guard frame.fin else { break }
// final frame received
let aggregatedFrame = self.aggregateFrames(
opcode: firstFrameOpcode,
allocator: context.channel.allocator
)
self.clearBuffer()
context.fireChannelRead(wrapInboundOut(aggregatedFrame))
case .binary, .text:
if frame.fin {
guard self.bufferedFrames.isEmpty else {
throw Error.receivedNewFrameWithoutFinishingPrevious
}
// fast path: no need to check any constraints nor unmask and copy data
context.fireChannelRead(data)
} else {
try self.bufferFrame(frame)
}
default:
// control frames can't be fragmented
context.fireChannelRead(data)
}
} catch {
// free memory early
self.clearBuffer()
context.fireErrorCaught(error)
}
}
private func bufferFrame(_ frame: WebSocketFrame) throws {
guard self.bufferedFrames.isEmpty || frame.opcode == .continuation else {
throw Error.receivedNewFrameWithoutFinishingPrevious
}
guard frame.fin || frame.length >= self.minNonFinalFragmentSize else {
throw Error.nonFinalFragmentSizeIsTooSmall
}
guard self.bufferedFrames.count < self.maxAccumulatedFrameCount else {
throw Error.tooManyFragments
}
// if this is not a final frame, we will at least receive one more frame
guard frame.fin || (self.bufferedFrames.count + 1) < self.maxAccumulatedFrameCount else {
throw Error.tooManyFragments
}
self.bufferedFrames.append(frame)
self.accumulatedFrameSize += frame.length
guard self.accumulatedFrameSize <= self.maxAccumulatedFrameSize else {
throw Error.accumulatedFrameSizeIsTooLarge
}
}
private func aggregateFrames(opcode: WebSocketOpcode, allocator: ByteBufferAllocator) -> WebSocketFrame {
var dataBuffer = allocator.buffer(capacity: self.accumulatedFrameSize)
for frame in self.bufferedFrames {
var unmaskedData = frame.unmaskedData
dataBuffer.writeBuffer(&unmaskedData)
}
return WebSocketFrame(fin: true, opcode: opcode, data: dataBuffer)
}
private func clearBuffer() {
self.bufferedFrames.removeAll(keepingCapacity: true)
self.accumulatedFrameSize = 0
}
}
|