File: NIOHTTPObjectAggregator.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (400 lines) | stat: -rw-r--r-- 16,226 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2020 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import NIO

/// The parts of a complete HTTP response from the view of the client.
///
/// A full HTTP request is made up of a response header encoded by `.head`
/// and an optional `.body`.
public struct NIOHTTPServerRequestFull {
    public var head: HTTPRequestHead
    public var body: ByteBuffer?

    public init(head: HTTPRequestHead, body: ByteBuffer?) {
        self.head = head
        self.body = body
    }
}

extension NIOHTTPServerRequestFull: Equatable {}

/// The parts of a complete HTTP response from the view of the client.
///
/// Afull HTTP response is made up of a response header encoded by `.head`
/// and an optional `.body`.
public struct NIOHTTPClientResponseFull {
    public var head: HTTPResponseHead
    public var body: ByteBuffer?

    public init(head: HTTPResponseHead, body: ByteBuffer?) {
        self.head = head
        self.body = body
    }
}

extension NIOHTTPClientResponseFull: Equatable {}

public struct NIOHTTPObjectAggregatorError: Error, Equatable {
    private enum Base {
        case frameTooLong
        case connectionClosed
        case endingIgnoredMessage
        case unexpectedMessageHead
        case unexpectedMessageBody
        case unexpectedMessageEnd
    }

    private var base: Base

    private init(base: Base) {
        self.base = base
    }

    public static let frameTooLong = NIOHTTPObjectAggregatorError(base: .frameTooLong)
    public static let connectionClosed = NIOHTTPObjectAggregatorError(base: .connectionClosed)
    public static let endingIgnoredMessage = NIOHTTPObjectAggregatorError(base: .endingIgnoredMessage)
    public static let unexpectedMessageHead = NIOHTTPObjectAggregatorError(base: .unexpectedMessageHead)
    public static let unexpectedMessageBody = NIOHTTPObjectAggregatorError(base: .unexpectedMessageBody)
    public static let unexpectedMessageEnd = NIOHTTPObjectAggregatorError(base: .unexpectedMessageEnd)
}

public struct NIOHTTPObjectAggregatorEvent: Hashable {
    private enum Base {
        case httpExpectationFailed
        case httpFrameTooLong
    }

    private var base: Base

    private init(base: Base) {
        self.base = base
    }

    public static let httpExpectationFailed = NIOHTTPObjectAggregatorEvent(base: .httpExpectationFailed)
    public static let httpFrameTooLong = NIOHTTPObjectAggregatorEvent(base: .httpFrameTooLong)
}

/// The state of the aggregator  connection.
internal enum AggregatorState {
    /// Nothing is active on this connection, the next message we expect would be a request `.head`.
    case idle

    /// Ill-behaving client may be sending content that is too large
    case ignoringContent

    /// We are receiving and aggregating a request
    case receiving

    /// Connection should be closed
    case closed

    mutating func messageHeadReceived() throws {
        switch self {
        case .idle:
            self = .receiving
        case .ignoringContent, .receiving:
            throw NIOHTTPObjectAggregatorError.unexpectedMessageHead
        case .closed:
            throw NIOHTTPObjectAggregatorError.connectionClosed
        }
    }

    mutating func messageBodyReceived() throws {
        switch self {
        case .receiving:
            ()
        case .ignoringContent:
            throw NIOHTTPObjectAggregatorError.frameTooLong
        case .idle:
            throw NIOHTTPObjectAggregatorError.unexpectedMessageBody
        case .closed:
            throw NIOHTTPObjectAggregatorError.connectionClosed
        }
    }


    mutating func messageEndReceived() throws {
        switch self {
        case .receiving:
            // Got the request end we were waiting for.
            self = .idle
        case .ignoringContent:
            // Expected transition from a state where message contents are getting
            // ignored because the message is too large. Throwing an error prevents
            // the normal control flow from continuing into dispatching the completed
            // invalid message to the next handler.
            self = .idle
            throw NIOHTTPObjectAggregatorError.endingIgnoredMessage
        case .idle:
            throw NIOHTTPObjectAggregatorError.unexpectedMessageEnd
        case .closed:
            throw NIOHTTPObjectAggregatorError.connectionClosed
        }
    }

    mutating func handlingOversizeMessage() {
        switch self {
        case .receiving, .idle:
            self = .ignoringContent
        case .ignoringContent, .closed:
            // If we are already ignoring content or connection is closed, should not get here
            preconditionFailure("Unreachable state: should never handle overized message in \(self)")
        }
    }

    mutating func closed() {
        self = .closed
    }
}

