# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import os
import pytest

from configuration import (
    available_ports,
    ALL_TEST_CIPHERS,
    ALL_TEST_CURVES,
    MINIMAL_TEST_CERTS,
    PROTOCOLS,
    TLS13_CIPHERS,
)
from common import ProviderOptions, Protocols, data_bytes
from fixtures import managed_process  # noqa: F401
from providers import Provider, S2N, OpenSSL
from utils import (
    invalid_test_parameters,
    get_parameter_name,
    get_expected_s2n_version,
    to_bytes,
)


@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize(
    "protocol", [p for p in PROTOCOLS if p != Protocols.TLS13], ids=get_parameter_name
)
@pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
@pytest.mark.parametrize("use_ticket", [True, False])
def test_session_resumption_s2n_server(
    managed_process,  # noqa: F811
    cipher,
    curve,
    certificate,
    protocol,
    provider,
    other_provider,
    use_ticket,
):
    port = next(available_ports)

    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        insecure=True,
        reconnect=True,
        protocol=protocol,
    )

    server_options = copy.copy(client_options)
    server_options.reconnects_before_exit = 6
    server_options.mode = Provider.ServerMode
    server_options.use_session_ticket = (use_ticket,)
    server_options.key = certificate.key
    server_options.cert = certificate.cert

    # Passing the type of client and server as a parameter will
    # allow us to use a fixture to enumerate all possibilities.
    server = managed_process(S2N, server_options, timeout=5)
    client = managed_process(provider, client_options, timeout=5)

    # The client should connect and return without error
    for results in client.get_results():
        results.assert_success()
        assert results.stdout.count(to_bytes("Session-ID:")) == 6

    expected_version = get_expected_s2n_version(protocol, OpenSSL)

    # S2N should indicate the procotol version in a successful connection.
    for results in server.get_results():
        results.assert_success()
        assert (
            results.stdout.count(
                to_bytes("Actual protocol version: {}".format(expected_version))
            )
            == 6
        )


@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize(
    "protocol", [p for p in PROTOCOLS if p != Protocols.TLS13], ids=get_parameter_name
)
@pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
@pytest.mark.parametrize("use_ticket", [True, False])
def test_session_resumption_s2n_client(
    managed_process,  # noqa: F811
    cipher,
    curve,
    protocol,
    provider,
    other_provider,
    certificate,
    use_ticket,
):
    port = next(available_ports)

    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        insecure=True,
        reconnect=True,
        use_session_ticket=use_ticket,
        protocol=protocol,
    )

    server_options = copy.copy(client_options)
    server_options.reconnects_before_exit = 6
    server_options.mode = Provider.ServerMode
    server_options.key = certificate.key
    server_options.cert = certificate.cert
    server_options.use_session_ticket = False

    # Passing the type of client and server as a parameter will
    # allow us to use a fixture to enumerate all possibilities.
    server = managed_process(provider, server_options, timeout=5)
    client = managed_process(S2N, client_options, timeout=5)

    expected_version = get_expected_s2n_version(protocol, OpenSSL)
    for results in client.get_results():
        results.assert_success()
        assert (
            results.stdout.count(
                to_bytes("Actual protocol version: {}".format(expected_version))
            )
            == 6
        )

    for results in server.get_results():
        results.assert_success()
        assert results.stdout.count(to_bytes("6 server accepts that finished"))


