import socket
import ssl
from contextlib import ExitStack
from threading import Thread
from typing import ContextManager, NoReturn

import pytest
from trustme import CA

from anyio import (
    BrokenResourceError,
    EndOfStream,
    Event,
    connect_tcp,
    create_task_group,
    create_tcp_listener,
)
from anyio.abc import AnyByteStream, SocketAttribute, SocketStream
from anyio.streams.tls import TLSAttribute, TLSListener, TLSStream

pytestmark = pytest.mark.anyio


class TestTLSStream:
    async def test_send_receive(
        self, server_context: ssl.SSLContext, client_context: ssl.SSLContext
    ) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            conn.settimeout(1)
            data = conn.recv(10)
            conn.send(data[::-1])
            conn.close()

        server_sock = server_context.wrap_socket(
            socket.socket(), server_side=True, suppress_ragged_eofs=False
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync)
        server_thread.start()

        async with await connect_tcp(*server_sock.getsockname()) as stream:
            wrapper = await TLSStream.wrap(
                stream, hostname="localhost", ssl_context=client_context
            )
            await wrapper.send(b"hello")
            response = await wrapper.receive()

        server_thread.join()
        server_sock.close()
        assert response == b"olleh"

    async def test_extra_attributes(
        self, server_context: ssl.SSLContext, client_context: ssl.SSLContext
    ) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            with conn:
                conn.settimeout(1)
                conn.recv(1)

        server_context.set_alpn_protocols(["h2"])
        client_context.set_alpn_protocols(["h2"])

        server_sock = server_context.wrap_socket(
            socket.socket(), server_side=True, suppress_ragged_eofs=True
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync)
        server_thread.start()

        async with await connect_tcp(*server_sock.getsockname()) as stream:
            wrapper = await TLSStream.wrap(
                stream,
                hostname="localhost",
                ssl_context=client_context,
                standard_compatible=False,
            )
            async with wrapper:
                for name, attribute in SocketAttribute.__dict__.items():
                    if not name.startswith("_"):
                        assert wrapper.extra(attribute) == stream.extra(attribute)

                assert wrapper.extra(TLSAttribute.alpn_protocol) == "h2"
                assert isinstance(
                    wrapper.extra(TLSAttribute.channel_binding_tls_unique), bytes
                )
                assert isinstance(wrapper.extra(TLSAttribute.cipher), tuple)
                assert isinstance(wrapper.extra(TLSAttribute.peer_certificate), dict)
                assert isinstance(
                    wrapper.extra(TLSAttribute.peer_certificate_binary), bytes
                )
                assert wrapper.extra(TLSAttribute.server_side) is False
                assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers), list)
                assert isinstance(wrapper.extra(TLSAttribute.ssl_object), ssl.SSLObject)
                assert wrapper.extra(TLSAttribute.standard_compatible) is False
                assert wrapper.extra(TLSAttribute.tls_version).startswith("TLSv")
                await wrapper.send(b"\x00")

        server_thread.join()
        server_sock.close()

    async def test_unwrap(
        self, server_context: ssl.SSLContext, client_context: ssl.SSLContext
    ) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            conn.settimeout(1)
            conn.send(b"encrypted")
            unencrypted = conn.unwrap()
            unencrypted.send(b"unencrypted")
            unencrypted.close()

        server_sock = server_context.wrap_socket(
            socket.socket(), server_side=True, suppress_ragged_eofs=False
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync)
        server_thread.start()

        async with await connect_tcp(*server_sock.getsockname()) as stream:
            wrapper = await TLSStream.wrap(
                stream, hostname="localhost", ssl_context=client_context
            )
            msg1 = await wrapper.receive()
            unwrapped_stream, msg2 = await wrapper.unwrap()
            if msg2 != b"unencrypted":
                msg2 += await unwrapped_stream.receive()

        server_thread.join()
        server_sock.close()
        assert msg1 == b"encrypted"
        assert msg2 == b"unencrypted"

    @pytest.mark.skipif(not ssl.HAS_ALPN, reason="ALPN support not available")
    async def test_alpn_negotiation(
        self, server_context: ssl.SSLContext, client_context: ssl.SSLContext
    ) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            conn.settimeout(1)
            selected_alpn_protocol = conn.selected_alpn_protocol()
            assert selected_alpn_protocol is not None
            conn.send(selected_alpn_protocol.encode())
            conn.close()

        server_context.set_alpn_protocols(["dummy1", "dummy2"])
        client_context.set_alpn_protocols(["dummy2", "dummy3"])

        server_sock = server_context.wrap_socket(
            socket.socket(), server_side=True, suppress_ragged_eofs=False
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync)
        server_thread.start()

        async with await connect_tcp(*server_sock.getsockname()) as stream:
            wrapper = await TLSStream.wrap(
                stream, hostname="localhost", ssl_context=client_context
            )
            assert wrapper.extra(TLSAttribute.alpn_protocol) == "dummy2"
            server_alpn_protocol = await wrapper.receive()

        server_thread.join()
        server_sock.close()
        assert server_alpn_protocol == b"dummy2"

    @pytest.mark.parametrize(
        "server_compatible, client_compatible",
        [
            pytest.param(True, True, id="both_standard"),
            pytest.param(True, False, id="server_standard"),
            pytest.param(False, True, id="client_standard"),
            pytest.param(False, False, id="neither_standard"),
        ],
    )
    async def test_ragged_eofs(
        self,
        server_context: ssl.SSLContext,
        client_context: ssl.SSLContext,
        server_compatible: bool,
        client_compatible: bool,
    ) -> None:
        server_exc = None

        def serve_sync() -> None:
            nonlocal server_exc
            conn, addr = server_sock.accept()
            try:
                conn.settimeout(1)
                conn.sendall(b"hello")
                if server_compatible:
                    conn.unwrap()
            except BaseException as exc:
                server_exc = exc
            finally:
                conn.close()

        client_cm: ContextManager = ExitStack()
        if client_compatible and not server_compatible:
            client_cm = pytest.raises(BrokenResourceError)

        server_sock = server_context.wrap_socket(
            socket.socket(),
            server_side=True,
            suppress_ragged_eofs=not server_compatible,
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync, daemon=True)
        server_thread.start()

        async with await connect_tcp(*server_sock.getsockname()) as stream:
            wrapper = await TLSStream.wrap(
                stream,
                hostname="localhost",
                ssl_context=client_context,
                standard_compatible=client_compatible,
            )
            with client_cm:
                assert await wrapper.receive() == b"hello"
                await wrapper.aclose()

        server_thread.join()
        server_sock.close()
        if not client_compatible and server_compatible:
            assert isinstance(server_exc, OSError)
            assert not isinstance(server_exc, socket.timeout)
        else:
            assert server_exc is None

    async def test_ragged_eof_on_receive(
        self, server_context: ssl.SSLContext, client_context: ssl.SSLContext
    ) -> None:
        server_exc = None

        def serve_sync() -> None:
            nonlocal server_exc
            conn, addr = server_sock.accept()
            try:
                conn.settimeout(1)
                conn.sendall(b"hello")
            except BaseException as exc:
                server_exc = exc
            finally:
                conn.close()

        server_sock = server_context.wrap_socket(
            socket.socket(), server_side=True, suppress_ragged_eofs=True
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync, daemon=True)
        server_thread.start()
        try:
            async with await connect_tcp(*server_sock.getsockname()) as stream:
                wrapper = await TLSStream.wrap(
                    stream,
                    hostname="localhost",
                    ssl_context=client_context,
                    standard_compatible=False,
                )
                assert await wrapper.receive() == b"hello"
                with pytest.raises(EndOfStream):
                    await wrapper.receive()
        finally:
            server_thread.join()
            server_sock.close()

        assert server_exc is None

    async def test_receive_send_after_eof(
        self, server_context: ssl.SSLContext, client_context: ssl.SSLContext
    ) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            conn.sendall(b"hello")
            conn.unwrap()
            conn.close()

        server_sock = server_context.wrap_socket(
            socket.socket(), server_side=True, suppress_ragged_eofs=False
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync, daemon=True)
        server_thread.start()

        stream = await connect_tcp(*server_sock.getsockname())
        async with await TLSStream.wrap(
            stream, hostname="localhost", ssl_context=client_context
        ) as wrapper:
            assert await wrapper.receive() == b"hello"
            with pytest.raises(EndOfStream):
                await wrapper.receive()

        server_thread.join()
        server_sock.close()

    @pytest.mark.parametrize(
        "force_tlsv12",
        [
            pytest.param(
                False,
                marks=[
                    pytest.mark.skipif(
                        not getattr(ssl, "HAS_TLSv1_3", False),
                        reason="No TLS 1.3 support",
                    )
                ],
            ),
            pytest.param(True),
        ],
        ids=["tlsv13", "tlsv12"],
    )
    async def test_send_eof_not_implemented(
        self, server_context: ssl.SSLContext, ca: CA, force_tlsv12: bool
    ) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            conn.sendall(b"hello")
            conn.unwrap()
            conn.close()

        client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
        ca.configure_trust(client_context)
        if force_tlsv12:
            expected_pattern = r"send_eof\(\) requires at least TLSv1.3"
            if hasattr(ssl, "TLSVersion"):
                client_context.maximum_version = ssl.TLSVersion.TLSv1_2
            else:  # Python 3.6
                client_context.options |= ssl.OP_NO_TLSv1_3
        else:
            expected_pattern = (
                r"send_eof\(\) has not yet been implemented for TLS streams"
            )

        server_sock = server_context.wrap_socket(
            socket.socket(), server_side=True, suppress_ragged_eofs=False
        )
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync, daemon=True)
        server_thread.start()

        stream = await connect_tcp(*server_sock.getsockname())
        async with await TLSStream.wrap(
            stream, hostname="localhost", ssl_context=client_context
        ) as wrapper:
            assert await wrapper.receive() == b"hello"
            with pytest.raises(NotImplementedError) as exc:
                await wrapper.send_eof()

            exc.match(expected_pattern)

        server_thread.join()
        server_sock.close()


class TestTLSListener:
    async def test_handshake_fail(self, server_context: ssl.SSLContext) -> None:
        def handler(stream: object) -> NoReturn:
            pytest.fail("This function should never be called in this scenario")

        exception = None

        class CustomTLSListener(TLSListener):
            @staticmethod
            async def handle_handshake_error(
                exc: BaseException, stream: AnyByteStream
            ) -> None:
                nonlocal exception
                await TLSListener.handle_handshake_error(exc, stream)
                assert isinstance(stream, SocketStream)
                exception = exc
                event.set()

        event = Event()
        listener = await create_tcp_listener(local_host="127.0.0.1")
        tls_listener = CustomTLSListener(listener, server_context)
        async with tls_listener, create_task_group() as tg:
            tg.start_soon(tls_listener.serve, handler)
            sock = socket.socket()
            sock.connect(listener.extra(SocketAttribute.local_address))
            sock.close()
            await event.wait()
            tg.cancel_scope.cancel()

        assert isinstance(exception, BrokenResourceError)
