# SPDX-FileCopyrightText: Peter Pentchev <roam@ringlet.net>
# SPDX-License-Identifier: BSD-2-Clause
"""Run the remrun tests against a local SSH server instance."""

from __future__ import annotations

import argparse
import contextlib
import dataclasses
import errno
import os
import pathlib
import pwd
import shlex
import shutil
import socket
import stat
import subprocess  # noqa: S404
import sys
import tempfile
import time
import typing

from . import util


if typing.TYPE_CHECKING:
    import logging
    from collections.abc import Iterator
    from typing import Final


VERSION: Final = "0.2.4"

PATH_PRIVSEP: Final = pathlib.Path("/run/sshd")


@dataclasses.dataclass(frozen=True)
class Account:
    """The (possibly unprivileged) account to run the tests as."""

    pw_ent: pwd.struct_passwd
    need_setuid: bool


@dataclasses.dataclass(frozen=True)
class Config:
    """Runtime configuration for the remrun test tool."""

    log: logging.Logger
    no_chmod_homedir: bool
    prog: pathlib.Path
    test_prog: pathlib.Path | None
    test_account: Account
    utf8_env: dict[str, str]


@dataclasses.dataclass(frozen=True)
class SSHConfig:
    """Information about the generated SSH configuration."""

    addr: str
    port: int
    username: str
    home: pathlib.Path
    client_config: pathlib.Path
    server_config: pathlib.Path


def get_test_account(log: logging.Logger, unprivileged: str | None) -> Account:
    """Get the data about the account we are running as or the other one."""
    if unprivileged is None:
        our_uid: Final = os.geteuid()
        log.debug("Getting information about our own account, uid %(uid)d", {"uid": our_uid})
        our_pw_ent: Final = pwd.getpwuid(our_uid)
        log.debug(
            "Got username %(name)s, uid %(uid)d, home %(home)s",
            {"name": our_pw_ent.pw_name, "uid": our_pw_ent.pw_uid, "home": our_pw_ent.pw_dir},
        )
        return Account(pw_ent=our_pw_ent, need_setuid=False)

    log.debug(
        "Getting information about another account, username %(name)s",
        {"name": unprivileged},
    )
    unpriv_pw_ent: Final = pwd.getpwnam(unprivileged)
    log.debug(
        "Got username %(name)s, uid %(uid)d, home %(home)s",
        {"name": unpriv_pw_ent.pw_name, "uid": unpriv_pw_ent.pw_uid, "home": unpriv_pw_ent.pw_dir},
    )
    return Account(pw_ent=unpriv_pw_ent, need_setuid=True)


def parse_args() -> Config:
    """Parse the command-line arguments."""
    parser: Final = argparse.ArgumentParser(prog="run_sshd_test")
    parser.add_argument(
        "-H",
        "--no-chmod-homedir",
        action="store_true",
        help="do not temporarily remove write access from the home directory",
    )
    parser.add_argument(
        "-t",
        "--test-prog",
        type=pathlib.Path,
        help="the path to run-test.sh if it is to be run",
    )
    parser.add_argument(
        "-u",
        "--unprivileged",
        type=str,
        help="the username of the unprivileged account to switch to",
    )
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        help="verbose operation; display diagnostic output",
    )
    parser.add_argument("remrun", type=pathlib.Path, help="the path to the remrun program to test")

    args: Final = parser.parse_args()

    prog: Final = args.remrun.absolute()
    if not prog.is_file() or not os.access(prog, os.R_OK | os.X_OK):
        sys.exit(f"Not an executable regular file: {prog}")

    log: Final = util.build_logger(verbose=args.verbose, quiet=False)
    test_account: Final = get_test_account(log, args.unprivileged)

    return Config(
        log=log,
        no_chmod_homedir=args.no_chmod_homedir,
        prog=prog,
        test_prog=args.test_prog.absolute() if args.test_prog is not None else None,
        test_account=test_account,
        utf8_env=util.get_utf8_env(),
    )


