File: test_ssl.py

package info (click to toggle)
geventhttpclient 2.3.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,456 kB
  • sloc: ansic: 16,557; python: 3,823; makefile: 24
file content (334 lines) | stat: -rw-r--r-- 10,258 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import os
import ssl
from unittest.mock import patch, MagicMock
from contextlib import contextmanager
from ssl import CertificateError
from unittest import mock

import dpkt.ssl
import gevent.queue
import gevent.server
import gevent.socket
import gevent.ssl
import pytest
from gevent import joinall
from gevent.socket import error as socket_error

from geventhttpclient import HTTPClient
from tests.common import LISTENER

BASEDIR = os.path.dirname(__file__)
KEY = os.path.join(BASEDIR, "server.key")
CERT = os.path.join(BASEDIR, "server.crt")


@contextmanager
def sslserver(handler, backlog=1):
    exception_queue = gevent.queue.Queue()

    def wrapped_handler(env, start_response):
        try:
            return handler(env, start_response)
        except Exception as e:
            exception_queue.put(e)
            raise

    server = gevent.server.StreamServer(
        LISTENER,
        backlog=backlog,
        handle=wrapped_handler,
        keyfile=KEY,
        certfile=CERT,
        ssl_version=ssl.PROTOCOL_TLS_SERVER,
    )
    server.start()
    try:
        yield server.server_host, server.server_port
        if not exception_queue.empty():
            raise exception_queue.get()
    finally:
        server.stop()


@contextmanager
def timeout_connect_server():
    sock = gevent.socket.socket(gevent.socket.AF_INET, gevent.socket.SOCK_STREAM, 0)
    sock = gevent.ssl.wrap_socket(
        sock, keyfile=KEY, certfile=CERT, ssl_version=ssl.PROTOCOL_TLS_SERVER
    )
    sock.setsockopt(gevent.socket.SOL_SOCKET, gevent.socket.SO_REUSEADDR, 1)
    sock.bind(("localhost", 0))
    sock.listen(1)

    def run(sock):
        conns = []
        while True:
            conn, addr = sock.accept()
            conns.append(conns)
            conn.recv(1024)
            gevent.sleep(10)

    job = gevent.spawn(run, sock)
    try:
        yield sock.getsockname()
        sock.close()
    finally:
        job.kill()


def simple_ssl_response(sock, addr):
    sock.recv(1024)
    sock.sendall(b"HTTP/1.1 200 Ok\r\nConnection: close\r\n\r\n")
    sock.close()


def test_simple_ssl():
    with sslserver(simple_ssl_response) as listener:
        client = HTTPClient(*listener, insecure=True, ssl=True, ssl_options={"ca_certs": CERT})
        response = client.get("/")
        assert response.status_code == 200
        response.read()


def timeout_on_connect(sock, addr):
    sock.recv(1024)
    sock.sendall(b"HTTP/1.1 200 Ok\r\nContent-Length: 0\r\n\r\n")


def test_implicit_sni_from_host_in_ssl():
    server_host, server_port, sent_sni = _get_sni_sent_from_client()
    assert sent_sni == server_host


def test_implicit_sni_from_header_in_ssl():
    server_host, server_port, sent_sni = _get_sni_sent_from_client(
        headers={"host": "ololo_special_host"},
    )
    assert sent_sni == "ololo_special_host"


def test_explicit_sni_in_ssl():
    server_host, server_port, sent_sni = _get_sni_sent_from_client(
        ssl_options={"server_hostname": "test_sni"},
        headers={"host": "ololo_special_host"},
    )
    assert sent_sni == "test_sni"


def _get_sni_sent_from_client(**additional_client_args):
    with sni_checker_server() as ctx:
        server_sock, server_greenlet = ctx
        server_addr, server_port = server_sock.getsockname()[:2]

        mock_addrinfo = (
            gevent.socket.AF_INET,
            gevent.socket.SOCK_STREAM,
            gevent.socket.IPPROTO_TCP,
            "localhost",
            ("127.0.0.1", server_port),
        )
        with mock.patch("gevent.socket.getaddrinfo", mock.Mock(return_value=[mock_addrinfo])):
            server_host = "some_foo"
            client = HTTPClient(
                server_host,
                server_port,
                insecure=True,
                ssl=True,
                connection_timeout=0.1,
                ssl_context_factory=gevent.ssl.create_default_context,
                **additional_client_args,
            )

            def run(http):
                try:
                    http.get("/")
                except socket_error:
                    pass  # handshake will not be completed

            client_greenlet = gevent.spawn(run, client)
            joinall([client_greenlet, server_greenlet])

    return server_host, server_port, server_greenlet.value