@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("cipher", TLS13_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name)
@pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
def test_tls13_session_resumption_s2n_server(
    managed_process,  # noqa: F811
    tmp_path,
    cipher,
    curve,
    certificate,
    protocol,
    provider,
    other_provider,
):
    port = str(next(available_ports))

    # Use temp directory to store session tickets
    p = tmp_path / "ticket.pem"
    path_to_ticket = str(p)

    close_marker_bytes = data_bytes(10)

    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        insecure=True,
        reconnect=False,
        extra_flags=["-sess_out", path_to_ticket],
        protocol=protocol,
    )

    server_options = copy.copy(client_options)
    server_options.mode = Provider.ServerMode
    server_options.key = certificate.key
    server_options.cert = certificate.cert
    server_options.use_session_ticket = True
    server_options.extra_flags = None
    server_options.data_to_send = close_marker_bytes

    server = managed_process(
        S2N, server_options, timeout=5, send_marker=S2N.get_send_marker()
    )
    client = managed_process(
        provider, client_options, timeout=5, close_marker=str(close_marker_bytes)
    )

    # The client should have received a session ticket
    for results in client.get_results():
        results.assert_success()
        assert b"Post-Handshake New Session Ticket arrived:" in results.stdout

    for results in server.get_results():
        results.assert_success()
        # The first connection is a full handshake
        assert b"Resumed session" not in results.stdout

    # Client inputs received session ticket to resume a session
    assert os.path.exists(path_to_ticket)
    client_options.extra_flags = ["-sess_in", path_to_ticket]

    port = str(next(available_ports))
    client_options.port = port
    server_options.port = port

    server = managed_process(
        S2N, server_options, timeout=5, send_marker=S2N.get_send_marker()
    )
    client = managed_process(
        provider, client_options, timeout=5, close_marker=str(close_marker_bytes)
    )

    s2n_version = get_expected_s2n_version(protocol, provider)

    # Client has not read server certificate message as this is a resumed session
    for results in client.get_results():
        results.assert_success()
        assert (
            to_bytes("SSL_connect:SSLv3/TLS read server certificate")
            not in results.stderr
        )

    # The server should indicate a session has been resumed
    for results in server.get_results():
        results.assert_success()
        assert b"Resumed session" in results.stdout
        assert (
            to_bytes("Actual protocol version: {}".format(s2n_version))
            in results.stdout
        )


@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("cipher", TLS13_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name)
@pytest.mark.parametrize("provider", [OpenSSL, S2N], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
def test_tls13_session_resumption_s2n_client(
    managed_process,  # noqa: F811
    cipher,
    curve,
    certificate,
    protocol,
    provider,
    other_provider,
):
    port = str(next(available_ports))

    # The reconnect option for s2nc allows the client to reconnect automatically
    # five times. In this test we expect one full connection and five resumption
    # connections.
    num_full_connections = 1
    num_resumed_connections = 5

    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        insecure=True,
        use_session_ticket=True,
        reconnect=True,
        protocol=protocol,
    )

    server_options = copy.copy(client_options)
    server_options.mode = Provider.ServerMode
    server_options.key = certificate.key
    server_options.cert = certificate.cert
    server_options.reconnects_before_exit = (
        num_resumed_connections + num_full_connections
    )

    server = managed_process(provider, server_options, timeout=5)
    client = managed_process(S2N, client_options, timeout=5)

    s2n_version = get_expected_s2n_version(protocol, provider)

    # s2nc indicates the number of resumed connections in its output
    for results in client.get_results():
        results.assert_success()
        assert results.stdout.count(b"Resumed session") == num_resumed_connections
        assert (
            to_bytes("Actual protocol version: {}".format(s2n_version))
            in results.stdout
        )

    server_accepts_str = (
        str(num_resumed_connections + num_full_connections)
        + " server accepts that finished"
    )

    for results in server.get_results():
        results.assert_success()
        if provider is S2N:
            assert results.stdout.count(b"Resumed session") == num_resumed_connections
            assert (
                to_bytes("Actual protocol version: {}".format(s2n_version))
                in results.stdout
            )
        else:
            assert to_bytes(server_accepts_str) in results.stdout
            # s_server only writes one certificate message in all of the connections
            assert (
                results.stderr.count(b"SSL_accept:SSLv3/TLS write certificate")
                == num_full_connections
            )