def find_listening_port(cfg: Config) -> tuple[str, int]:
    """Find a port to listen on at a local address."""
    for addr, family in (("127.0.0.1", socket.AF_INET), ("::1", socket.AF_INET6)):
        cfg.log.debug("Looking for a port to listen on at %(addr)s", {"addr": addr})
        for port in range(8086, 8200):
            lsock = socket.socket(family=family, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP)
            lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                lsock.bind((addr, port))
                lsock.close()
            except OSError as err:
                cfg.log.debug(
                    "- could not bind to %(addr)s:%(port)d: %(err)s",
                    {"addr": addr, "port": port, "err": err},
                )

            cfg.log.debug("- got %(addr)s:%(port)d", {"addr": addr, "port": port})
            return addr, port

        cfg.log.debug("- could not bind to any of the desired ports at %(addr)s", {"addr": addr})

    sys.exit("Could not find a local address/port to listen on")


def create_ssh_config(cfg: Config, addr: str, port: int, tempd: pathlib.Path) -> SSHConfig:
    """Set up the SSH server's config directories."""
    username: Final = pwd.getpwuid(os.getuid()).pw_name

    home_dir: Final = tempd / "home"
    cli_dir: Final = home_dir / ".ssh"
    srv_dir: Final = tempd / "server"

    cli_cfg: Final = cli_dir / "config"
    cli_key: Final = cli_dir / "id"
    cli_known: Final = cli_dir / "known_hosts"

    srv_authkeys: Final = srv_dir / "authorized_keys"
    srv_cfg: Final = srv_dir / "sshd_config"
    srv_key: Final = srv_dir / "ssh_host_key"
    srv_pid: Final = srv_dir / "sshd.pid"

    home_dir.mkdir(mode=0o700)
    cli_dir.mkdir(mode=0o700)
    srv_dir.mkdir(mode=0o700)

    cfg.log.debug("Generating the SSH host key at %(srv_key)s", {"srv_key": srv_key})
    subprocess.check_call(  # noqa: S603
        ["ssh-keygen", "-f", srv_key, "-t", "ed25519", "-N", ""],  # noqa: S607
        env=cfg.utf8_env,
    )

    cfg.log.debug("Generating the SSH client key at %(cli_key)s", {"cli_key": cli_key})
    subprocess.check_call(  # noqa: S603
        ["ssh-keygen", "-f", cli_key, "-t", "ed25519", "-N", ""],  # noqa: S607
        env=cfg.utf8_env,
    )

    cfg.log.debug("Copying the client public key to the authorized keys file")
    srv_authkeys.write_text(
        cli_key.with_suffix(".pub").read_text(encoding="UTF-8"),
        encoding="UTF-8",
    )
    srv_authkeys.chmod(0o600)

    cfg.log.debug("Generating the client known hosts file")
    cli_known.write_text(
        addr + " " + srv_key.with_suffix(".pub").read_text(encoding="UTF-8"),
        encoding="UTF-8",
    )

    cfg.log.debug("Generating the SSH client config file")
    cli_cfg.write_text(
        f"""
Host *
ForwardAgent no
ForwardX11 no
GlobalKnownHostsFile /dev/null
GSSAPIAuthentication no
HostbasedAuthentication no
IdentitiesOnly yes
IdentityFile {cli_key}
KbdInteractiveAuthentication no
PasswordAuthentication no
Port {port}
PubkeyAuthentication yes
RequestTTY no
StrictHostKeyChecking yes
Tunnel no
UpdateHostKeys no
User {username}
UserKnownHostsFile {cli_known}
VerifyHostKeyDNS no
""",
        encoding="UTF-8",
    )

    cfg.log.debug("Generating the SSH server config file")
    srv_cfg.write_text(
        f"""
AllowUsers {username}
AuthorizedKeysFile {srv_authkeys}
DisableForwarding yes
GSSAPIAuthentication no
HostKey {srv_key}
IgnoreRhosts yes
KbdInteractiveAuthentication no
ListenAddress {addr}
PasswordAuthentication no
PermitRootLogin {'yes' if username == 'root' else 'no'}
PermitTTY no
PidFile {srv_pid}
Port {port}
PubkeyAuthentication yes
StrictModes no
UseDNS no
""",
        encoding="UTF-8",
    )

    subprocess.check_call(  # noqa: S603
        ["grep", "-Ere", "^", "."],  # noqa: S607
        cwd=tempd,
        env=cfg.utf8_env,
    )

    if username == "root":
        cfg.log.debug("Unlocking the root account, just in case")
        subprocess.check_call(["usermod", "-U", "root"], env=cfg.utf8_env)  # noqa: S603,S607

    return SSHConfig(
        addr=addr,
        port=port,
        username=username,
        home=home_dir,
        client_config=cli_cfg,
        server_config=srv_cfg,
    )