/// A `ChannelInboundHandler` that handles HTTP chunked `HTTPServerRequestPart`
/// messages by aggregating individual message chunks into a single
/// `NIOHTTPServerRequestFull`.
///
/// This is achieved by buffering the contents of all received `HTTPServerRequestPart`
/// messages until `HTTPServerRequestPart.end` is received, then assembling the
/// full message and firing a channel read upstream with it. It is useful for when you do not
/// want to deal with chunked messages and just want to receive everything at once, and
/// are happy with the additional memory used and delay handling of the message until
/// everything has been received.
///
/// `NIOHTTPServerRequestAggregator` may end up sending a `HTTPResponseHead`:
/// - Response status `413 Request Entity Too Large` when either the
///     `content-length` or the bytes received so far exceed `maxContentLength`.
///
/// `NIOHTTPServerRequestAggregator` may close the connection if it is impossible
/// to recover:
/// - If `content-length` is too large and `keep-alive` is off.
/// - If the bytes received exceed `maxContentLength` and the client didn't signal
///     `content-length`
public final class NIOHTTPServerRequestAggregator: ChannelInboundHandler, RemovableChannelHandler {
    public typealias InboundIn = HTTPServerRequestPart
    public typealias InboundOut = NIOHTTPServerRequestFull

    // Aggregator may generate responses of its own
    public typealias OutboundOut = HTTPServerResponsePart

    private var fullMessageHead: HTTPRequestHead? = nil
    private var buffer: ByteBuffer! = nil
    private var maxContentLength: Int
    private var closeOnExpectationFailed: Bool
    private var state: AggregatorState
    
    public init(maxContentLength: Int, closeOnExpectationFailed: Bool = false) {
        precondition(maxContentLength >= 0, "maxContentLength must not be negative")
        self.maxContentLength = maxContentLength
        self.closeOnExpectationFailed = closeOnExpectationFailed
        self.state = .idle
    }

    public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
        let msg = self.unwrapInboundIn(data)
        var serverResponse: HTTPResponseHead? = nil

        do {
            switch msg {
            case .head(let httpHead):
                try self.state.messageHeadReceived()
                serverResponse = self.beginAggregation(context: context, request: httpHead, message: msg)
            case .body(var content):
                try self.state.messageBodyReceived()
                serverResponse = self.aggregate(context: context, content: &content, message: msg)
            case .end(let trailingHeaders):
                try self.state.messageEndReceived()
                self.endAggregation(context: context, trailingHeaders: trailingHeaders)
            }
        } catch let error as NIOHTTPObjectAggregatorError {
            context.fireErrorCaught(error)
            // Won't be able to complete those
            self.fullMessageHead = nil
            self.buffer.clear()
        } catch let error {
            context.fireErrorCaught(error)
        }

        // Generated a server esponse to send back
        if let response = serverResponse {
            context.write(self.wrapOutboundOut(.head(response)), promise: nil)
            context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
            if response.status == .payloadTooLarge {
                // If indicated content length is too large
                self.state.handlingOversizeMessage()
                context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
                context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
            }
            if !response.headers.isKeepAlive(version: response.version) {
                context.close(promise: nil)
                self.state.closed()
            }
        }
    }

    private func beginAggregation(context: ChannelHandlerContext, request: HTTPRequestHead, message: InboundIn) -> HTTPResponseHead? {
        self.fullMessageHead = request
        if let contentLength = request.contentLength, contentLength > self.maxContentLength {
            return self.handleOversizeMessage(message: message)
        }
        return nil
    }

    private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer, message: InboundIn) -> HTTPResponseHead? {
        if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) {
            return self.handleOversizeMessage(message: message)
        } else {
            self.buffer.writeBuffer(&content)
            return nil
        }
    }

    private func endAggregation(context: ChannelHandlerContext, trailingHeaders: HTTPHeaders?) {
        if var aggregated = self.fullMessageHead {
            // Remove `Trailer` from existing header fields and append trailer fields to existing header fields
            // See rfc7230 4.1.3 Decoding Chunked
            if let headers = trailingHeaders {
                aggregated.headers.remove(name: "trailer")
                aggregated.headers.add(contentsOf: headers)
            }

            let fullMessage = NIOHTTPServerRequestFull(head: aggregated,
                                                    body: self.buffer.readableBytes > 0 ? self.buffer : nil)
            self.fullMessageHead = nil
            self.buffer.clear()
            context.fireChannelRead(NIOAny(fullMessage))
        }
    }

    private func handleOversizeMessage(message: InboundIn) -> HTTPResponseHead {
        var payloadTooLargeHead = HTTPResponseHead(
            version: self.fullMessageHead?.version ?? .http1_1,
            status: .payloadTooLarge,
            headers: HTTPHeaders([("content-length", "0")]))

        switch message {
        case .head(let request):
            if !request.isKeepAlive {
                // If keep-alive is off and, no need to leave the connection open.
                // Send back a 413 and close the connection.
                payloadTooLargeHead.headers.add(name: "connection", value: "close")
            }
        default:
            // The client started to send data already, close because it's impossible to recover.
            // Send back a 413 and close the connection.
            payloadTooLargeHead.headers.add(name: "connection", value: "close")
        }

        return payloadTooLargeHead
    }

    public func handlerAdded(context: ChannelHandlerContext) {
        self.buffer = context.channel.allocator.buffer(capacity: 0)
    }
}