@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("cipher", TLS13_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize("protocol", [Protocols.TLS13], ids=get_parameter_name)
@pytest.mark.parametrize("provider", [OpenSSL], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
def test_s2nd_falls_back_to_full_connection(
    managed_process,  # noqa: F811
    tmp_path,
    cipher,
    curve,
    certificate,
    protocol,
    provider,
    other_provider,
):
    port = str(next(available_ports))

    # Use temp directory to store session tickets
    p = tmp_path / "ticket.pem"
    path_to_ticket = str(p)

    """
    This test will set up a full connection with an Openssl client and server to obtain
    a valid Openssl session ticket. Then, the Openssl client attempts to send the 
    received session ticket to an s2n server to resume a session. s2nd will fallback to
    a full connection as it does not recognize the session ticket.
    """
    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        insecure=True,
        reconnect=False,
        extra_flags=["-sess_out", path_to_ticket],
        data_to_send=data_bytes(4069),
        protocol=protocol,
    )

    server_options = copy.copy(client_options)
    server_options.mode = Provider.ServerMode
    server_options.key = certificate.key
    server_options.cert = certificate.cert
    server_options.extra_flags = None

    server = managed_process(provider, server_options)

    # The send_marker prevents us from closing stdin until the client receives
    # the session ticket. If we close stdin too early, then the client will exit
    # the echo loop and stop listening for new messages.
    client = managed_process(
        provider, client_options, send_marker="TLS read server session ticket"
    )

    # The client should have received a session ticket
    for results in client.get_results():
        results.assert_success()
        assert b"Post-Handshake New Session Ticket arrived:" in results.stdout

    for results in server.get_results():
        results.assert_success()
        # Server should have sent certificate message as this is a full connection
        assert b"SSL_accept:SSLv3/TLS write certificate" in results.stderr

    # Client inputs received session ticket to resume a session
    assert os.path.exists(path_to_ticket)
    client_options.extra_flags = ["-sess_in", path_to_ticket]

    port = str(next(available_ports))
    client_options.port = port
    server_options.port = port

    # Switch providers so now s2n is the server
    server = managed_process(S2N, server_options, timeout=5)
    client = managed_process(provider, client_options, timeout=5)

    s2n_version = get_expected_s2n_version(protocol, provider)

    # Client has read server certificate because this is a full connection
    for results in client.get_results():
        results.assert_success()
        assert (
            to_bytes("SSL_connect:SSLv3/TLS read server certificate") in results.stderr
        )

    # The server should indicate a session has not been resumed
    for results in server.get_results():
        results.assert_success()
        assert b"Resumed session" not in results.stdout
        assert (
            to_bytes("Actual protocol version: {}".format(s2n_version))
            in results.stdout
        )


@pytest.mark.uncollect_if(func=invalid_test_parameters)
@pytest.mark.parametrize("cipher", ALL_TEST_CIPHERS, ids=get_parameter_name)
@pytest.mark.parametrize("curve", ALL_TEST_CURVES, ids=get_parameter_name)
@pytest.mark.parametrize("certificate", MINIMAL_TEST_CERTS, ids=get_parameter_name)
@pytest.mark.parametrize(
    "protocol", [p for p in PROTOCOLS if p < Protocols.TLS13], ids=get_parameter_name
)
@pytest.mark.parametrize("provider", [OpenSSL, S2N], ids=get_parameter_name)
@pytest.mark.parametrize("other_provider", [S2N], ids=get_parameter_name)
def test_session_resumption_s2n_client_tls13_server_not_tls13(
    managed_process,  # noqa: F811
    cipher,
    curve,
    protocol,
    provider,
    other_provider,
    certificate,
):
    port = next(available_ports)

    # This test verifies that an S2N client that supports TLS1.3 can resume sessions
    # with a server that does not support TLS1.3

    # The reconnect option for s2nc allows the client to reconnect automatically
    # five times. In this test we expect one full connection and five resumption
    # connections.
    num_full_connections = 1
    num_resumed_connections = 5

    client_options = ProviderOptions(
        mode=Provider.ClientMode,
        port=port,
        cipher=cipher,
        curve=curve,
        insecure=True,
        reconnect=True,
        use_session_ticket=True,
        protocol=Protocols.TLS13,
    )

    server_options = ProviderOptions(
        mode=Provider.ServerMode,
        port=port,
        cipher=cipher,
        curve=curve,
        insecure=True,
        reconnect=True,
        use_session_ticket=False,
        protocol=protocol,
        reconnects_before_exit=num_resumed_connections + num_full_connections,
        key=certificate.key,
        cert=certificate.cert,
    )

    # Passing the type of client and server as a parameter will
    # allow us to use a fixture to enumerate all possibilities.
    server = managed_process(provider, server_options, timeout=5)
    client = managed_process(S2N, client_options, timeout=5)

    expected_version = get_expected_s2n_version(protocol, provider)
    server_accepts_str = (
        str(num_resumed_connections + num_full_connections)
        + " server accepts that finished"
    )

    for results in client.get_results():
        results.assert_success()
        assert results.stdout.count(b"Resumed session") == num_resumed_connections
        assert (
            to_bytes("Actual protocol version: {}".format(expected_version))
            in results.stdout
        )

    for results in server.get_results():
        results.assert_success()
        if provider is S2N:
            assert results.stdout.count(b"Resumed session") == num_resumed_connections
            assert (
                to_bytes("Actual protocol version: {}".format(expected_version))
                in results.stdout
            )
        else:
            assert to_bytes(server_accepts_str) in results.stdout
