"""Find addresses and ports for the microsocks test."""

# The cfg.diag_() method will invoke the function at once, or not at all.
# pylint: disable=cell-var-from-loop

from __future__ import annotations

import ipaddress
import itertools
import random
import socket
import sys

from typing import NamedTuple, TYPE_CHECKING

import netifaces

if TYPE_CHECKING:
    from collections.abc import Iterable

    from . import defs


class Address(NamedTuple):
    """Information about a single network address on the system."""

    family: int
    address: str
    packed: bytes


class AddrPort(NamedTuple):
    """An address and two "free" ports to listen on during the test run."""

    address: Address
    svc_port: int
    proxy_port: int
    clients: list[Address]


_IPClassType = type[ipaddress.IPv4Address] | type[ipaddress.IPv6Address]


class _IPFamily(NamedTuple):
    """An IP address family and the corresponding ipaddress class."""

    family: int
    ipcls: _IPClassType


def get_addresses(cfg: defs.Config) -> list[Address]:
    """Get the IPv4 and IPv6 addresses on this system."""
    cfg.diag_("Enumerating the system network interfaces")
    ifaces = netifaces.interfaces()
    cfg.diag(lambda: f"- got {len(ifaces)} interface names")
    if not ifaces:
        return []

    def add_addresses(
        family: int,
        ipcls: _IPClassType,
        addrs: Iterable[str],
    ) -> list[Address]:
        """Create objects for the IPv4/IPv6 addresses found on an interface."""
        res = []
        for addr in addrs:
            try:
                ipaddr = ipcls(addr)
            except ValueError as err:
                cfg.diag_(f"- could not parse the {addr!r} address: {err}")
                continue

            res.append(Address(family=family, address=addr, packed=ipaddr.packed))
            cfg.diag(lambda: f"- added {res[-1]!r}")

        return res

    families: list[_IPFamily] = [
        _IPFamily(socket.AF_INET, ipaddress.IPv4Address),
        _IPFamily(socket.AF_INET6, ipaddress.IPv6Address),
    ]
    return sorted(
        itertools.chain(
            *(
                add_addresses(
                    ipfamily.family,
                    ipfamily.ipcls,
                    (addr["addr"] for addr in addrs.get(ipfamily.family, [])),
                )
                for addrs, ipfamily in itertools.product(
                    (netifaces.ifaddresses(iface) for iface in netifaces.interfaces()),
                    families,
                )
            )
        )
    )


