"""Module takes care of sending and recieving DNS messages as a mock client"""

import errno
import socket
import struct
import time
from typing import Optional, Tuple, Union

import dns.message
import dns.inet


SOCKET_OPERATION_TIMEOUT = 5
RECEIVE_MESSAGE_SIZE = 2**16-1
THROTTLE_BY = 0.1


def handle_socket_timeout(sock: socket.socket, deadline: float):
    # deadline is always time.monotonic
    remaining = deadline - time.monotonic()
    if remaining <= 0:
        raise RuntimeError("Server took too long to respond")
    sock.settimeout(remaining)


def recv_n_bytes_from_tcp(stream: socket.socket, n: int, deadline: float) -> bytes:
    # deadline is always time.monotonic
    data = b""
    while n != 0:
        handle_socket_timeout(stream, deadline)
        chunk = stream.recv(n)
        # Empty bytes from socket.recv mean that socket is closed
        if not chunk:
            raise OSError()
        n -= len(chunk)
        data += chunk
    return data


def recvfrom_blob(sock: socket.socket,
                  timeout: int = SOCKET_OPERATION_TIMEOUT) -> Tuple[bytes, str]:
    """
    Receive DNS message from TCP/UDP socket.
    """

    # deadline is always time.monotonic
    deadline = time.monotonic() + timeout

    while True:
        try:
            if sock.type & socket.SOCK_DGRAM:
                handle_socket_timeout(sock, deadline)
                data, addr = sock.recvfrom(RECEIVE_MESSAGE_SIZE)
            elif sock.type & socket.SOCK_STREAM:
                # First 2 bytes of TCP packet are the size of the message
                # See https://tools.ietf.org/html/rfc1035#section-4.2.2
                data = recv_n_bytes_from_tcp(sock, 2, deadline)
                msg_len = struct.unpack_from("!H", data)[0]
                data = recv_n_bytes_from_tcp(sock, msg_len, deadline)
                addr = sock.getpeername()[0]
            else:
                raise NotImplementedError(f"[recvfrom_blob]: unknown socket type '{sock.type}'")
            return data, addr
        except socket.timeout as ex:
            raise RuntimeError("Server took too long to respond") from ex
        except OSError as ex:
            if ex.errno == errno.ENOBUFS:
                time.sleep(0.1)
            else:
                raise


def recvfrom_msg(sock: socket.socket,
                 timeout: int = SOCKET_OPERATION_TIMEOUT) -> Tuple[dns.message.Message, str]:
    data, addr = recvfrom_blob(sock, timeout=timeout)
    msg = dns.message.from_wire(data, one_rr_per_rrset=True)
    return msg, addr


def sendto_msg(sock: socket.socket, message: bytes, addr: Optional[str] = None) -> None:
    """ Send DNS/UDP/TCP message. """
    try:
        if sock.type & socket.SOCK_DGRAM:
            if addr is None:
                sock.send(message)
            else:
                sock.sendto(message, addr)
        elif sock.type & socket.SOCK_STREAM:
            data = struct.pack("!H", len(message)) + message
            sock.sendall(data)
        else:
            raise NotImplementedError(f"[sendto_msg]: unknown socket type '{sock.type}'")
    except OSError as ex:
        # Reference: http://lkml.iu.edu/hypermail/linux/kernel/0002.3/0709.html
        if ex.errno != errno.ECONNREFUSED:
            raise


def setup_socket(address: str,
                 port: int,
                 tcp: bool = False,
                 src_address: Optional[str] = None) -> socket.socket:
    family = dns.inet.af_for_address(address)
    sock = socket.socket(family, socket.SOCK_STREAM if tcp else socket.SOCK_DGRAM)
    if tcp:
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True)
    if src_address is not None:
        sock.bind((src_address, 0))  # random source port
    sock.settimeout(SOCKET_OPERATION_TIMEOUT)
    sock.connect((address, port))
    return sock


def send_query(sock: socket.socket, query: Union[dns.message.Message, bytes]) -> None:
    message = query if isinstance(query, bytes) else query.to_wire()
    while True:
        try:
            sendto_msg(sock, message)
            break
        except OSError as ex:
            # ENOBUFS, throttle sending
            if ex.errno == errno.ENOBUFS:
                time.sleep(0.1)
            else:
                raise


def get_answer(sock: socket.socket, timeout: int = SOCKET_OPERATION_TIMEOUT) -> bytes:
    """ Compatibility function """
    answer, _ = recvfrom_blob(sock, timeout=timeout)
    return answer


def get_dns_message(sock: socket.socket,
                    timeout: int = SOCKET_OPERATION_TIMEOUT) -> dns.message.Message:
    return dns.message.from_wire(get_answer(sock, timeout=timeout))
