1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
import asyncio
import logging
import os
from typing import no_type_check
from unittest.mock import MagicMock
import pytest
import pytest_asyncio
import zmq
from jupyter_client.session import Session
from tornado.ioloop import IOLoop
from zmq.eventloop.zmqstream import ZMQStream
from ipykernel.ipkernel import IPythonKernel
from ipykernel.kernelbase import Kernel
from ipykernel.zmqshell import ZMQInteractiveShell
try:
import resource
except ImportError:
# Windows
resource = None # type:ignore
from .utils import new_kernel
# Handle resource limit
# Ensure a minimal soft limit of DEFAULT_SOFT if the current hard limit is at least that much.
if resource is not None:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
DEFAULT_SOFT = 4096
if hard >= DEFAULT_SOFT:
soft = DEFAULT_SOFT
if hard < soft:
hard = soft
resource.setrlimit(resource.RLIMIT_NOFILE, (soft, hard))
# Enforce selector event loop on Windows.
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type:ignore
class KernelMixin:
log = logging.getLogger()
def _initialize(self):
self.context = context = zmq.Context()
self.iopub_socket = context.socket(zmq.PUB)
self.stdin_socket = context.socket(zmq.ROUTER)
self.session = Session()
self.test_sockets = [self.iopub_socket]
self.test_streams = []
for name in ["shell", "control"]:
socket = context.socket(zmq.ROUTER)
stream = ZMQStream(socket)
stream.on_send(self._on_send)
self.test_sockets.append(socket)
self.test_streams.append(stream)
setattr(self, f"{name}_stream", stream)
async def do_debug_request(self, msg):
return {}
def destroy(self):
for stream in self.test_streams:
stream.close()
for socket in self.test_sockets:
socket.close()
self.context.destroy()
@no_type_check
async def test_shell_message(self, *args, **kwargs):
msg_list = self._prep_msg(*args, **kwargs)
await self.dispatch_shell(msg_list)
self.shell_stream.flush()
return await self._wait_for_msg()
@no_type_check
async def test_control_message(self, *args, **kwargs):
msg_list = self._prep_msg(*args, **kwargs)
await self.process_control(msg_list)
self.control_stream.flush()
return await self._wait_for_msg()
def _on_send(self, msg, *args, **kwargs):
self._reply = msg
def _prep_msg(self, *args, **kwargs):
self._reply = None
raw_msg = self.session.msg(*args, **kwargs)
msg = self.session.serialize(raw_msg)
return [zmq.Message(m) for m in msg]
async def _wait_for_msg(self):
while not self._reply:
await asyncio.sleep(0.1)
_, msg = self.session.feed_identities(self._reply)
return self.session.deserialize(msg)
def _send_interrupt_children(self):
# override to prevent deadlock
pass
class MockKernel(KernelMixin, Kernel): # type:ignore
implementation = "test"
implementation_version = "1.0"
language = "no-op"
language_version = "0.1"
language_info = {
"name": "test",
"mimetype": "text/plain",
"file_extension": ".txt",
}
banner = "test kernel"
def __init__(self, *args, **kwargs):
self._initialize()
self.shell = MagicMock()
super().__init__(*args, **kwargs)
async def do_execute(
self, code, silent, store_history=True, user_expressions=None, allow_stdin=False
):
if not silent:
stream_content = {"name": "stdout", "text": code}
self.send_response(self.iopub_socket, "stream", stream_content)
return {
"status": "ok",
# The base class increments the execution count
"execution_count": self.execution_count,
"payload": [],
"user_expressions": {},
}
class MockIPyKernel(KernelMixin, IPythonKernel): # type:ignore
def __init__(self, *args, **kwargs):
self._initialize()
super().__init__(*args, **kwargs)
@pytest_asyncio.fixture()
def kernel():
kernel = MockKernel()
kernel.io_loop = IOLoop.current()
yield kernel
kernel.destroy()
@pytest_asyncio.fixture()
def ipkernel():
kernel = MockIPyKernel()
kernel.io_loop = IOLoop.current()
yield kernel
kernel.destroy()
ZMQInteractiveShell.clear_instance()
@pytest.fixture
def kc():
with new_kernel() as kc:
yield kc
|