"""
Tests that exercise the HTTP Upgrade functionality of
hyper-h2, ensuring that clients and servers can upgrade their plaintext
HTTP/1.1 connections to HTTP/2.
"""
from __future__ import annotations

import base64

import pytest

import h2.config
import h2.connection
import h2.errors
import h2.events
import h2.exceptions
from h2.utilities import utf8_encode_headers

EXAMPLE_REQUEST_HEADERS = [
    (":authority", "example.com"),
    (":path", "/"),
    (":scheme", "https"),
    (":method", "GET"),
]
EXAMPLE_REQUEST_HEADERS_BYTES = [
    (b":authority", b"example.com"),
    (b":path", b"/"),
    (b":scheme", b"https"),
    (b":method", b"GET"),
]


class TestClientUpgrade:
    """
    Tests of the client-side of the HTTP/2 upgrade dance.
    """

    example_response_headers = [
        (b":status", b"200"),
        (b"server", b"fake-serv/0.1.0"),
    ]

    def test_returns_http2_settings(self, frame_factory) -> None:
        """
        Calling initiate_upgrade_connection returns a base64url encoded
        Settings frame with the settings used by the connection.
        """
        conn = h2.connection.H2Connection()
        data = conn.initiate_upgrade_connection()

        # The base64 encoding must not be padded.
        assert not data.endswith(b"=")

        # However, SETTINGS frames should never need to be padded.
        decoded_frame = base64.urlsafe_b64decode(data)
        expected_frame = frame_factory.build_settings_frame(
            settings=conn.local_settings,
        )
        assert decoded_frame == expected_frame.serialize_body()

    def test_emits_preamble(self, frame_factory) -> None:
        """
        Calling initiate_upgrade_connection emits the connection preamble.
        """
        conn = h2.connection.H2Connection()
        conn.initiate_upgrade_connection()

        data = conn.data_to_send()
        assert data.startswith(frame_factory.preamble())

        data = data[len(frame_factory.preamble()):]
        expected_frame = frame_factory.build_settings_frame(
            settings=conn.local_settings,
        )
        assert data == expected_frame.serialize()

    def test_can_receive_response(self, frame_factory) -> None:
        """
        After upgrading, we can safely receive a response.
        """
        c = h2.connection.H2Connection()
        c.initiate_upgrade_connection()
        c.clear_outbound_data_buffer()

        f1 = frame_factory.build_headers_frame(
            stream_id=1,
            headers=self.example_response_headers,
        )
        f2 = frame_factory.build_data_frame(
            stream_id=1,
            data=b"some data",
            flags=["END_STREAM"],
        )
        events = c.receive_data(f1.serialize() + f2.serialize())
        assert len(events) == 3

        assert isinstance(events[0], h2.events.ResponseReceived)
        assert isinstance(events[1], h2.events.DataReceived)
        assert isinstance(events[2], h2.events.StreamEnded)

        assert events[0].headers == self.example_response_headers
        assert events[1].data == b"some data"
        assert all(e.stream_id == 1 for e in events)

        assert not c.data_to_send()

    @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES])
    def test_can_receive_pushed_stream(self, frame_factory, headers) -> None:
        """
        After upgrading, we can safely receive a pushed stream.
        """
        c = h2.connection.H2Connection()
        c.initiate_upgrade_connection()
        c.clear_outbound_data_buffer()

        f = frame_factory.build_push_promise_frame(
            stream_id=1,
            promised_stream_id=2,
            headers=headers,
        )
        events = c.receive_data(f.serialize())
        assert len(events) == 1

        assert isinstance(events[0], h2.events.PushedStreamReceived)
        assert events[0].headers == utf8_encode_headers(headers)
        assert events[0].parent_stream_id == 1
        assert events[0].pushed_stream_id == 2

    @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES])
    def test_cannot_send_headers_stream_1(self, frame_factory, headers) -> None:
        """
        After upgrading, we cannot send headers on stream 1.
        """
        c = h2.connection.H2Connection()
        c.initiate_upgrade_connection()
        c.clear_outbound_data_buffer()

        with pytest.raises(h2.exceptions.ProtocolError):
            c.send_headers(stream_id=1, headers=headers)

    def test_cannot_send_data_stream_1(self, frame_factory) -> None:
        """
        After upgrading, we cannot send data on stream 1.
        """
        c = h2.connection.H2Connection()
        c.initiate_upgrade_connection()
        c.clear_outbound_data_buffer()

        with pytest.raises(h2.exceptions.ProtocolError):
            c.send_data(stream_id=1, data=b"some data")