def bind_to(cfg: defs.Config, addr: Address, port: int) -> socket.socket:
    """Bind to the specified port on the specified address."""
    try:
        sock = socket.socket(addr.family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
    except OSError as err:
        cfg.diag_(f"Could not create a family {addr.family} socket: {err}")
        raise

    try:
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    except OSError as err:
        cfg.diag_(f"Could not set the reuse-port option: {err}")
        sock.close()
        raise

    try:
        sock.bind((addr.address, port))
    except OSError as err:
        cfg.diag_(f"Could not bind to port {port} on {addr.address}: {err}")
        sock.close()
        raise

    return sock


def _find_available_port(cfg: defs.Config, addr: Address, port: int) -> int | None:
    """Find a port to listen on at the specified address."""
    try:
        sock = bind_to(cfg, addr, 0)
    except OSError:
        return None

    cfg.diag(lambda: f"  - bound to a random port: {sock.getsockname()!r}")
    sock.close()

    for _ in range(100):
        port += random.randint(10, 30)
        cfg.diag(lambda: f"- trying {port}")
        try:
            sock = bind_to(cfg, addr, port)
        except OSError:
            continue

        cfg.diag_("  - success!")
        sock.close()
        return port

    cfg.diag_("- could not find an available port at all...")
    return None


def find_ports(cfg: defs.Config, addrs: list[Address], first_port: int = 6374) -> list[AddrPort]:
    """Find two ports per network address to listen on."""
    res: list[AddrPort] = []
    for addr in addrs:
        cfg.diag(
            lambda: f"Looking for a service port to listen on for "
            f"{addr.address} family {addr.family}"
        )
        svc_port = _find_available_port(cfg, addr, first_port)
        if svc_port is None:
            cfg.diag(lambda: f"Could not find a service port on {addr.address}")
            continue

        cfg.diag(lambda: f"Looking for a proxy port to listen on for {addr.address}")
        proxy_port = _find_available_port(cfg, addr, svc_port)
        if proxy_port is None:
            cfg.diag(lambda: f"Could not find a service port on {addr.address}")
            continue

        res.append(
            AddrPort(
                address=addr,
                svc_port=svc_port,
                proxy_port=proxy_port,
                clients=[],
            )
        )
        cfg.diag(lambda: f"Added {res[-1]!r}")

    return res


def _check_connect(cfg: defs.Config, server: socket.socket, client: Address) -> bool:
    """Check whether a client socket can connect to the server one."""
    cfg.diag(lambda: f"- checking whether {client.address} can connect to {server!r}")
    with bind_to(cfg, client, 0) as sock:
        cfg.diag(lambda: f"  - got client socket {sock!r}")
        try:
            sock.connect(server.getsockname())
        except OSError as err:
            cfg.diag_(f"  - failed to connect: {err}")
            return False

        try:
            csock, cdata = server.accept()
        except OSError as err:
            cfg.diag_(f"  - failed to accept the connection: {err}")
            return False

        cfg.diag(lambda: f"  - got socket {csock!r} data {cdata!r}")

        try:  # pylint: disable=too-many-try-statements
            if (
                csock.getsockname() != sock.getpeername()
                or csock.getpeername() != sock.getsockname()
            ):
                cfg.diag(lambda: f"  - get*name() mismatch between {csock!r} and {sock!r}")
                return False

            cfg.diag_("  - success!")
            return True  # noqa: TRY300
        finally:
            csock.close()


def find_pairs(cfg: defs.Config, ports: list[AddrPort]) -> dict[int, list[AddrPort]]:
    """Figure out which addresses can connect to which other addresses."""

    def find_single(port: AddrPort, others: Iterable[AddrPort]) -> AddrPort:
        """Find which clients can connect to the specified server port."""
        cfg.diag(lambda: f"Checking whether we can connect to {port.address.address}")
        with bind_to(cfg, port.address, port.svc_port) as svc_sock:
            svc_sock.listen(10)
            with bind_to(cfg, port.address, port.proxy_port) as proxy_sock:
                proxy_sock.listen(10)
                return port._replace(  # noqa: SLF001
                    clients=[
                        other.address
                        for other in others
                        if _check_connect(cfg, svc_sock, other.address)
                        and _check_connect(cfg, proxy_sock, other.address)
                    ]
                )

    return {
        family: data
        for family, data in (
            (
                family,
                [
                    res_port
                    for res_port in (
                        find_single(port, (other for other in lports if other != port))
                        for port in lports
                    )
                    if res_port.clients
                ],
            )
            for family, lports in (
                (family, list(fports))
                for family, fports in itertools.groupby(ports, lambda port: port.address.family)
            )
        )
        if data
    }


def pick_pairs(
    cfg: defs.Config, apairs: dict[int, list[AddrPort]]
) -> list[tuple[AddrPort, AddrPort]]:
    """Pick two (maybe the same) addresses for each family."""

    def reorder(server: Address, clients: list[Address]) -> list[Address]:
        """Sort the addresses, put the server's own address at the end."""
        return [addr for addr in sorted(clients) if addr != server] + [
            addr for addr in clients if addr == server
        ]

    res: dict[int, tuple[AddrPort, AddrPort]] = {}
    for family, pairs in apairs.items():
        if len(pairs) == 1:
            cfg.diag(lambda: f"Considering a single set for {family!r}")
            first = pairs[0]
            clients = reorder(first.address, first.clients)
            r_first = first._replace(clients=[clients[0]])  # noqa: SLF001
            r_second = first._replace(  # noqa: SLF001
                clients=[clients[0] if len(clients) == 1 else clients[1]]
            )
        else:
            cfg.diag(lambda: f"Considering two sets for {family!r}")
            first, second = pairs[0], pairs[1]
            c_first, c_second = reorder(first.address, first.clients), reorder(
                second.address, second.clients
            )
            r_first, r_second = first._replace(  # noqa: SLF001
                clients=[c_first[0]]
            ), second._replace(  # noqa: SLF001
                clients=[c_second[0]]
            )

        if len(r_first.clients) != 1 or len(r_second.clients) != 1:
            sys.exit(
                f"Internal error: unexpected number of clients: "
                f"{r_first.clients=!r} {r_second.clients=!r}"
            )
        if (r_first.address, {r_first.svc_port, r_first.proxy_port}) == (
            r_second.address,
            {r_second.svc_port, r_second.proxy_port},
        ):
            # So basically we only have a single address to work with...
            cfg.diag(lambda: "Looking for more ports to listen on at {r_second.address}")
            more_ports = find_ports(
                cfg,
                [r_second.address],
                first_port=max(r_second.svc_port, r_second.proxy_port) + 1,
            )[0]
            r_second = r_second._replace(  # noqa: SLF001
                svc_port=more_ports.svc_port, proxy_port=more_ports.proxy_port
            )
            if {r_first.svc_port, r_first.proxy_port} == {
                r_second.svc_port,
                r_second.proxy_port,
            }:
                sys.exit(
                    f"Internal error: duplicte port pairs: "
                    f"{r_first.svc_port!r} {r_first.proxy_port!r} "
                    f"{r_second.svc_port!r} {r_second.proxy_port!r}"
                )

        res[family] = (r_first, r_second)
        cfg.diag(lambda: f"Address family {family!r}: picked {res[family]!r}")

    return [
        (res[first][0], res[second][1])
        for first, second in itertools.product(res.keys(), res.keys())
    ]