/// A `ChannelInboundHandler` that handles HTTP chunked `HTTPClientResponsePart`
/// messages by aggregating individual message chunks into a single
/// `NIOHTTPClientResponseFull`.
///
/// This is achieved by buffering the contents of all received `HTTPClientResponsePart`
/// messages until `HTTPClientResponsePart.end` is received, then assembling the
/// full message and firing a channel read upstream with it. Useful when you do not
/// want to deal with chunked messages and just want to receive everything at once, and
/// are happy with the additional memory used and delay handling of the message until
/// everything has been received.
///
/// If `NIOHTTPClientResponseAggregator` encounters a message larger than
/// `maxContentLength`, it discards the aggregated contents until the next
/// `HTTPClientResponsePart.end` and signals that via
/// `fireUserInboundEventTriggered`.
public final class NIOHTTPClientResponseAggregator: ChannelInboundHandler, RemovableChannelHandler {
    public typealias InboundIn = HTTPClientResponsePart
    public typealias InboundOut = NIOHTTPClientResponseFull

    private var fullMessageHead: HTTPResponseHead? = nil
    private var buffer: ByteBuffer! = nil
    private var maxContentLength: Int
    private var state: AggregatorState

    public init(maxContentLength: Int) {
        precondition(maxContentLength >= 0, "maxContentLength must not be negative")
        self.maxContentLength = maxContentLength
        self.state = .idle
    }

    public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
        let msg = self.unwrapInboundIn(data)

        do {
            switch msg {
            case .head(let httpHead):
                try self.state.messageHeadReceived()
                try self.beginAggregation(context: context, response: httpHead)
            case .body(var content):
                try self.state.messageBodyReceived()
                try self.aggregate(context: context, content: &content)
            case .end(let trailingHeaders):
                try self.state.messageEndReceived()
                self.endAggregation(context: context, trailingHeaders: trailingHeaders)
            }
        } catch let error as NIOHTTPObjectAggregatorError {
            context.fireErrorCaught(error)
            // Won't be able to complete those
            self.fullMessageHead = nil
            self.buffer.clear()
        } catch let error {
            context.fireErrorCaught(error)
        }
    }

    private func beginAggregation(context: ChannelHandlerContext, response: HTTPResponseHead) throws {
        self.fullMessageHead = response
        if let contentLength = response.contentLength, contentLength > self.maxContentLength {
            self.state.handlingOversizeMessage()
            context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
            context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
        }
    }

    private func aggregate(context: ChannelHandlerContext, content: inout ByteBuffer) throws {
        if (content.readableBytes > self.maxContentLength - self.buffer.readableBytes) {
            self.state.handlingOversizeMessage()
            context.fireUserInboundEventTriggered(NIOHTTPObjectAggregatorEvent.httpFrameTooLong)
            context.fireErrorCaught(NIOHTTPObjectAggregatorError.frameTooLong)
        } else {
            self.buffer.writeBuffer(&content)
        }
    }

    private func endAggregation(context: ChannelHandlerContext, trailingHeaders: HTTPHeaders?) {
        if var aggregated = self.fullMessageHead {
            // Remove `Trailer` from existing header fields and append trailer fields to existing header fields
            // See rfc7230 4.1.3 Decoding Chunked
            if let headers = trailingHeaders {
                aggregated.headers.remove(name: "trailer")
                aggregated.headers.add(contentsOf: headers)
            }

            let fullMessage = NIOHTTPClientResponseFull(
                head: aggregated,
                body: self.buffer.readableBytes > 0 ? self.buffer : nil)
            self.fullMessageHead = nil
            self.buffer.clear()
            context.fireChannelRead(NIOAny(fullMessage))
        }
    }

    public func handlerAdded(context: ChannelHandlerContext) {
        self.buffer = context.channel.allocator.buffer(capacity: 0)
    }
}