class TestServerUpgrade:
    """
    Tests of the server-side of the HTTP/2 upgrade dance.
    """

    example_response_headers = [
        (b":status", b"200"),
        (b"server", b"fake-serv/0.1.0"),
    ]
    server_config = h2.config.H2Configuration(client_side=False)

    def test_returns_nothing(self, frame_factory) -> None:
        """
        Calling initiate_upgrade_connection returns nothing.
        """
        conn = h2.connection.H2Connection(config=self.server_config)
        curl_header = b"AAMAAABkAAQAAP__"
        data = conn.initiate_upgrade_connection(curl_header)
        assert data is None

    def test_emits_preamble(self, frame_factory) -> None:
        """
        Calling initiate_upgrade_connection emits the connection preamble.
        """
        conn = h2.connection.H2Connection(config=self.server_config)
        conn.initiate_upgrade_connection()

        data = conn.data_to_send()
        expected_frame = frame_factory.build_settings_frame(
            settings=conn.local_settings,
        )
        assert data == expected_frame.serialize()

    def test_can_send_response(self, frame_factory) -> None:
        """
        After upgrading, we can safely send a response.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_upgrade_connection()
        c.clear_outbound_data_buffer()

        c.send_headers(stream_id=1, headers=self.example_response_headers)
        c.send_data(stream_id=1, data=b"some data", end_stream=True)

        f1 = frame_factory.build_headers_frame(
            stream_id=1,
            headers=self.example_response_headers,
        )
        f2 = frame_factory.build_data_frame(
            stream_id=1,
            data=b"some data",
            flags=["END_STREAM"],
        )

        expected_data = f1.serialize() + f2.serialize()
        assert c.data_to_send() == expected_data

    @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES])
    def test_can_push_stream(self, frame_factory, headers) -> None:
        """
        After upgrading, we can safely push a stream.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_upgrade_connection()
        c.clear_outbound_data_buffer()

        c.push_stream(
            stream_id=1,
            promised_stream_id=2,
            request_headers=headers,
        )

        f = frame_factory.build_push_promise_frame(
            stream_id=1,
            promised_stream_id=2,
            headers=headers,
        )
        assert c.data_to_send() == f.serialize()

    @pytest.mark.parametrize("headers", [EXAMPLE_REQUEST_HEADERS, EXAMPLE_REQUEST_HEADERS_BYTES])
    def test_cannot_receive_headers_stream_1(self, frame_factory, headers) -> None:
        """
        After upgrading, we cannot receive headers on stream 1.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_upgrade_connection()
        c.receive_data(frame_factory.preamble())
        c.clear_outbound_data_buffer()

        f = frame_factory.build_headers_frame(
            stream_id=1,
            headers=headers,
        )
        c.receive_data(f.serialize())

        expected_frame = frame_factory.build_rst_stream_frame(
            stream_id=1,
            error_code=h2.errors.ErrorCodes.STREAM_CLOSED,
        )
        assert c.data_to_send() == expected_frame.serialize()

    def test_cannot_receive_data_stream_1(self, frame_factory) -> None:
        """
        After upgrading, we cannot receive data on stream 1.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_upgrade_connection()
        c.receive_data(frame_factory.preamble())
        c.clear_outbound_data_buffer()

        f = frame_factory.build_data_frame(
            stream_id=1,
            data=b"some data",
        )
        c.receive_data(f.serialize())

        expected = frame_factory.build_rst_stream_frame(
            stream_id=1,
            error_code=h2.errors.ErrorCodes.STREAM_CLOSED,
        ).serialize()
        assert c.data_to_send() == expected

    def test_client_settings_are_applied(self, frame_factory) -> None:
        """
        The settings provided by the client are applied and immediately
        ACK'ed.
        """
        server = h2.connection.H2Connection(config=self.server_config)
        client = h2.connection.H2Connection()

        # As a precaution, let's confirm that the server and client, at the
        # start of the connection, do not agree on their initial settings
        # state.
        assert (
            client.local_settings != server.remote_settings
        )

        # Get the client header data and pass it to the server.
        header_data = client.initiate_upgrade_connection()
        server.initiate_upgrade_connection(header_data)

        # This gets complex, but here we go.
        # RFC 7540 § 3.2.1 says that "explicit acknowledgement" of the settings
        # in the header is "not necessary". That's annoyingly vague, but we
        # interpret that to mean "should not be sent". So to test that this
        # worked we need to test that the server has only sent the preamble,
        # and has not sent a SETTINGS ack, and also that the server has the
        # correct settings.
        expected_frame = frame_factory.build_settings_frame(
            server.local_settings,
        )
        assert server.data_to_send() == expected_frame.serialize()

        assert (
            client.local_settings == server.remote_settings
        )
