1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
"""Tests for kernel connection utilities"""
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
import errno
import json
import os
from tempfile import TemporaryDirectory
from typing import no_type_check
from unittest.mock import patch
import pytest
import zmq
from traitlets.config.loader import Config
from ipykernel import connect
from ipykernel.kernelapp import IPKernelApp
from .utils import TemporaryWorkingDirectory
@pytest.fixture(scope="module", autouse=True)
def _enable_tracemalloc():
try:
import tracemalloc
except ModuleNotFoundError:
# pypy
tracemalloc = None
if tracemalloc is not None:
tracemalloc.start()
yield
if tracemalloc is not None:
tracemalloc.stop()
sample_info: dict = {
"ip": "1.2.3.4",
"transport": "ipc",
"shell_port": 1,
"hb_port": 2,
"iopub_port": 3,
"stdin_port": 4,
"control_port": 5,
"key": b"abc123",
"signature_scheme": "hmac-md5",
}
class DummyKernelApp(IPKernelApp):
def _default_shell_port(self):
return 0
def initialize(self, argv=None):
self.init_profile_dir()
self.init_connection_file()
def test_get_connection_file():
cfg = Config()
with TemporaryWorkingDirectory() as d:
cfg.ProfileDir.location = d
cf = "kernel.json"
app = DummyKernelApp(config=cfg, connection_file=cf)
app.initialize()
profile_cf = os.path.join(app.connection_dir, cf)
assert profile_cf == app.abs_connection_file
with open(profile_cf, "w") as f:
f.write("{}")
assert os.path.exists(profile_cf)
assert connect.get_connection_file(app) == profile_cf
app.connection_file = cf
assert connect.get_connection_file(app) == profile_cf
def test_get_connection_info():
with TemporaryDirectory() as d:
cf = os.path.join(d, "kernel.json")
connect.write_connection_file(cf, **sample_info)
json_info = connect.get_connection_info(cf)
info = connect.get_connection_info(cf, unpack=True)
assert isinstance(json_info, str)
sub_info = {k: v for k, v in info.items() if k in sample_info}
assert sub_info == sample_info
info2 = json.loads(json_info)
info2["key"] = info2["key"].encode("utf-8")
sub_info2 = {k: v for k, v in info.items() if k in sample_info}
assert sub_info2 == sample_info
def test_port_bind_failure_raises(request):
cfg = Config()
with TemporaryWorkingDirectory() as d:
cfg.ProfileDir.location = d
cf = "kernel.json"
app = DummyKernelApp(config=cfg, connection_file=cf)
request.addfinalizer(app.close)
app.initialize()
with patch.object(app, "_try_bind_socket") as mock_try_bind:
mock_try_bind.side_effect = zmq.ZMQError(-100, "fails for unknown error types")
with pytest.raises(zmq.ZMQError):
app.init_sockets()
assert mock_try_bind.call_count == 1
@no_type_check
def test_port_bind_failure_recovery(request):
try:
errno.WSAEADDRINUSE
except AttributeError:
# Fake windows address in-use code
p = patch.object(errno, "WSAEADDRINUSE", 12345, create=True)
p.start()
request.addfinalizer(p.stop)
cfg = Config()
with TemporaryWorkingDirectory() as d:
cfg.ProfileDir.location = d
cf = "kernel.json"
app = DummyKernelApp(config=cfg, connection_file=cf)
request.addfinalizer(app.close)
app.initialize()
with patch.object(app, "_try_bind_socket") as mock_try_bind:
mock_try_bind.side_effect = [
zmq.ZMQError(errno.EADDRINUSE, "fails for non-bind unix"),
zmq.ZMQError(errno.WSAEADDRINUSE, "fails for non-bind windows"),
] + [0] * 100
# Shouldn't raise anything as retries will kick in
app.init_sockets()
def test_port_bind_failure_gives_up_retries(request):
cfg = Config()
with TemporaryWorkingDirectory() as d:
cfg.ProfileDir.location = d
cf = "kernel.json"
app = DummyKernelApp(config=cfg, connection_file=cf)
request.addfinalizer(app.close)
app.initialize()
with patch.object(app, "_try_bind_socket") as mock_try_bind:
mock_try_bind.side_effect = zmq.ZMQError(errno.EADDRINUSE, "fails for non-bind")
with pytest.raises(zmq.ZMQError):
app.init_sockets()
assert mock_try_bind.call_count == 100
|