import json
import os
import struct
import sys
import tempfile
import time
import timeit

from ._aux import posix_redirect_output, update_sys_path
from .discovery import disc_benchmarks
from .run import _run

wall_timer = timeit.default_timer


def recvall(sock, size):
    """
    Receives data from a socket until the specified size of data has been received.

    #### Parameters
    **sock** (`socket`)
    : The socket from which the data will be received. This socket should
    already be connected to the other end from which data is to be received.

    **size** (`int`)
    : The total size of data to be received from the socket.

    #### Returns
    **data** (`bytes`)
    : The data received from the socket. The length of this data will be equal
    to the size specified.

    #### Raises
    **RuntimeError**
    : If the socket closed before the specified size of data could be received.

    #### Notes
    This function continuously receives data from the provided socket in a loop
    until the total length of the received data is equal to the specified size.
    If the socket closes before the specified size of data could be received, a
    `RuntimeError` is raised. The function returns the received data as a byte
    string.
    """
    data = b""
    while len(data) < size:
        s = sock.recv(size - len(data))
        data += s
        if not s:
            raise RuntimeError(
                "did not receive data from socket " f"(size {size}, got only {data !r})"
            )
    return data


def _run_server(args):
    """
    Runs a server that executes benchmarks based on the received commands.

    #### Parameters
    **args** (`tuple`)
    : A tuple containing the benchmark directory and socket name.

    - `benchmark_dir` (`str`): The directory where the benchmarks are located.
    - `socket_name` (`str`): The name of the UNIX socket to be used for
    - communication.

    #### Raises
    **RuntimeError**
    : If the received command contains unknown data.

    #### Notes
    This function creates a server that listens on a UNIX socket for commands.
    It can perform two actions based on the received command: quit or preimport
    benchmarks.

    If the command is "quit", the server stops running. If the command is
    "preimport", the function imports all the benchmarks in the specified
    directory, capturing all the I/O to a file during import. After the
    benchmarks are imported, the function sends the contents of the output file
    back through the socket.

    If the action is not "quit" or "preimport", the function assumes it is a
    command to run a specific benchmark. It then runs the benchmark and waits
    for the results. It also handles a timeout for the benchmark execution and
    sends the results back through the socket.

    The function continuously accepts new commands until it receives a "quit"
    command or a KeyboardInterrupt.

    It uses UNIX domain sockets for inter-process communication. The name of the
    socket is passed as a parameter in `args`. The socket is created, bound to
    the socket name, and set to listen for connections. When a connection is
    accepted, the command is read from the socket, parsed, and executed
    accordingly. After executing the command, the server sends back the result
    through the socket and waits for the next command.
    """
    import signal
    import socket

    (
        benchmark_dir,
        socket_name,
    ) = args

    update_sys_path(benchmark_dir)

    # Socket I/O
    s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    s.bind(socket_name)
    s.listen(1)

    # Read and act on commands from socket
    while True:
        stdout_file = None

        try:
            conn, addr = s.accept()
        except KeyboardInterrupt:
            break

        try:
            fd, stdout_file = tempfile.mkstemp()
            os.close(fd)

            # Read command
            (read_size,) = struct.unpack("<Q", recvall(conn, 8))
            command_text = recvall(conn, read_size)
            command_text = command_text.decode("utf-8")

            # Parse command
            command = json.loads(command_text)
            action = command.pop("action")

            if action == "quit":
                break
            elif action == "preimport":
                # Import benchmark suite before forking.
                # Capture I/O to a file during import.
                with posix_redirect_output(stdout_file, permanent=False):
                    for _ in disc_benchmarks(benchmark_dir, ignore_import_errors=True):
                        pass

                # Report result
                with open(stdout_file, errors="replace") as f:
                    out = f.read()
                out = json.dumps(out)
                out = out.encode("utf-8")
                conn.sendall(struct.pack("<Q", len(out)))
                conn.sendall(out)
                continue

            benchmark_id = command.pop("benchmark_id")
            params_str = command.pop("params_str")
            profile_path = command.pop("profile_path")
            result_file = command.pop("result_file")
            timeout = command.pop("timeout")
            cwd = command.pop("cwd")

            if command:
                raise RuntimeError(f"Command contained unknown data: {command_text !r}")

            # Spawn benchmark
            run_args = (
                benchmark_dir,
                benchmark_id,
                params_str,
                profile_path,
                result_file,
            )
            pid = os.fork()
            if pid == 0:
                conn.close()
                sys.stdin.close()
                exitcode = 1
                try:
                    with posix_redirect_output(stdout_file, permanent=True):
                        try:
                            os.chdir(cwd)
                            _run(run_args)
                            exitcode = 0
                        except BaseException:
                            import traceback

                            traceback.print_exc()
                finally:
                    os._exit(exitcode)

            # Wait for results
            # (Poll in a loop is simplest --- also used by subprocess.py)
            start_time = wall_timer()
            is_timeout = False
            time2sleep = 1e-15
            while True:
                res, status = os.waitpid(pid, os.WNOHANG)
                if res != 0:
                    break

                if timeout is not None and wall_timer() > start_time + timeout:
                    # Timeout
                    if is_timeout:
                        os.kill(pid, signal.SIGKILL)
                    else:
                        os.kill(pid, signal.SIGTERM)
                    is_timeout = True
                time2sleep *= 1e1
                time.sleep(min(time2sleep, 0.001))

            # Report result
            with open(stdout_file, errors="replace") as f:
                out = f.read()

            # Emulate subprocess
            if os.WIFSIGNALED(status):
                retcode = -os.WTERMSIG(status)
            elif os.WIFEXITED(status):
                retcode = os.WEXITSTATUS(status)
            elif os.WIFSTOPPED(status):
                retcode = -os.WSTOPSIG(status)
            else:
                # shouldn't happen, but fail silently
                retcode = -128

            info = {"out": out, "errcode": -256 if is_timeout else retcode}

            result_text = json.dumps(info)
            result_text = result_text.encode("utf-8")

            conn.sendall(struct.pack("<Q", len(result_text)))
            conn.sendall(result_text)
        except KeyboardInterrupt:
            break
        finally:
            conn.close()
            if stdout_file is not None:
                os.unlink(stdout_file)