def create_ssh_wrapper(cfg: Config, ssh_cfg: SSHConfig) -> Config:
    """Create the SSH wrapper that uses the generated config and keys."""
    home_bin = ssh_cfg.home / "bin"
    home_bin.mkdir(mode=0o700)

    cfg.log.debug("Determining the full path of the real SSH executable")
    match subprocess.check_output(  # noqa: S603
        ["sh", "-c", "command -v ssh"],  # noqa: S607
        encoding="UTF-8",
        env=cfg.utf8_env,
    ).splitlines():
        case [single]:
            ssh_prog: Final = pathlib.Path(single)

        case other:
            sys.exit(f"Expected `command -v ssh` to output exactly one line, got {other!r}")

    if not ssh_prog.is_file() or not os.access(ssh_prog, os.R_OK | os.X_OK):
        sys.exit(f"Expected `command -v ssh` to point to an executable file, got {ssh_prog!r}")

    cfg.log.debug("Generating the SSH wrapper")
    ssh_wrapper: Final = home_bin / "ssh"
    ssh_wrapper.write_text(
        f"""#!/bin/sh

exec {shlex.quote(str(ssh_prog))} -F {shlex.quote(str(ssh_cfg.client_config))} "$@"
""",
        encoding="UTF-8",
    )
    ssh_wrapper.chmod(0o700)

    utf8_env: Final = dict(cfg.utf8_env)
    match utf8_env.get("PATH"):
        case None:
            utf8_env["PATH"] = str(home_bin)

        case opath:
            utf8_env["PATH"] = f"{home_bin}:{opath}"

    cfg.log.debug("Checking that the SSH-specific environment is sane")
    match subprocess.check_output(  # noqa: S603
        ["sh", "-c", "command -v ssh"],  # noqa: S607
        encoding="UTF-8",
        env=utf8_env,
    ).splitlines():
        case [single] if single == str(ssh_wrapper):
            pass

        case other_which:
            sys.exit(f"Expected `command -v ssh` to output {ssh_wrapper}, got {other_which!r}")

    return dataclasses.replace(cfg, utf8_env=utf8_env)


def check_ssh_connection(cfg: Config, ssh_cfg: SSHConfig) -> None:
    """Once the SSH server has been started, check that the client can connect to it."""
    cfg.log.debug("Checking that our SSH client and server both work")
    match subprocess.check_output(  # noqa: S603
        ["sh", "-c", f"ssh -- {ssh_cfg.addr} printenv SSH_CONNECTION"],  # noqa: S607
        encoding="UTF-8",
        env=cfg.utf8_env,
    ).splitlines():
        case [single] if single.split()[-2:] == [ssh_cfg.addr, str(ssh_cfg.port)]:
            pass

        case other:
            sys.exit(
                f"Expected `printenv SSH_CONNECTION` to end with '{ssh_cfg.addr} {ssh_cfg.port}', "
                f"got {other!r}",
            )


@contextlib.contextmanager
def start_sshd(cfg: Config, ssh_cfg: SSHConfig) -> Iterator[subprocess.Popen[str]]:
    """Start an SSH server listening at the specified address and port."""
    proc = None
    try:
        cfg.log.debug("Looking for an SSH server executable")
        current_path: Final = os.environ.get("PATH")
        sshd: Final = shutil.which(
            "sshd",
            path=f"{current_path}:/usr/sbin/sshd" if current_path is not None else "/usr/sbin/sshd",
        )
        if sshd is None:
            sys.exit("No sshd in the search path or /usr/sbin")

        cfg.log.debug("Starting an SSH server: %(sshd)s", {"sshd": sshd})
        proc = subprocess.Popen(  # noqa: S603
            [sshd, "-D", "-e", "-f", ssh_cfg.server_config],
            encoding="UTF-8",
            env=cfg.utf8_env,
        )
        cfg.log.debug("- got SSH server process %(pid)d", {"pid": proc.pid})
        yield proc
    finally:
        if proc is not None:
            if not proc.poll():
                cfg.log.debug("Killing the SSH server")
                proc.kill()
            cfg.log.debug("The SSH server is done, code %(code)d", {"code": proc.wait()})


