1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
import asyncio
import socket
import urllib.parse
def get_host_port(server):
for sock in server.sockets:
if sock.family == socket.AF_INET: # pragma: no branch
return sock.getsockname()
raise AssertionError("expected at least one IPv4 socket")
def get_uri(server, secure=None):
if secure is None:
secure = server.server._ssl_context is not None # hack
protocol = "wss" if secure else "ws"
host, port = get_host_port(server)
return f"{protocol}://{host}:{port}"
async def handler(ws):
path = urllib.parse.urlparse(ws.request.path).path
if path == "/":
# The default path is an eval shell.
async for expr in ws:
value = eval(expr)
await ws.send(str(value))
elif path == "/crash":
raise RuntimeError
elif path == "/no-op":
pass
elif path == "/delay":
delay = float(await ws.recv())
await ws.close()
await asyncio.sleep(delay)
else:
raise AssertionError(f"unexpected path: {path}")
# This shortcut avoids repeating serve(handler, "localhost", 0) for every test.
args = handler, "localhost", 0
class EvalShellMixin:
async def assertEval(self, client, expr, value):
await client.send(expr)
self.assertEqual(await client.recv(), value)
|