1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
import contextlib
import ssl
import threading
import urllib.parse
from websockets.sync.router import *
from websockets.sync.server import *
def get_uri(server, secure=None):
if secure is None:
secure = isinstance(server.socket, ssl.SSLSocket) # hack
protocol = "wss" if secure else "ws"
host, port = server.socket.getsockname()
return f"{protocol}://{host}:{port}"
def handler(ws):
path = urllib.parse.urlparse(ws.request.path).path
if path == "/":
# The default path is an eval shell.
for expr in ws:
value = eval(expr)
ws.send(str(value))
elif path == "/crash":
raise RuntimeError
elif path == "/no-op":
pass
else:
raise AssertionError(f"unexpected path: {path}")
class EvalShellMixin:
def assertEval(self, client, expr, value):
client.send(expr)
self.assertEqual(client.recv(), value)
@contextlib.contextmanager
def run_server_or_router(
serve_or_route,
handler_or_url_map,
host="localhost",
port=0,
**kwargs,
):
with serve_or_route(handler_or_url_map, host, port, **kwargs) as server:
thread = threading.Thread(target=server.serve_forever)
thread.start()
# HACK: since the sync server doesn't track connections (yet), we record
# a reference to the thread handling the most recent connection, then we
# can wait for that thread to terminate when exiting the context.
handler_thread = None
original_handler = server.handler
def handler(sock, addr):
nonlocal handler_thread
handler_thread = threading.current_thread()
original_handler(sock, addr)
server.handler = handler
try:
yield server
finally:
server.shutdown()
thread.join()
# HACK: wait for the thread handling the most recent connection.
if handler_thread is not None:
handler_thread.join()
def run_server(handler=handler, **kwargs):
return run_server_or_router(serve, handler, **kwargs)
def run_router(url_map, **kwargs):
return run_server_or_router(route, url_map, **kwargs)
@contextlib.contextmanager
def run_unix_server_or_router(
path,
unix_serve_or_route,
handler_or_url_map,
**kwargs,
):
with unix_serve_or_route(handler_or_url_map, path, **kwargs) as server:
thread = threading.Thread(target=server.serve_forever)
thread.start()
try:
yield server
finally:
server.shutdown()
thread.join()
def run_unix_server(path, handler=handler, **kwargs):
return run_unix_server_or_router(path, unix_serve, handler, **kwargs)
def run_unix_router(path, url_map, **kwargs):
return run_unix_server_or_router(path, unix_route, url_map, **kwargs)
|