"""
More complex tests that try to do more.

Certain tests don't really eliminate incorrect behaviour unless they do quite
a bit. These tests should live here, to keep the pain in once place rather than
hide it in the other parts of the test suite.
"""
from __future__ import annotations

import pytest

import h2
import h2.config
import h2.connection


class TestComplexClient:
    """
    Complex tests for client-side stacks.
    """

    example_request_headers = [
        (":authority", "example.com"),
        (":path", "/"),
        (":scheme", "https"),
        (":method", "GET"),
    ]
    example_response_headers = [
        (":status", "200"),
        ("server", "fake-serv/0.1.0"),
    ]

    def test_correctly_count_server_streams(self, frame_factory) -> None:
        """
        We correctly count the number of server streams, both inbound and
        outbound.
        """
        # This test makes no sense unless you do both inbound and outbound,
        # because it's important to confirm that we count them correctly.
        c = h2.connection.H2Connection()
        c.initiate_connection()
        expected_inbound_streams = expected_outbound_streams = 0

        assert c.open_inbound_streams == expected_inbound_streams
        assert c.open_outbound_streams == expected_outbound_streams

        for stream_id in range(1, 15, 2):
            # Open an outbound stream
            c.send_headers(stream_id, self.example_request_headers)
            expected_outbound_streams += 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            # Receive a pushed stream (to create an inbound one). This doesn't
            # open until we also receive headers.
            f = frame_factory.build_push_promise_frame(
                stream_id=stream_id,
                promised_stream_id=stream_id+1,
                headers=self.example_request_headers,
            )
            c.receive_data(f.serialize())
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            f = frame_factory.build_headers_frame(
                stream_id=stream_id+1,
                headers=self.example_response_headers,
            )
            c.receive_data(f.serialize())
            expected_inbound_streams += 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

        for stream_id in range(13, 0, -2):
            # Close an outbound stream.
            c.end_stream(stream_id)

            # Stream doesn't close until both sides close it.
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            f = frame_factory.build_headers_frame(
                stream_id=stream_id,
                headers=self.example_response_headers,
                flags=["END_STREAM"],
            )
            c.receive_data(f.serialize())
            expected_outbound_streams -= 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            # Pushed streams can only be closed remotely.
            f = frame_factory.build_data_frame(
                stream_id=stream_id+1,
                data=b"the content",
                flags=["END_STREAM"],
            )
            c.receive_data(f.serialize())
            expected_inbound_streams -= 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

        assert c.open_inbound_streams == 0
        assert c.open_outbound_streams == 0


class TestComplexServer:
    """
    Complex tests for server-side stacks.
    """

    example_request_headers = [
        (b":authority", b"example.com"),
        (b":path", b"/"),
        (b":scheme", b"https"),
        (b":method", b"GET"),
    ]
    example_response_headers = [
        (b":status", b"200"),
        (b"server", b"fake-serv/0.1.0"),
    ]
    server_config = h2.config.H2Configuration(client_side=False)

    def test_correctly_count_server_streams(self, frame_factory) -> None:
        """
        We correctly count the number of server streams, both inbound and
        outbound.
        """
        # This test makes no sense unless you do both inbound and outbound,
        # because it's important to confirm that we count them correctly.
        c = h2.connection.H2Connection(config=self.server_config)
        c.receive_data(frame_factory.preamble())
        expected_inbound_streams = expected_outbound_streams = 0

        assert c.open_inbound_streams == expected_inbound_streams
        assert c.open_outbound_streams == expected_outbound_streams

        for stream_id in range(1, 15, 2):
            # Receive an inbound stream.
            f = frame_factory.build_headers_frame(
                headers=self.example_request_headers,
                stream_id=stream_id,
            )
            c.receive_data(f.serialize())
            expected_inbound_streams += 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            # Push a stream (to create a outbound one). This doesn't open
            # until we send our response headers.
            c.push_stream(stream_id, stream_id+1, self.example_request_headers)
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            c.send_headers(stream_id+1, self.example_response_headers)
            expected_outbound_streams += 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

        for stream_id in range(13, 0, -2):
            # Close an inbound stream.
            f = frame_factory.build_data_frame(
                data=b"",
                flags=["END_STREAM"],
                stream_id=stream_id,
            )
            c.receive_data(f.serialize())

            # Stream doesn't close until both sides close it.
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            c.send_data(stream_id, b"", end_stream=True)
            expected_inbound_streams -= 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

            # Pushed streams, however, we can close ourselves.
            c.send_data(
                stream_id=stream_id+1,
                data=b"",
                end_stream=True,
            )
            expected_outbound_streams -= 1
            assert c.open_inbound_streams == expected_inbound_streams
            assert c.open_outbound_streams == expected_outbound_streams

        assert c.open_inbound_streams == 0
        assert c.open_outbound_streams == 0


