File: NIOWebSocketFrameAggregator.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (138 lines) | stat: -rw-r--r-- 5,949 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2021 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import NIO


/// `NIOWebSocketFrameAggregator` buffers inbound fragmented `WebSocketFrame`'s and aggregates them into a single `WebSocketFrame`.
/// It guarantees that a `WebSocketFrame` with an `opcode` of `.continuation` is never forwarded.
/// Frames which are not fragmented are just forwarded without any processing.
/// Fragmented frames are unmasked, concatenated and forwarded as a new `WebSocketFrame` which is either a `.binary` or `.text` frame.
/// `extensionData`, `rsv1`, `rsv2` and `rsv3` are lost if a frame is fragmented because they cannot be concatenated.
/// - Note: `.ping`, `.pong`, `.closeConnection` frames are forwarded during frame aggregation
public final class NIOWebSocketFrameAggregator: ChannelInboundHandler {
    public enum Error: Swift.Error {
        case nonFinalFragmentSizeIsTooSmall
        case tooManyFragments
        case accumulatedFrameSizeIsTooLarge
        case receivedNewFrameWithoutFinishingPrevious
        case didReceiveFragmentBeforeReceivingTextOrBinaryFrame
    }
    public typealias InboundIn = WebSocketFrame
    public typealias InboundOut = WebSocketFrame
    
    private let minNonFinalFragmentSize: Int
    private let maxAccumulatedFrameCount: Int
    private let maxAccumulatedFrameSize: Int
    
    private var bufferedFrames: [WebSocketFrame] = []
    private var accumulatedFrameSize: Int = 0
    
    
    /// Configures a `NIOWebSocketFrameAggregator`.
    /// - Parameters:
    ///   - minNonFinalFragmentSize: Minimum size in bytes of a fragment which is not the last fragment of a complete frame. Used to defend agains many really small payloads.
    ///   - maxAccumulatedFrameCount: Maximum number of fragments which are allowed to result in a complete frame.
    ///   - maxAccumulatedFrameSize: Maximum accumulated size in bytes of buffered fragments. It is essentially the maximum allowed size of an incoming frame after all fragments are concatenated.
    public init(
        minNonFinalFragmentSize: Int,
        maxAccumulatedFrameCount: Int,
        maxAccumulatedFrameSize: Int
    ) {
        self.minNonFinalFragmentSize = minNonFinalFragmentSize
        self.maxAccumulatedFrameCount = maxAccumulatedFrameCount
        self.maxAccumulatedFrameSize = maxAccumulatedFrameSize
    }
    
    
    public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
        let frame = unwrapInboundIn(data)
        do {
            switch frame.opcode {
            case .continuation:
                guard let firstFrameOpcode = self.bufferedFrames.first?.opcode else {
                    throw Error.didReceiveFragmentBeforeReceivingTextOrBinaryFrame
                }
                try self.bufferFrame(frame)
                
                guard frame.fin else { break }
                // final frame received
                
                let aggregatedFrame = self.aggregateFrames(
                    opcode: firstFrameOpcode,
                    allocator: context.channel.allocator
                )
                self.clearBuffer()
                
                context.fireChannelRead(wrapInboundOut(aggregatedFrame))
            case .binary, .text:
                if frame.fin {
                    guard self.bufferedFrames.isEmpty else {
                        throw Error.receivedNewFrameWithoutFinishingPrevious
                    }
                    // fast path: no need to check any constraints nor unmask and copy data
                    context.fireChannelRead(data)
                } else {
                    try self.bufferFrame(frame)
                }
            default:
                // control frames can't be fragmented
                context.fireChannelRead(data)
            }
        } catch {
            // free memory early
            self.clearBuffer()
            context.fireErrorCaught(error)
        }
    }
    
    private func bufferFrame(_ frame: WebSocketFrame) throws {
        guard self.bufferedFrames.isEmpty || frame.opcode == .continuation else {
            throw Error.receivedNewFrameWithoutFinishingPrevious
        }
        guard frame.fin || frame.length >= self.minNonFinalFragmentSize else {
            throw Error.nonFinalFragmentSizeIsTooSmall
        }
        guard self.bufferedFrames.count < self.maxAccumulatedFrameCount else {
            throw Error.tooManyFragments
        }
        
        // if this is not a final frame, we will at least receive one more frame
        guard frame.fin || (self.bufferedFrames.count + 1) < self.maxAccumulatedFrameCount else {
            throw Error.tooManyFragments
        }
        
        self.bufferedFrames.append(frame)
        self.accumulatedFrameSize += frame.length
        
        guard self.accumulatedFrameSize <= self.maxAccumulatedFrameSize else {
            throw Error.accumulatedFrameSizeIsTooLarge
        }
    }
    
    private func aggregateFrames(opcode: WebSocketOpcode, allocator: ByteBufferAllocator) -> WebSocketFrame {
        var dataBuffer = allocator.buffer(capacity: self.accumulatedFrameSize)
        
        for frame in self.bufferedFrames {
            var unmaskedData = frame.unmaskedData
            dataBuffer.writeBuffer(&unmaskedData)
        }
        
        return WebSocketFrame(fin: true, opcode: opcode, data: dataBuffer)
    }
    
    private func clearBuffer() {
        self.bufferedFrames.removeAll(keepingCapacity: true)
        self.accumulatedFrameSize = 0
    }
}