def create_test_script(cfg: Config, tempd: pathlib.Path) -> pathlib.Path:
    """Create the test script that runs printenv with some arguments."""
    cfg.log.debug("Creating the test printenv script to run on the other side")
    test_printenv: Final = tempd / "test_printenv"
    test_printenv.write_text(
        """#!/bin/sh

printenv USER SSH_CONNECTION
""",
        encoding="UTF-8",
    )
    test_printenv.chmod(0o700)
    return test_printenv


def test_remrun(
    cfg: Config,
    ssh_cfg: SSHConfig,
    test_printenv: pathlib.Path,
    *,
    remote_tmp: pathlib.Path | None = None,
) -> None:
    """Run remrun a couple of times, examine its output."""
    assert cfg.test_prog is None

    remote_tmp_opts: Final = ["-T", str(remote_tmp)] if remote_tmp is not None else []

    cfg.log.debug("Now running remrun with our client against our server")
    match subprocess.check_output(  # noqa: S603
        [cfg.prog, *remote_tmp_opts, "--", ssh_cfg.addr, test_printenv],
        encoding="UTF-8",
        env=cfg.utf8_env,
    ).splitlines():
        case [username, conninfo] if username == ssh_cfg.username and conninfo.split()[-2:] == [
            ssh_cfg.addr,
            str(ssh_cfg.port),
        ]:
            pass

        case other_conninfo:
            sys.exit(
                f"Expected `remrun test_printenv` to output {ssh_cfg.username!r} and "
                f"something ending in {ssh_cfg.addr!r} {ssh_cfg.port!r}, got {other_conninfo!r}",
            )

    res: Final = subprocess.run(  # noqa: S603
        [str(cfg.prog), *remote_tmp_opts, "--", ssh_cfg.addr, "-"],
        capture_output=True,
        check=True,
        encoding="UTF-8",
        env=cfg.utf8_env,
        input=test_printenv.read_text(encoding="UTF-8"),
    )
    if not res.stdout:
        sys.exit("`remrun -` did not output anything")
    else:
        match res.stdout.splitlines():
            case [username, conninfo] if username == ssh_cfg.username and conninfo.split()[-2:] == [
                ssh_cfg.addr,
                str(ssh_cfg.port),
            ]:
                pass

            case other_more_conninfo:
                sys.exit(
                    f"Expected `remrun -` to output {ssh_cfg.username!r} and "
                    f"something ending in {ssh_cfg.addr!r} {ssh_cfg.port!r}, "
                    f"got {other_more_conninfo!r}",
                )


def test_prog(cfg: Config, ssh_cfg: SSHConfig, *, remote_tmp: pathlib.Path | None = None) -> None:
    """Run the run-test.sh test suite within our environment."""
    assert cfg.test_prog is not None
    cfg.log.debug("Running the %(prog)s testsuite", {"prog": cfg.test_prog})
    run_env: Final = dict(cfg.utf8_env)
    run_env["REMRUN_TEST_HOSTSPEC"] = ssh_cfg.addr

    if remote_tmp is None:
        with contextlib.suppress(KeyError):
            del run_env["REMRUN_TEST_REMOTE_TMP"]
    else:
        run_env["REMRUN_TEST_REMOTE_TMP"] = str(remote_tmp)

    subprocess.check_call(["sh", "--", cfg.test_prog, cfg.prog], env=run_env)  # noqa: S603,S607