class TestContinuationFrames:
    """
    Tests for the relatively complex CONTINUATION frame logic.
    """

    example_request_headers = [
        (b":authority", b"example.com"),
        (b":path", b"/"),
        (b":scheme", b"https"),
        (b":method", b"GET"),
    ]
    server_config = h2.config.H2Configuration(client_side=False)

    def _build_continuation_sequence(self, headers, block_size, frame_factory):
        f = frame_factory.build_headers_frame(headers)
        header_data = f.data
        chunks = [
            header_data[x:x+block_size]
            for x in range(0, len(header_data), block_size)
        ]
        f.data = chunks.pop(0)
        frames = [
            frame_factory.build_continuation_frame(c) for c in chunks
        ]
        f.flags = {"END_STREAM"}
        frames[-1].flags.add("END_HEADERS")
        frames.insert(0, f)
        return frames

    def test_continuation_frame_basic(self, frame_factory) -> None:
        """
        Test that we correctly decode a header block split across continuation
        frames.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_connection()
        c.receive_data(frame_factory.preamble())

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        data = b"".join(f.serialize() for f in frames)
        events = c.receive_data(data)

        assert len(events) == 2
        first_event, second_event = events

        assert isinstance(first_event, h2.events.RequestReceived)
        assert first_event.headers == self.example_request_headers
        assert first_event.stream_id == 1

        assert isinstance(second_event, h2.events.StreamEnded)
        assert second_event.stream_id == 1

    @pytest.mark.parametrize("stream_id", [3, 1])
    def test_continuation_cannot_interleave_headers(self,
                                                    frame_factory,
                                                    stream_id) -> None:
        """
        We cannot interleave a new headers block with a CONTINUATION sequence.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_connection()
        c.receive_data(frame_factory.preamble())
        c.clear_outbound_data_buffer()

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        assert len(frames) > 2  # This is mostly defensive.

        bogus_frame = frame_factory.build_headers_frame(
            headers=self.example_request_headers,
            stream_id=stream_id,
            flags=["END_STREAM"],
        )
        frames.insert(len(frames) - 2, bogus_frame)
        data = b"".join(f.serialize() for f in frames)

        with pytest.raises(h2.exceptions.ProtocolError) as e:
            c.receive_data(data)

        assert "invalid frame" in str(e.value).lower()

    def test_continuation_cannot_interleave_data(self, frame_factory) -> None:
        """
        We cannot interleave a data frame with a CONTINUATION sequence.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_connection()
        c.receive_data(frame_factory.preamble())
        c.clear_outbound_data_buffer()

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        assert len(frames) > 2  # This is mostly defensive.

        bogus_frame = frame_factory.build_data_frame(
            data=b"hello",
            stream_id=1,
        )
        frames.insert(len(frames) - 2, bogus_frame)
        data = b"".join(f.serialize() for f in frames)

        with pytest.raises(h2.exceptions.ProtocolError) as e:
            c.receive_data(data)

        assert "invalid frame" in str(e.value).lower()

    def test_continuation_cannot_interleave_unknown_frame(self, frame_factory) -> None:
        """
        We cannot interleave an unknown frame with a CONTINUATION sequence.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_connection()
        c.receive_data(frame_factory.preamble())
        c.clear_outbound_data_buffer()

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        assert len(frames) > 2  # This is mostly defensive.

        bogus_frame = frame_factory.build_data_frame(
            data=b"hello",
            stream_id=1,
        )
        bogus_frame.type = 88
        frames.insert(len(frames) - 2, bogus_frame)
        data = b"".join(f.serialize() for f in frames)

        with pytest.raises(h2.exceptions.ProtocolError) as e:
            c.receive_data(data)

        assert "invalid frame" in str(e.value).lower()

    def test_continuation_frame_multiple_blocks(self, frame_factory) -> None:
        """
        Test that we correctly decode several header blocks split across
        continuation frames.
        """
        c = h2.connection.H2Connection(config=self.server_config)
        c.initiate_connection()
        c.receive_data(frame_factory.preamble())

        for stream_id in range(1, 7, 2):
            frames = self._build_continuation_sequence(
                headers=self.example_request_headers,
                block_size=2,
                frame_factory=frame_factory,
            )
            for frame in frames:
                frame.stream_id = stream_id

            data = b"".join(f.serialize() for f in frames)
            events = c.receive_data(data)

            assert len(events) == 2
            first_event, second_event = events

            assert isinstance(first_event, h2.events.RequestReceived)
            assert first_event.headers == self.example_request_headers
            assert first_event.stream_id == stream_id

            assert isinstance(second_event, h2.events.StreamEnded)
            assert second_event.stream_id == stream_id


