"""Run some tests on the microsocks proxy server and client."""

import argparse
import dataclasses
import errno
import pathlib
import socket
import subprocess
import struct
import sys
import time

import utf8_locale

from . import addresses
from . import defs


@dataclasses.dataclass(frozen=True)
class Config(defs.Config):
    """Runtime configuration for the microsocks test tool."""

    msocks: pathlib.Path
    utf8_env: dict[str, str]


def _parse_args() -> Config:
    """Parse the command-line arguments."""
    parser = argparse.ArgumentParser(prog="msocktest")
    parser.add_argument(
        "-s",
        "--microsocks",
        type=pathlib.Path,
        required=True,
        help="the path to the microsocks program to test",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="verbose mode; display diagnostic output",
    )

    args = parser.parse_args()

    return Config(
        msocks=args.microsocks,
        verbose=args.verbose,
        utf8_env=utf8_locale.get_utf8_env(),
    )


def _do_test_conn_connect(
    cfg: Config, cli_sock: socket.socket, address: addresses.Address, port: int
) -> None:
    """Connect to the specified address/port."""
    dest = (address.address, port)
    print(f"Connecting to the microsocks server at {dest!r}")
    for _ in range(10):
        try:  # pylint: disable=too-many-try-statements
            cli_sock.connect(dest)
            break
        except OSError as err:
            if err.errno != errno.ECONNREFUSED:
                raise
            cfg.diag_("Could not connect, waiting for a second")
            time.sleep(1)
    else:
        sys.exit(f"Could not connect to the microsocks server at {dest} after ten attempts")


def _expect_read(cfg: Config, sock: socket.socket, expected: bytes, tag: str) -> None:
    """Read some data, make sure it is as expected."""
    data = sock.recv(4096)
    cfg.diag(lambda: f"- got {data!r}")
    if data != expected:
        sys.exit(f"Unexpected {tag}: {expected=!r} {data=!r}")


def _do_test_conn_xfer(
    cfg: Config,
    cli_sock: socket.socket,
    srv_sock: socket.socket,
    svc_listen: addresses.AddrPort,
) -> None:
    """Perform the SOCKS5 protocol negotiation and conversation."""
    cfg.diag_("Sending 'none' auth")
    cli_sock.send(bytes([5, 1, 0]))

    cfg.diag_("Waiting for the server's auth response")
    _expect_read(cfg, cli_sock, bytes([5, 0]), "auth response")

    family = 1 if svc_listen.address.family == socket.AF_INET else 4
    data = (
        bytes([5, 1, 0, family])
        + svc_listen.address.packed
        + struct.pack(">h", svc_listen.svc_port)
    )
    cfg.diag(lambda: f"Sending a SOCKS5 CONNECT request: {data!r}")
    cli_sock.send(data)

    cfg.diag_("Waiting for the server's connect response")
    # So microsocks is kind of lazy and always returns an IPv4-like
    # response: ATYP 1, four zeroes for the address.
    _expect_read(cfg, cli_sock, bytes([5, 0, 0, 1] + [0] * 4 + [0] * 2), "connect response")

    cfg.diag_("Accepting a connection from the microsocks server")
    (conn_sock, conn_data) = srv_sock.accept()
    cfg.diag(lambda: f"- accepted a connection on fd {conn_sock.fileno()} from {conn_data!r}")
    if conn_sock.family != svc_listen.address.family:
        sys.exit(
            f"Expected a {svc_listen.address.family} family connection, got {conn_sock.family}"
        )

    cfg.diag_("Let's say hello to the client")
    expected = b"Hello"
    conn_sock.send(expected)
    cfg.diag_("Let's try to read that from the client side")
    _expect_read(cfg, cli_sock, expected, "client read")

    cfg.diag_("Closing the client connection")
    cli_sock.close()
    cfg.diag_("Let's get an empty read from the server side")
    _expect_read(cfg, conn_sock, b"", "server side empty read")

    cfg.diag_("Closing the server side of the connection")
    conn_sock.close()


def _test_conn(
    cfg: Config,
    proxy_listen: addresses.AddrPort,
    svc_listen: addresses.AddrPort,
) -> None:
    """Test the connectivity across a microsocks server."""
    print(f"Client at {proxy_listen.clients[0]} port {proxy_listen.proxy_port}")
    print(f"Proxy at {proxy_listen.address} port {proxy_listen.svc_port}")
    print(f"Proxy client at {svc_listen.clients[0]} port {svc_listen.proxy_port}")
    print(f"Server at {svc_listen.address} port {svc_listen.svc_port}")

    print("Creating the server listening socket")
    with addresses.bind_to(cfg, svc_listen.address, svc_listen.svc_port) as srv_sock:
        cfg.diag(lambda: f"Server socket: {srv_sock!r}")
        srv_sock.listen(1)

        print("Spawning the microsocks server process")
        with subprocess.Popen(
            [
                cfg.msocks,
                "-i",
                proxy_listen.address.address,
                "-p",
                str(proxy_listen.svc_port),
                "-b",
                svc_listen.clients[0].address,
            ],
            bufsize=0,
            encoding="UTF-8",
            env=cfg.utf8_env,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        ) as sproc:
            cfg.diag(lambda: f"- spawned microsocks server at {sproc.pid}")

            try:  # pylint: disable=too-many-try-statements
                print("Creating the client socket")
                with addresses.bind_to(
                    cfg,
                    proxy_listen.clients[0],
                    proxy_listen.proxy_port,
                ) as cli_sock:
                    cfg.diag(lambda: f"Client socket: {cli_sock!r}")
                    _do_test_conn_connect(
                        cfg,
                        cli_sock,
                        proxy_listen.address,
                        proxy_listen.svc_port,
                    )
                    cfg.diag(lambda: f"- connected: {cli_sock!r}")

                    _do_test_conn_xfer(cfg, cli_sock, srv_sock, svc_listen)

                cfg.diag_("Stopping the microsocks server")
                sproc.terminate()
                cfg.diag_("Waiting for the microsocks server process to exit")
                print(repr(sproc.communicate()))
                cfg.diag(lambda: f"microsocks server exit code {sproc.wait()}")
            except BaseException as err:  # pylint: disable=broad-except
                cfg.diag_(f"Killing the microsocks server because of an exception: {err}")
                sproc.kill()
                raise


def main() -> None:
    """Parse command-line arguments, prepare the environment, run tests."""
    cfg = _parse_args()
    cfg.diag(lambda: f"Using {cfg.utf8_env['LC_ALL']} as a UTF-8 locale.")

    cfg.diag_("Starting up")
    apairs = addresses.find_pairs(cfg, addresses.find_ports(cfg, addresses.get_addresses(cfg)))
    print("Connectivity information: address family, server, clients:")
    for family, ports in sorted(apairs.items()):
        print(family)
        for port in ports:
            print(f"\t{port.address.address}")
            for client in port.clients:
                print(f"\t\t{client.address}")

    print("Picking two pairs for each address family, if possible")
    selected = addresses.pick_pairs(cfg, apairs)
    print(f"Testing {len(selected)} combination(s)")
    for proxy_listen, svc_listen in selected:
        _test_conn(cfg, proxy_listen, svc_listen)


if __name__ == "__main__":
    main()