def ensure_privsep_path(cfg: Config, tempd: pathlib.Path) -> bool:
    """Make sure the SSH server will be able to chroot into /run/sshd."""
    if not PATH_PRIVSEP.is_dir():
        cfg.log.debug("Creating the %(path)s directory", {"path": PATH_PRIVSEP})
        PATH_PRIVSEP.mkdir(mode=0o744, parents=True)

    if cfg.test_account.need_setuid:
        name, uid, gid = (
            cfg.test_account.pw_ent.pw_name,
            cfg.test_account.pw_ent.pw_uid,
            cfg.test_account.pw_ent.pw_gid,
        )
        cfg.log.debug(
            "Changing the ownership of %(tempd)s to %(uid)d:%(gid)d for %(unpriv)r",
            {"tempd": tempd, "uid": uid, "gid": gid, "unpriv": cfg.test_account.pw_ent.pw_name},
        )
        try:
            os.chown(tempd, uid, gid)
        except OSError as err:
            sys.exit(
                f"Could not chown() {tempd} to {uid}:{gid} for {name!r}: {err}",
            )

        child_pid: Final = os.fork()
        if child_pid != 0:
            cfg.log.debug(
                "Skipping the tests, waiting for process %(child_pid)d to end",
                {"child_pid": child_pid},
            )
            res: Final = os.waitpid(child_pid, 0)
            cfg.log.debug(
                "Process %(child_pid)d ended, code %(res)r",
                {"child_pid": child_pid, "res": res},
            )
            if res[1]:
                sys.exit(f"The unprivileged child process failed; code {res[1]}")
            return False

        cfg.log.debug("Trying to setuid() to %(unpriv)r", {"unpriv": name})
        try:
            os.setgid(gid)
            os.setuid(uid)
        except OSError as err:
            sys.exit(
                f"Could not setuid()/setgid() to {uid}:{gid} for {name!r}: {err}",
            )

    return True


@contextlib.contextmanager
def create_temp_dir(cfg: Config) -> Iterator[pathlib.Path]:
    """Create a temporary directory, remove it at the end."""
    tempd_obj = None
    initial_uid: Final = os.getuid()
    try:
        tempd_obj = tempfile.mkdtemp(prefix="run_sshd_test.", dir=".")
        tempd: Final = pathlib.Path(tempd_obj).absolute()
        cfg.log.debug(
            "Using %(tempd)s as a temporary directory, initial uid %(initial_uid)s",
            {"tempd": tempd, "initial_uid": initial_uid},
        )
        yield tempd
    finally:
        if tempd_obj is not None:
            current_pid = os.getpid()
            current_uid = os.getuid()
            if initial_uid != current_uid:
                cfg.log.debug(
                    "Not removing %(tempd)s in process %(current_pid)d: "
                    "uid %(current_uid)d != %(initial_uid)d",
                    {
                        "tempd": tempd,
                        "current_pid": current_pid,
                        "current_uid": current_uid,
                        "initial_uid": initial_uid,
                    },
                )
            else:
                cfg.log.debug(
                    "Removing %(tempd)s in process %(current_pid)d",
                    {"tempd": tempd, "current_pid": current_pid},
                )
                shutil.rmtree(tempd)


def wait_for_sshd_banner(cfg: Config, ssh_cfg: SSHConfig) -> None:
    """Try to connect to the SSH server's port, expect a banner."""
    match socket.getaddrinfo(
        ssh_cfg.addr,
        ssh_cfg.port,
        type=socket.SOCK_STREAM,
        proto=socket.IPPROTO_TCP,
        flags=socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
    )[0]:
        case (a_family, a_type, a_proto, _, (a_address, a_port, *_)):
            s_family, s_type, s_proto, s_address, s_port = (
                a_family,
                a_type,
                a_proto,
                a_address,
                a_port,
            )

        case other_ainfo:
            sys.exit(f"getaddrinfo({ssh_cfg.addr!r}, {ssh_cfg.port!r}) returned {other_ainfo!r}")

    cfg.log.debug(
        "Waiting for the SSH server at %(s_address)s:%(s_port)d "
        "(address family %(s_family)s) to start accepting connections",
        {"s_address": s_address, "s_port": s_port, "s_family": s_family.name},
    )

    for _ in range(20):
        time.sleep(0.5)
        cfg.log.debug(
            "Trying to connect to %(s_address)s port %(s_port)d...",
            {"s_address": s_address, "s_port": s_port},
        )
        with socket.socket(s_family, s_type, s_proto) as sock:
            try:
                sock.connect((s_address, s_port))
            except OSError as err:
                if err.errno != errno.ECONNREFUSED:
                    raise
                print("Connection refused, will retry")
                continue

            cfg.log.debug("Connected!")
            data = sock.recv(4096).decode("ISO-8859-15")
            cfg.log.debug("Got banner %(data)r", {"data": data})
            if not data.startswith("SSH-") or "\n" not in data:
                sys.exit(
                    f"Expected an SSH banner from the server started at {s_address}:{s_port}, "
                    f"got {data!r}",
                )
            return

    sys.exit(f"Could not connect to {s_address}:{s_port} for 10 seconds")