class TestContinuationFramesPushPromise:
    """
    Tests for the relatively complex CONTINUATION frame logic working with
    PUSH_PROMISE frames.
    """

    example_request_headers = [
        (b":authority", b"example.com"),
        (b":path", b"/"),
        (b":scheme", b"https"),
        (b":method", b"GET"),
    ]
    example_response_headers = [
        (b":status", b"200"),
        (b"server", b"fake-serv/0.1.0"),
    ]

    def _build_continuation_sequence(self, headers, block_size, frame_factory):
        f = frame_factory.build_push_promise_frame(
            stream_id=1, promised_stream_id=2, headers=headers,
        )
        header_data = f.data
        chunks = [
            header_data[x:x+block_size]
            for x in range(0, len(header_data), block_size)
        ]
        f.data = chunks.pop(0)
        frames = [
            frame_factory.build_continuation_frame(c) for c in chunks
        ]
        f.flags = {"END_STREAM"}
        frames[-1].flags.add("END_HEADERS")
        frames.insert(0, f)
        return frames

    def test_continuation_frame_basic_push_promise(self, frame_factory) -> None:
        """
        Test that we correctly decode a header block split across continuation
        frames when that header block is initiated with a PUSH_PROMISE.
        """
        c = h2.connection.H2Connection()
        c.initiate_connection()
        c.send_headers(stream_id=1, headers=self.example_request_headers)

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        data = b"".join(f.serialize() for f in frames)
        events = c.receive_data(data)

        assert len(events) == 1
        event = events[0]

        assert isinstance(event, h2.events.PushedStreamReceived)
        assert event.headers == self.example_request_headers
        assert event.parent_stream_id == 1
        assert event.pushed_stream_id == 2

    @pytest.mark.parametrize("stream_id", [3, 1, 2])
    def test_continuation_cannot_interleave_headers_pp(self,
                                                       frame_factory,
                                                       stream_id) -> None:
        """
        We cannot interleave a new headers block with a CONTINUATION sequence
        when the headers block is based on a PUSH_PROMISE frame.
        """
        c = h2.connection.H2Connection()
        c.initiate_connection()
        c.send_headers(stream_id=1, headers=self.example_request_headers)

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        assert len(frames) > 2  # This is mostly defensive.

        bogus_frame = frame_factory.build_headers_frame(
            headers=self.example_response_headers,
            stream_id=stream_id,
            flags=["END_STREAM"],
        )
        frames.insert(len(frames) - 2, bogus_frame)
        data = b"".join(f.serialize() for f in frames)

        with pytest.raises(h2.exceptions.ProtocolError) as e:
            c.receive_data(data)

        assert "invalid frame" in str(e.value).lower()

    def test_continuation_cannot_interleave_data(self, frame_factory) -> None:
        """
        We cannot interleave a data frame with a CONTINUATION sequence when
        that sequence began with a PUSH_PROMISE frame.
        """
        c = h2.connection.H2Connection()
        c.initiate_connection()
        c.send_headers(stream_id=1, headers=self.example_request_headers)

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        assert len(frames) > 2  # This is mostly defensive.

        bogus_frame = frame_factory.build_data_frame(
            data=b"hello",
            stream_id=1,
        )
        frames.insert(len(frames) - 2, bogus_frame)
        data = b"".join(f.serialize() for f in frames)

        with pytest.raises(h2.exceptions.ProtocolError) as e:
            c.receive_data(data)

        assert "invalid frame" in str(e.value).lower()

    def test_continuation_cannot_interleave_unknown_frame(self, frame_factory) -> None:
        """
        We cannot interleave an unknown frame with a CONTINUATION sequence when
        that sequence began with a PUSH_PROMISE frame.
        """
        c = h2.connection.H2Connection()
        c.initiate_connection()
        c.send_headers(stream_id=1, headers=self.example_request_headers)

        frames = self._build_continuation_sequence(
            headers=self.example_request_headers,
            block_size=5,
            frame_factory=frame_factory,
        )
        assert len(frames) > 2  # This is mostly defensive.

        bogus_frame = frame_factory.build_data_frame(
            data=b"hello",
            stream_id=1,
        )
        bogus_frame.type = 88
        frames.insert(len(frames) - 2, bogus_frame)
        data = b"".join(f.serialize() for f in frames)

        with pytest.raises(h2.exceptions.ProtocolError) as e:
            c.receive_data(data)

        assert "invalid frame" in str(e.value).lower()

    @pytest.mark.parametrize("evict", [True, False])
    def test_stream_remotely_closed_disallows_push_promise(self,
                                                           evict,
                                                           frame_factory) -> None:
        """
        Streams closed normally by the remote peer disallow PUSH_PROMISE
        frames, and cause a GOAWAY.
        """
        c = h2.connection.H2Connection()
        c.initiate_connection()
        c.send_headers(
            stream_id=1,
            headers=self.example_request_headers,
            end_stream=True,
        )

        f = frame_factory.build_headers_frame(
            stream_id=1,
            headers=self.example_response_headers,
            flags=["END_STREAM"],
        )
        c.receive_data(f.serialize())
        c.clear_outbound_data_buffer()

        if evict:
            # This is annoyingly stateful, but enumerating the list of open
            # streams will force us to flush state.
            assert not c.open_outbound_streams

        f = frame_factory.build_push_promise_frame(
            stream_id=1,
            promised_stream_id=2,
            headers=self.example_request_headers,
        )

        with pytest.raises(h2.exceptions.ProtocolError):
            c.receive_data(f.serialize())

        f = frame_factory.build_goaway_frame(
            last_stream_id=0,
            error_code=h2.errors.ErrorCodes.PROTOCOL_ERROR,
        )
        assert c.data_to_send() == f.serialize()

    def test_continuation_frame_multiple_push_promise(self, frame_factory) -> None:
        """
        Test that we correctly decode  header blocks split across continuation
        frames when those header block is initiated with a PUSH_PROMISE, for
        more than one pushed stream.
        """
        c = h2.connection.H2Connection()
        c.initiate_connection()
        c.send_headers(stream_id=1, headers=self.example_request_headers)

        for promised_stream_id in range(2, 8, 2):
            frames = self._build_continuation_sequence(
                headers=self.example_request_headers,
                block_size=2,
                frame_factory=frame_factory,
            )
            frames[0].promised_stream_id = promised_stream_id
            data = b"".join(f.serialize() for f in frames)
            events = c.receive_data(data)

            assert len(events) == 1
            event = events[0]

            assert isinstance(event, h2.events.PushedStreamReceived)
            assert event.headers == self.example_request_headers
            assert event.parent_stream_id == 1
            assert event.pushed_stream_id == promised_stream_id
