#!/usr/bin/env python

import asyncio
import os
import signal
import statistics
import tracemalloc

from websockets.asyncio.server import serve
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory


CLIENTS = 20
INTERVAL = 1 / 10  # seconds

WB, ML = 12, 5

MEM_SIZE = []


async def handler(ws):
    msg = await ws.recv()
    await ws.send(msg)

    msg = await ws.recv()
    await ws.send(msg)

    MEM_SIZE.append(tracemalloc.get_traced_memory()[0])
    tracemalloc.stop()

    tracemalloc.start()

    # Hold connection open until the end of the test.
    await asyncio.sleep(CLIENTS * INTERVAL)


async def server():
    async with serve(
        handler,
        "localhost",
        8765,
        extensions=[
            ServerPerMessageDeflateFactory(
                server_max_window_bits=WB,
                client_max_window_bits=WB,
                compress_settings={"memLevel": ML},
            )
        ],
    ) as server:
        print("Stop the server with:")
        print(f"kill -TERM {os.getpid()}")
        print()
        loop = asyncio.get_running_loop()
        loop.add_signal_handler(signal.SIGTERM, server.close)

        tracemalloc.start()
        await server.wait_closed()


asyncio.run(server())


# First connection incurs non-representative setup costs.
del MEM_SIZE[0]

print(f"µ = {statistics.mean(MEM_SIZE) / 1024:.1f} KiB")
print(f"σ = {statistics.stdev(MEM_SIZE) / 1024:.1f} KiB")
