# SPDX-License-Identifier: GPL-3.0-or-later

from contextlib import contextmanager, ContextDecorator
import os
import subprocess
from typing import Any, Dict, Optional

import dns
import dns.rcode
import pytest

from kresd import CERTS_DIR, Forward, Kresd, make_kresd, make_port
import utils


HINTS = {
    '0.foo.': '127.0.0.1',
    '1.foo.': '127.0.0.1',
    '2.foo.': '127.0.0.1',
    '3.foo.': '127.0.0.1',
}


def resolve_hint(sock, qname):
    buff, msgid = utils.get_msgbuff(qname)
    sock.sendall(buff)
    answer = utils.receive_parse_answer(sock)
    assert answer.id == msgid
    assert answer.rcode() == dns.rcode.NOERROR
    assert answer.answer[0][0].address == HINTS[qname]


class Proxy(ContextDecorator):
    EXECUTABLE = ''

    def __init__(
                self,
                local_ip: str = '127.0.0.1',
                local_port: Optional[int] = None,
                upstream_ip: str = '127.0.0.1',
                upstream_port: Optional[int] = None
            ) -> None:
        self.local_ip = local_ip
        self.local_port = local_port
        self.upstream_ip = upstream_ip
        self.upstream_port = upstream_port
        self.proxy = None

    def get_args(self):
        args = []
        args.append('--local')
        args.append(self.local_ip)
        if self.local_port is not None:
            args.append('--lport')
            args.append(str(self.local_port))
        args.append('--upstream')
        args.append(self.upstream_ip)
        if self.upstream_port is not None:
            args.append('--uport')
            args.append(str(self.upstream_port))
        return args

    def __enter__(self):
        args = [self.EXECUTABLE] + self.get_args()
        print(' '.join(args))

        try:
            self.proxy = subprocess.Popen(
                args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        except subprocess.CalledProcessError:
            pytest.skip("proxy '{}' failed to run (did you compile it?)"
                        .format(self.EXECUTABLE))

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.proxy is not None:
            self.proxy.terminate()
            self.proxy = None


class TLSProxy(Proxy):
    EXECUTABLE = 'tlsproxy'

    def __init__(  # noqa: PLR0913
                self,
                local_ip: str = '127.0.0.1',
                local_port: Optional[int] = None,
                upstream_ip: str = '127.0.0.1',
                upstream_port: Optional[int] = None,
                certname: Optional[str] = 'tt',
                close: Optional[int] = None,
                rehandshake: bool = False,
                force_tls13: bool = False
            ) -> None:
        super().__init__(local_ip, local_port, upstream_ip, upstream_port)
        if certname is not None:
            self.cert_path = os.path.join(CERTS_DIR, certname + '.cert.pem')
            self.key_path = os.path.join(CERTS_DIR, certname + '.key.pem')
        else:
            self.cert_path = None
            self.key_path = None
        self.close = close
        self.rehandshake = rehandshake
        self.force_tls13 = force_tls13

    def get_args(self):
        args = super().get_args()
        if self.cert_path is not None:
            args.append('--cert')
            args.append(self.cert_path)
        if self.key_path is not None:
            args.append('--key')
            args.append(self.key_path)
        if self.close is not None:
            args.append('--close')
            args.append(str(self.close))
        if self.rehandshake:
            args.append('--rehandshake')
        if self.force_tls13:
            args.append('--tls13')
        return args


@contextmanager
def kresd_tls_client(
            workdir: str,
            proxy: TLSProxy,
            kresd_tls_client_kwargs: Optional[Dict[Any, Any]] = None,
            kresd_fwd_target_kwargs: Optional[Dict[Any, Any]] = None
        ) -> Kresd:
    """kresd_tls_client --(tls)--> tlsproxy --(tcp)--> kresd_fwd_target"""
    allowed_ips = {'127.0.0.1', '::1'}
    assert proxy.local_ip in allowed_ips, "only localhost IPs supported for proxy"
    assert proxy.upstream_ip in allowed_ips, "only localhost IPs are supported for proxy"

    if kresd_tls_client_kwargs is None:
        kresd_tls_client_kwargs = {}
    if kresd_fwd_target_kwargs is None:
        kresd_fwd_target_kwargs = {}

    # run forward target instance
    dir1 = os.path.join(workdir, 'kresd_fwd_target')
    os.makedirs(dir1)

    with make_kresd(dir1, hints=HINTS, **kresd_fwd_target_kwargs) as kresd_fwd_target:
        sock = kresd_fwd_target.ip_tcp_socket()
        resolve_hint(sock, list(HINTS.keys())[0])

        proxy.local_port = make_port('127.0.0.1', '::1')
        proxy.upstream_port = kresd_fwd_target.port

        with proxy:
            # run test kresd instance
            dir2 = os.path.join(workdir, 'kresd_tls_client')
            os.makedirs(dir2)
            forward = Forward(
                proto='tls', ip=proxy.local_ip, port=proxy.local_port,
                hostname='transport-test-server.com', ca_file=proxy.cert_path)
            with make_kresd(dir2, forward=forward, **kresd_tls_client_kwargs) as kresd:
                yield kresd