@contextmanager
def sni_checker_server():
    sock = gevent.socket.socket(gevent.socket.AF_INET, gevent.socket.SOCK_STREAM, 0)
    sock.setsockopt(gevent.socket.SOL_SOCKET, gevent.socket.SO_REUSEADDR, 1)
    sock.bind(("localhost", 0))
    sock.listen(1)

    # @cyberw 2021-07-10: seems this doesn't exist any more, hope it doesn't make any difference
    # sock.last_seen_sni = None

    def run(sock):
        while True:
            conn, addr = sock.accept()
            client_hello = conn.recv(4096)
            return extract_sni_from_client_hello(client_hello)

    def extract_sni_from_client_hello(hello_packet):
        records, bytes_used = dpkt.ssl.tls_multi_factory(hello_packet)

        for record in records:
            # TLS handshake only
            if record.type != 22:
                continue

            if len(record.data) == 0:
                continue
            # Client Hello only
            if record.data[0] not in (1, chr(1)):
                continue

            handshake = dpkt.ssl.TLSHandshake(record.data)

            ch = handshake.data

            SNI_extension = [
                ext_data
                for (ext_type, ext_data) in ch.extensions
                if ext_type == 0x0  # server_name
            ]
            if SNI_extension:
                SNI_extension = SNI_extension[0]
                sni_list, _ = dpkt.ssl.parse_variable_array(SNI_extension, 2)
                sni_list = sni_list[1:]  # skip SNI entry type
                first_entry, _ = dpkt.ssl.parse_variable_array(sni_list, 2)

                return first_entry.decode()

    job = gevent.spawn(run, sock)
    try:
        yield sock, job
        if job.exception:
            raise job.exception
        sock.close()
    finally:
        job.kill()


def test_timeout_on_connect():
    with timeout_connect_server() as listener:
        client = HTTPClient(*listener, insecure=True, ssl=True, ssl_options={"ca_certs": CERT})

        def run(http, wait_time=100):
            try:
                response = http.get("/")
                gevent.sleep(wait_time)
                response.read()
            except Exception:
                pass

        gevent.spawn(run, client)
        gevent.sleep(0)

        e = None
        try:
            http2 = HTTPClient(
                *listener,
                insecure=True,
                ssl=True,
                connection_timeout=0.1,
                ssl_options={"ca_certs": CERT},
            )
            http2.get("/")
        except gevent.ssl.SSLError as error:
            e = error
        except gevent.socket.timeout as error:
            e = error
        except:
            raise

        assert e is not None, "should have raised"
        if isinstance(e, gevent.ssl.SSLError):
            assert "operation timed out" in str(e)


def network_timeout(sock, addr):
    sock.recv(1024)
    gevent.sleep(10)
    sock.sendall(b"HTTP/1.1 200 Ok\r\nContent-Length: 0\r\n\r\n")


def test_network_timeout():
    with sslserver(network_timeout) as listener:
        client = HTTPClient(
            *listener,
            ssl=True,
            insecure=True,
            network_timeout=0.1,
            ssl_options={"ca_certs": CERT},
        )
        with pytest.raises(gevent.socket.timeout):
            client.get("/")


def check_client_cert_required(client):
    """Make sure hostnames are checked by default."""
    ssl_context = client._connection_pool.ssl_context
    assert ssl_context.check_hostname
    assert ssl_context.verify_mode == gevent.ssl.CERT_REQUIRED
    for socket in client._connection_pool._socket_queue.queue:
        assert socket._context.verify_mode == gevent.ssl.CERT_REQUIRED


def test_verify_self_signed_fail(capsys):
    with sslserver(simple_ssl_response) as listener:
        client = HTTPClient(*listener, ssl=True)
        with pytest.raises(CertificateError) as raised:
            client.get("/")
        assert "CERTIFICATE_VERIFY_FAILED" in str(raised.value)
        assert raised.value.verify_message == "self-signed certificate"
        check_client_cert_required(client)
        client.close()

    # This tests breaking server side socket confusingly prints its certificate error message delayed
    # into other tests output, if we don't give it a split second for printing now.
    gevent.sleep(0.01)
    captured = capsys.readouterr().err
    assert "ssl.SSLError" in captured
    assert "ALERT_UNKNOWN_CA" in captured


@patch("ssl.create_default_context")
def test_ssl_context_cert_and_keyfile(mock_create_default_context):
    mock_ssl_context = MagicMock()
    mock_create_default_context.return_value = mock_ssl_context

    ssl_options = {
        "certfile": "/path/to/certfile.pem",
        "keyfile": "/path/to/keyfile.pem",
        "ca_certs": "/path/to/ca-certificates.crt",
    }
    http_client = HTTPClient(
        "github.com", ssl_context_factory=ssl.create_default_context, ssl_options=ssl_options
    )

    mock_create_default_context.assert_called_once_with(cafile=ssl_options["ca_certs"])
    mock_ssl_context.load_cert_chain.assert_called_once_with(
        certfile=ssl_options["certfile"], keyfile=ssl_options["keyfile"]
    )
    assert isinstance(http_client, HTTPClient)


@pytest.mark.network
def test_client_ssl():
    client = HTTPClient("github.com", ssl=True)
    assert client.port == 443
    response = client.get("/")
    assert response.status_code == 200
    body = response.read()
    assert len(body)
    check_client_cert_required(client)


@pytest.mark.network
def test_fail_invalid_ca_certificate():
    certs = os.path.join(os.path.dirname(os.path.abspath(__file__)), "oncert.pem")
    client = HTTPClient("github.com", ssl_options={"ca_certs": certs})
    assert client.port == 443
    with pytest.raises(gevent.ssl.SSLError) as e_info:
        client.get("/")
    assert e_info.value.reason == "CERTIFICATE_VERIFY_FAILED"
    check_client_cert_required(client)