File: test_version_negotiation.py

package info (click to toggle)
aws-crt-python 0.28.4%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 78,428 kB
  • sloc: ansic: 437,955; python: 27,657; makefile: 5,855; sh: 4,289; ruby: 208; java: 82; perl: 73; cpp: 25; xml: 11
file content (203 lines) | stat: -rw-r--r-- 7,308 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import pytest

from configuration import (
    available_ports,
    ALL_TEST_CIPHERS,
    ALL_TEST_CURVES,
    MINIMAL_TEST_CERTS,
)
from common import ProviderOptions, Protocols, data_bytes
from fixtures import managed_process  # noqa: F401
from providers import Provider, S2N, OpenSSL, GnuTLS
from utils import (
    invalid_test_parameters,
    get_parameter_name,
    get_expected_s2n_version,
    get_expected_openssl_version,
    to_bytes,
    get_expected_gnutls_version,
)


def test_nothing():
    """
    Sometimes the version negotiation test parameters in combination with the s2n
    libcrypto results in no test cases existing. In this case, pass a nothing test to
    avoid marking the entire codebuild run as failed.
    """
    assert True


def invalid_version_negotiation_test_parameters(*args, **kwargs):
    # Since s2nd/s2nc will always be using TLS 1.3, make sure the libcrypto is compatible
    if invalid_test_parameters(**{"provider": S2N, "protocol": Protocols.TLS13}):
        return True

    return invalid_test_parameters(*args, **kwargs)


@pytest.mark.flaky(reruns=5, reruns_delay=2)
@pytest.mark.uncollect_if(func=invalid_version_negotiation_test_parameters)
@pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize(
    "protocol",
    [Protocols.TLS12, Protocols.TLS11, Protocols.TLS10],
    ids=get_parameter_name,
)
@pytest.mark.parametrize("provider", [S2N, OpenSSL, GnuTLS], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
def test_s2nc_tls13_negotiates_tls12(
    managed_process,  # noqa: F811
    cipher,
    curve,
    certificate,
    protocol,
    provider,
    other_provider,
):
    port = next(available_ports)

    random_bytes = data_bytes(24)
    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        data_to_send=random_bytes,
        insecure=True,
        protocol=Protocols.TLS13,
    )

    server_options = copy.copy(client_options)
    server_options.data_to_send = None
    server_options.mode = Provider.ServerMode
    server_options.key = certificate.key
    server_options.cert = certificate.cert
    server_options.protocol = protocol

    kill_marker = None
    if provider == GnuTLS:
        kill_marker = random_bytes

    server = managed_process(
        provider, server_options, timeout=5, kill_marker=kill_marker
    )
    client = managed_process(S2N, client_options, timeout=5)

    client_version = get_expected_s2n_version(Protocols.TLS13, provider)
    actual_version = get_expected_s2n_version(protocol, provider)

    for results in client.get_results():
        results.assert_success()
        assert (
            to_bytes("Client protocol version: {}".format(client_version))
            in results.stdout
        )
        assert (
            to_bytes("Actual protocol version: {}".format(actual_version))
            in results.stdout
        )

    for results in server.get_results():
        results.assert_success()
        # This check only cares about S2N. Trying to maintain expected output of other providers doesn't add benefit to
        # whether the S2N client was able to negotiate a lower TLS version.
        if provider is S2N:
            # The client sends a TLS 1.3 client hello so a client protocol version of TLS 1.3 should always be expected.
            assert (
                to_bytes("Client protocol version: {}".format(Protocols.TLS13.value))
                in results.stdout
            )
            assert (
                to_bytes("Actual protocol version: {}".format(actual_version))
                in results.stdout
            )

        assert any([random_bytes[1:] in stream for stream in results.output_streams()])


@pytest.mark.uncollect_if(func=invalid_version_negotiation_test_parameters)
@pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize(
    "protocol",
    [Protocols.TLS12, Protocols.TLS11, Protocols.TLS10],
    ids=get_parameter_name,
)
@pytest.mark.parametrize("provider", [S2N, OpenSSL, GnuTLS], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
def test_s2nd_tls13_negotiates_tls12(
    managed_process,  # noqa: F811
    cipher,
    curve,
    certificate,
    protocol,
    provider,
    other_provider,
):
    port = next(available_ports)

    random_bytes = data_bytes(24)
    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        data_to_send=random_bytes,
        insecure=True,
        protocol=protocol,
    )

    server_options = copy.copy(client_options)
    server_options.data_to_send = None
    server_options.mode = Provider.ServerMode
    server_options.key = certificate.key
    server_options.cert = certificate.cert
    # When the protocol is set to TLS13, the s2n server provider will default to using
    # all ciphers, not just the TLS13 ciphers. This is the desired behavior for this test.
    server_options.protocol = Protocols.TLS13

    server = managed_process(S2N, server_options, timeout=5)
    client = managed_process(provider, client_options, timeout=5)

    server_version = get_expected_s2n_version(Protocols.TLS13, provider)
    actual_version = get_expected_s2n_version(protocol, provider)

    for results in client.get_results():
        results.assert_success()
        if provider is S2N:
            # The client will get the server version from the SERVER HELLO, which will be the negotiated version
            assert (
                to_bytes("Server protocol version: {}".format(actual_version))
                in results.stdout
            )
            assert (
                to_bytes("Actual protocol version: {}".format(actual_version))
                in results.stdout
            )
        elif provider is OpenSSL:
            # This check cares about other providers because we want to know that they did negotiate the version
            # that our S2N server intended to negotiate.
            openssl_version = get_expected_openssl_version(protocol)
            assert to_bytes("Protocol  : {}".format(openssl_version)) in results.stdout
        elif provider is GnuTLS:
            gnutls_version = get_expected_gnutls_version(protocol)
            assert to_bytes(f"Version: {gnutls_version}") in results.stdout

    for results in server.get_results():
        results.assert_success()
        assert (
            to_bytes("Server protocol version: {}".format(server_version))
            in results.stdout
        )
        assert (
            to_bytes("Actual protocol version: {}".format(actual_version))
            in results.stdout
        )
        assert random_bytes[1:] in results.stdout