def get_homedir_writable_mode(cfg: Config, homedir: pathlib.Path) -> int | None:
    """If the home directory is writable, get the interesting bits of its access mode."""
    try:
        homedir_stat: Final = homedir.stat()
    except FileNotFoundError:
        return None

    # Do not attempt to check group writability; I doubt the test environments will
    # be set up so elaborately.
    if (
        homedir_stat.st_uid != cfg.test_account.pw_ent.pw_uid
        or (homedir_stat.st_mode & stat.S_IWUSR) == 0
    ):
        return None

    return homedir_stat.st_mode & 0o7777


def run_homedir_tests(cfg: Config, ssh_cfg: SSHConfig, tempd: pathlib.Path) -> None:
    """Run the tests in the various homedir-related configurations."""

    def do_run(*, use_home: bool) -> None:
        """Run the tests."""

        def create_if_needed() -> pathlib.Path | None:
            """Create the directory that will be used as a temporary one."""
            remote_tmp_dir: Final = tempd / "remote-tmp"
            if remote_tmp_dir.is_dir():
                shutil.rmtree(remote_tmp_dir)

            if use_home:
                return None

            remote_tmp_dir.mkdir(mode=0o700)
            return remote_tmp_dir

        remote_tmp: Final = create_if_needed()
        cfg.log.debug(
            "Running tests with use_home=%(use_home)s and remote_tmp=%(remote_tmp)r",
            {"use_home": use_home, "remote_tmp": remote_tmp},
        )
        if cfg.test_prog is None:
            test_printenv = create_test_script(cfg, tempd)
            test_remrun(cfg, ssh_cfg, test_printenv, remote_tmp=remote_tmp)
        else:
            test_prog(cfg, ssh_cfg, remote_tmp=remote_tmp)

    homedir: Final = pathlib.Path(cfg.test_account.pw_ent.pw_dir)
    homedir_writable_mode: Final = get_homedir_writable_mode(cfg, homedir)
    cfg.log.debug(
        "Preparing to run tests with homedir %(homedir)r and writable mode %(mode)r",
        {"homedir": homedir, "mode": homedir_writable_mode},
    )

    if homedir_writable_mode is not None:
        do_run(use_home=True)

        if cfg.no_chmod_homedir:
            cfg.log.info("Not removing the owner write permission on %(home)s", {"home": homedir})
            do_run(use_home=False)
        else:
            cfg.log.warning(
                "Temporarily removing the owner write permission on %(home)s",
                {"home": homedir},
            )
            homedir.chmod(homedir_writable_mode & (~stat.S_IWUSR))
            try:
                do_run(use_home=False)
            finally:
                cfg.log.warning(
                    "Restoring the owner write permission on %(home)s",
                    {"home": homedir},
                )
                homedir.chmod(homedir_writable_mode)
    else:
        cfg.log.warning(
            "The %(home)s directory is not writable by %(name)s, skipping the use-home test",
            {"home": homedir, "name": cfg.test_account.pw_ent.pw_name},
        )
        do_run(use_home=False)


def main() -> None:
    """Parse command-line options, start the SSH server, run tests."""
    cfg = parse_args()

    with create_temp_dir(cfg) as tempd:
        try:
            os.chdir(tempd)
            run_test: Final = ensure_privsep_path(cfg, tempd)
            if run_test:
                addr, port = find_listening_port(cfg)

                ssh_cfg: Final = create_ssh_config(cfg, addr, port, tempd)
                cfg = create_ssh_wrapper(cfg, ssh_cfg)
                with start_sshd(cfg, ssh_cfg):
                    wait_for_sshd_banner(cfg, ssh_cfg)
                    check_ssh_connection(cfg, ssh_cfg)
                    run_homedir_tests(cfg, ssh_cfg, tempd)

                print("The remrun tool seems to be operational")
        finally:
            os.chdir("/")


if __name__ == "__main__":
    main()
