1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
|
# basic_tests.py -- Basic unit tests for Terminado
# Copyright (c) Jupyter Development Team
# Copyright (c) 2014, Ramalingam Saravanan <sarava@sarava.net>
# Distributed under the terms of the Simplified BSD License.
import asyncio
import datetime
import json
import os
import re
# We must set the policy for python >=3.8, see https://www.tornadoweb.org/en/stable/#installation
# Snippet from https://github.com/tornadoweb/tornado/issues/2608#issuecomment-619524992
import sys
import unittest
from sys import platform
import pytest
import tornado
import tornado.httpserver
import tornado.testing
from tornado.ioloop import IOLoop
from terminado import NamedTermManager, SingleTermManager, TermSocket, UniqueTermManager
if sys.version_info >= (3, 8) and sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
#
# The timeout we use to assume no more messages are coming
# from the sever.
#
DONE_TIMEOUT = 1.0
ASYNC_TEST_TIMEOUT = 30
os.environ["ASYNC_TEST_TIMEOUT"] = str(ASYNC_TEST_TIMEOUT)
MAX_TERMS = 3 # Testing thresholds
class TestTermClient:
"""Test connection to a terminal manager"""
__test__ = False
def __init__(self, websocket):
self.ws = websocket
self.pending_read = None
async def read_msg(self):
# Because the Tornado Websocket client has no way to cancel
# a pending read, we have to keep track of them...
if self.pending_read is None:
self.pending_read = self.ws.read_message()
response = await self.pending_read
self.pending_read = None
if response:
response = json.loads(response)
return response
async def read_all_msg(self, timeout=DONE_TIMEOUT):
"""Read messages until read times out"""
msglist: list = []
delta = datetime.timedelta(seconds=timeout)
while True:
try:
mf = self.read_msg()
msg = await tornado.gen.with_timeout(delta, mf)
except tornado.gen.TimeoutError:
return msglist
msglist.append(msg)
async def write_msg(self, msg):
await self.ws.write_message(json.dumps(msg))
async def read_stdout(self, timeout=DONE_TIMEOUT):
"""Read standard output until timeout read reached,
return stdout and any non-stdout msgs received."""
msglist = await self.read_all_msg(timeout)
stdout = "".join([msg[1] for msg in msglist if msg[0] == "stdout"])
othermsg = [msg for msg in msglist if msg[0] != "stdout"]
return (stdout, othermsg)
async def discard_stdout(self, timeout=DONE_TIMEOUT):
"""Read standard output messages, discarding the data
as it's received. Return the number of bytes discarded
and any non-stdout msgs"""
othermsg: list = []
bytes_discarded = 0
delta = datetime.timedelta(seconds=timeout)
while True:
try:
mf = self.read_msg()
msg = await tornado.gen.with_timeout(delta, mf)
except tornado.gen.TimeoutError:
return bytes_discarded, othermsg
if msg[0] == "stdout":
bytes_discarded += len(msg[1])
else:
othermsg.append(msg)
async def write_stdin(self, data):
"""Write to terminal stdin"""
await self.write_msg(["stdin", data])
async def get_pid(self):
"""Get process ID of terminal shell process"""
await self.read_stdout() # Clear out any pending
await self.write_stdin("echo $$\r")
(stdout, extra) = await self.read_stdout()
if os.name == "nt":
match = re.search(r"echo \$\$\\.*?\\r\\n(\d+)", repr(stdout))
assert match is not None
pid = int(match.groups()[0])
else:
# This should work on any OS, but keeping the above Windows special
# case as I can't verify on Windows.
for li in stdout.splitlines():
if re.match(r"\d+$", li):
pid = int(li)
break
return pid
def close(self):
self.ws.close()
class TermTestCase(tornado.testing.AsyncHTTPTestCase):
# Factory for TestTermClient, because it has to be async
# See: https://github.com/tornadoweb/tornado/issues/1161
async def get_term_client(self, path):
port = self.get_http_port()
url = "ws://127.0.0.1:%d%s" % (port, path)
request = tornado.httpclient.HTTPRequest(
url, headers={"Origin": "http://127.0.0.1:%d" % port}
)
ws = await tornado.websocket.websocket_connect(request)
return TestTermClient(ws)
async def get_term_clients(self, paths):
return await asyncio.gather(*(self.get_term_client(path) for path in paths))
async def get_pids(self, tm_list):
pids = []
for tm in tm_list: # Must be sequential, in case terms are shared
pid = await tm.get_pid()
pids.append(pid)
return pids
def tearDown(self):
run = IOLoop.current().run_sync
run(self.named_tm.kill_all)
run(self.single_tm.kill_all)
run(self.unique_tm.kill_all)
super().tearDown()
def get_app(self):
self.named_tm = NamedTermManager(
shell_command=["bash"],
max_terminals=MAX_TERMS,
)
self.single_tm = SingleTermManager(shell_command=["bash"])
self.unique_tm = UniqueTermManager(
shell_command=["bash"],
max_terminals=MAX_TERMS,
)
named_tm = self.named_tm
class NewTerminalHandler(tornado.web.RequestHandler):
"""Create a new named terminal, return redirect"""
def get(self):
name, terminal = named_tm.new_named_terminal()
self.redirect("/named/" + name, permanent=False)
return tornado.web.Application(
[
(r"/new", NewTerminalHandler),
(r"/named/(\w+)", TermSocket, {"term_manager": self.named_tm}),
(r"/single", TermSocket, {"term_manager": self.single_tm}),
(r"/unique", TermSocket, {"term_manager": self.unique_tm}),
],
debug=True,
)
test_urls = ("/named/term1", "/unique") + (("/single",) if os.name != "nt" else ())
class CommonTests(TermTestCase):
@tornado.testing.gen_test
async def test_basic(self):
for url in self.test_urls:
tm = await self.get_term_client(url)
response = await tm.read_msg()
self.assertEqual(response, ["setup", {}])
# Check for initial shell prompt
response = await tm.read_msg()
self.assertEqual(response[0], "stdout")
self.assertGreater(len(response[1]), 0)
tm.close()
@tornado.testing.gen_test
async def test_basic_command(self):
for url in self.test_urls:
tm = await self.get_term_client(url)
await tm.read_all_msg()
await tm.write_stdin("whoami\n")
(stdout, other) = await tm.read_stdout()
if os.name == "nt":
assert "whoami" in stdout
else:
assert stdout.startswith("who")
assert other == []
tm.close()
class NamedTermTests(TermTestCase):
def test_new(self):
response = self.fetch("/new", follow_redirects=False)
self.assertEqual(response.code, 302)
url = response.headers["Location"]
# Check that the new terminal was created
name = url.split("/")[2]
self.assertIn(name, self.named_tm.terminals)
@tornado.testing.gen_test
async def test_namespace(self):
names = ["/named/1"] * 2 + ["/named/2"] * 2
tms = await self.get_term_clients(names)
pids = await self.get_pids(tms)
self.assertEqual(pids[0], pids[1])
self.assertEqual(pids[2], pids[3])
self.assertNotEqual(pids[0], pids[3])
terminal = self.named_tm.terminals["1"]
killed = await terminal.terminate(True)
assert killed
assert not terminal.ptyproc.isalive()
assert terminal.ptyproc.closed
[tm.close() for tm in tms]
@tornado.testing.gen_test
@pytest.mark.skipif("linux" not in platform, reason="It only works on Linux")
async def test_max_terminals(self):
urls = ["/named/%d" % i for i in range(MAX_TERMS + 1)]
tms = await self.get_term_clients(urls[:MAX_TERMS])
_ = await self.get_pids(tms)
# MAX_TERMS+1 should fail
tm = await self.get_term_client(urls[MAX_TERMS])
msg = await tm.read_msg()
self.assertEqual(msg, None) # Connection closed
tm.close()
[tm.close() for tm in tms]
class SingleTermTests(TermTestCase):
@tornado.testing.gen_test
async def test_single_process(self):
tms = await self.get_term_clients(["/single", "/single"])
pids = await self.get_pids(tms)
self.assertEqual(pids[0], pids[1])
assert self.single_tm.terminal is not None
killed = await self.single_tm.terminal.terminate(True)
assert killed
assert self.single_tm.terminal.ptyproc.closed
[tm.close() for tm in tms]
class UniqueTermTests(TermTestCase):
@tornado.testing.gen_test
async def test_unique_processes(self):
tms = await self.get_term_clients(["/unique", "/unique"])
pids = await self.get_pids(tms)
self.assertNotEqual(pids[0], pids[1])
[tm.close() for tm in tms]
@tornado.testing.gen_test
@pytest.mark.skipif("linux" not in platform, reason="It only works on Linux")
async def test_max_terminals(self):
tms = await self.get_term_clients(["/unique"] * MAX_TERMS)
pids = await self.get_pids(tms)
self.assertEqual(len(set(pids)), MAX_TERMS) # All PIDs unique
# MAX_TERMS+1 should fail
tm = await self.get_term_client("/unique")
msg = await tm.read_msg()
self.assertEqual(msg, None) # Connection closed
# Close one
tms[0].close()
msg = await tms[0].read_msg() # Closed
self.assertEqual(msg, None)
# Should be able to open back up to MAX_TERMS
tm = await self.get_term_client("/unique")
msg = await tm.read_msg()
self.assertEqual(msg[0], "setup")
tm.close()
@tornado.testing.gen_test
@pytest.mark.timeout(timeout=ASYNC_TEST_TIMEOUT, method="thread")
async def test_large_io_doesnt_hang(self):
# This is a regression test for an error where Terminado hangs when
# the PTY buffer size is exceeded. While the buffer size varies from
# OS to OS, 30KBish seems like a reasonable amount and will trigger
# this on both OSX and Debian.
massive_payload = "ten bytes " * 3000
massive_payload = "echo " + massive_payload + "\n"
tm = await self.get_term_client("/unique")
# Clear all startup messages.
await tm.read_all_msg()
# Write a payload that doesn't fit in a single PTY buffer.
await tm.write_stdin(massive_payload)
# Verify that the server didn't hang when responding, and that
# we got a reasonable amount of data back (to tell us the read
# didn't just timeout.
bytes_discarded, other = await tm.discard_stdout()
# Echo won't actually output anything on Windows.
if "win" not in platform:
assert bytes_discarded > 10000
assert other == []
tm.close()
if __name__ == "__main__":
unittest.main()
|