File: HTTPServerProtocolErrorHandlerTest.swift

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (155 lines) | stat: -rw-r--r-- 6,483 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import XCTest
import NIO
import NIOHTTP1

class HTTPServerProtocolErrorHandlerTest: XCTestCase {
    func testHandlesBasicErrors() throws {
        class CloseOnHTTPErrorHandler: ChannelInboundHandler {
            typealias InboundIn = Never

            func errorCaught(context: ChannelHandlerContext, error: Error) {
                if let error = error as? HTTPParserError {
                    context.fireErrorCaught(error)
                    context.close(promise: nil)
                }
            }
        }
        let channel = EmbeddedChannel()
        XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).wait())
        XCTAssertNoThrow(try channel.pipeline.addHandler(CloseOnHTTPErrorHandler()).wait())

        var buffer = channel.allocator.buffer(capacity: 1024)
        buffer.writeStaticString("GET / HTTP/1.1\r\nContent-Length: -4\r\n\r\n")
        do {
            try channel.writeInbound(buffer)
        } catch HTTPParserError.invalidContentLength {
            // This error is expected
        }
        channel.embeddedEventLoop.run()

        // The channel should be closed at this stage.
        XCTAssertNoThrow(try channel.closeFuture.wait())

        // We expect exactly one ByteBuffer in the output.
        guard var written = try channel.readOutbound(as: ByteBuffer.self) else {
            XCTFail("No writes")
            return
        }

        XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound()))

        // Check the response.
        assertResponseIs(response: written.readString(length: written.readableBytes)!,
                         expectedResponseLine: "HTTP/1.1 400 Bad Request",
                         expectedResponseHeaders: ["Connection: close", "Content-Length: 0"])
    }

    func testIgnoresNonParserErrors() throws {
        enum DummyError: Error {
            case error
        }
        let channel = EmbeddedChannel()
        XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).wait())

        channel.pipeline.fireErrorCaught(DummyError.error)
        XCTAssertThrowsError(try channel.throwIfErrorCaught()) { error in
            XCTAssertEqual(DummyError.error, error as? DummyError)
        }

        XCTAssertNoThrow(try channel.finish())
    }

    func testDoesNotSendAResponseIfResponseHasAlreadyStarted() throws {
        let channel = EmbeddedChannel()
        defer {
            XCTAssertNoThrow(try channel.finish())
        }

        XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, withErrorHandling: true).wait())
        let res = HTTPServerResponsePart.head(.init(version: .http1_1,
                                                    status: .ok,
                                                    headers: .init([("Content-Length", "0")])))
        XCTAssertNoThrow(try channel.writeAndFlush(res).wait())
        // now we have started a response but it's not complete yet, let's inject a parser error
        channel.pipeline.fireErrorCaught(HTTPParserError.invalidEOFState)
        var allOutbound = try channel.readAllOutboundBuffers()
        let allOutboundString = allOutbound.readString(length: allOutbound.readableBytes)
        // there should be no HTTP/1.1 400 or anything in here
        XCTAssertEqual("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n", allOutboundString)
        XCTAssertThrowsError(try channel.throwIfErrorCaught()) { error in
            XCTAssertEqual(.invalidEOFState, error as? HTTPParserError)
        }
    }

    func testCanHandleErrorsWhenResponseHasStarted() throws {
        enum NextExpectedState {
            case head
            case end
            case none
        }
        class DelayWriteHandler: ChannelInboundHandler {
            typealias InboundIn = HTTPServerRequestPart
            typealias OutboundOut = HTTPServerResponsePart

            private var nextExpected: NextExpectedState = .head

            func channelRead(context: ChannelHandlerContext, data: NIOAny) {
                let req = self.unwrapInboundIn(data)
                switch req {
                case .head:
                    XCTAssertEqual(.head, self.nextExpected)
                    self.nextExpected = .end
                    let res = HTTPServerResponsePart.head(.init(version: .http1_1,
                                                                status: .ok,
                                                                headers: .init([("Content-Length", "0")])))
                    context.writeAndFlush(self.wrapOutboundOut(res), promise: nil)
                default:
                    XCTAssertEqual(.end, self.nextExpected)
                    self.nextExpected = .none
                }
            }


        }
        let channel = EmbeddedChannel()
        XCTAssertNoThrow(try channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap {
            channel.pipeline.addHandler(DelayWriteHandler())
        }.wait())

        var buffer = channel.allocator.buffer(capacity: 1024)
        buffer.writeStaticString("GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\nGET / HT")
        XCTAssertNoThrow(try channel.writeInbound(buffer))
        XCTAssertNoThrow(try channel.close().wait())
        channel.embeddedEventLoop.run()

        // The channel should be closed at this stage.
        XCTAssertNoThrow(try channel.closeFuture.wait())

        // We expect exactly one ByteBuffer in the output.
        guard var written = try channel.readOutbound(as: ByteBuffer.self) else {
            XCTFail("No writes")
            return
        }

        XCTAssertNoThrow(XCTAssertNil(try channel.readOutbound()))

        // Check the response.
        assertResponseIs(response: written.readString(length: written.readableBytes)!,
                         expectedResponseLine: "HTTP/1.1 200 OK",
                         expectedResponseHeaders: ["Content-Length: 0"])
    }
}