# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from io import BytesIO

from cassandra import DriverException
from cassandra.segment import Segment, CrcException
from cassandra.connection import segment_codec_no_compression, segment_codec_lz4


def to_bits(b):
    return '{:08b}'.format(b)

class SegmentCodecTest(unittest.TestCase):

    small_msg = b'b' * 50
    max_msg = b'b' * Segment.MAX_PAYLOAD_LENGTH
    large_msg = b'b' * (Segment.MAX_PAYLOAD_LENGTH + 1)

    @staticmethod
    def _header_to_bits(data):
        # unpack a header to bits
        # data should be the little endian bytes sequence
        if len(data) > 6:  # compressed
            data = data[:5]
            bits = ''.join([to_bits(b) for b in reversed(data)])
            # return the compressed payload length, the uncompressed payload length,
            # the self contained flag and the padding as bits
            return bits[23:40] + bits[6:23] + bits[5:6] + bits[:5]
        else:  # uncompressed
            data = data[:3]
            bits = ''.join([to_bits(b) for b in reversed(data)])
            # return the payload length, the self contained flag and
            # the padding as bits
            return bits[7:24] + bits[6:7] + bits[:6]

    def test_encode_uncompressed_header(self):
        buffer = BytesIO()
        segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True)
        self.assertEqual(buffer.tell(), 6)
        self.assertEqual(
            self._header_to_bits(buffer.getvalue()),
            "00000000000110010" + "1" + "000000")

    @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
    def test_encode_compressed_header(self):
        buffer = BytesIO()
        compressed_length = len(segment_codec_lz4.compress(self.small_msg))
        segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True)

        self.assertEqual(buffer.tell(), 8)
        self.assertEqual(
            self._header_to_bits(buffer.getvalue()),
            "{:017b}".format(compressed_length) + "00000000000110010" + "1" + "00000")

    def test_encode_uncompressed_header_with_max_payload(self):
        buffer = BytesIO()
        segment_codec_no_compression.encode_header(buffer, len(self.max_msg), -1, True)
        self.assertEqual(buffer.tell(), 6)
        self.assertEqual(
            self._header_to_bits(buffer.getvalue()),
            "11111111111111111" + "1" + "000000")

    def test_encode_header_fails_if_payload_too_big(self):
        buffer = BytesIO()
        for codec in [c for c in [segment_codec_no_compression, segment_codec_lz4] if c is not None]:
            with self.assertRaises(DriverException):
                codec.encode_header(buffer, len(self.large_msg), -1, False)

    def test_encode_uncompressed_header_not_self_contained_msg(self):
        buffer = BytesIO()
        # simulate the first chunk with the max size
        segment_codec_no_compression.encode_header(buffer, len(self.max_msg), -1, False)
        self.assertEqual(buffer.tell(), 6)
        self.assertEqual(
            self._header_to_bits(buffer.getvalue()),
            ("11111111111111111"
             "0"  # not self contained
             "000000"))

    @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
    def test_encode_compressed_header_with_max_payload(self):
        buffer = BytesIO()
        compressed_length = len(segment_codec_lz4.compress(self.max_msg))
        segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), True)
        self.assertEqual(buffer.tell(), 8)
        self.assertEqual(
            self._header_to_bits(buffer.getvalue()),
            "{:017b}".format(compressed_length) + "11111111111111111" + "1" + "00000")

    @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
    def test_encode_compressed_header_not_self_contained_msg(self):
        buffer = BytesIO()
        # simulate the first chunk with the max size
        compressed_length = len(segment_codec_lz4.compress(self.max_msg))
        segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), False)
        self.assertEqual(buffer.tell(), 8)
        self.assertEqual(
            self._header_to_bits(buffer.getvalue()),
            ("{:017b}".format(compressed_length) +
             "11111111111111111"
             "0"  # not self contained
             "00000"))

    def test_decode_uncompressed_header(self):
        buffer = BytesIO()
        segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True)
        buffer.seek(0)
        header = segment_codec_no_compression.decode_header(buffer)
        self.assertEqual(header.uncompressed_payload_length, -1)
        self.assertEqual(header.payload_length, len(self.small_msg))
        self.assertEqual(header.is_self_contained, True)

    @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
    def test_decode_compressed_header(self):
        buffer = BytesIO()
        compressed_length = len(segment_codec_lz4.compress(self.small_msg))
        segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True)
        buffer.seek(0)
        header = segment_codec_lz4.decode_header(buffer)
        self.assertEqual(header.uncompressed_payload_length, len(self.small_msg))
        self.assertEqual(header.payload_length, compressed_length)
        self.assertEqual(header.is_self_contained, True)

    def test_decode_header_fails_if_corrupted(self):
        buffer = BytesIO()
        segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True)
        # corrupt one byte
        buffer.seek(buffer.tell()-1)
        buffer.write(b'0')
        buffer.seek(0)

        with self.assertRaises(CrcException):
            segment_codec_no_compression.decode_header(buffer)

    def test_decode_uncompressed_self_contained_segment(self):
        buffer = BytesIO()
        segment_codec_no_compression.encode(buffer, self.small_msg)

        buffer.seek(0)
        header = segment_codec_no_compression.decode_header(buffer)
        segment = segment_codec_no_compression.decode(buffer, header)

        self.assertEqual(header.is_self_contained, True)
        self.assertEqual(header.uncompressed_payload_length, -1)
        self.assertEqual(header.payload_length, len(self.small_msg))
        self.assertEqual(segment.payload, self.small_msg)

    @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
    def test_decode_compressed_self_contained_segment(self):
        buffer = BytesIO()
        segment_codec_lz4.encode(buffer, self.small_msg)

        buffer.seek(0)
        header = segment_codec_lz4.decode_header(buffer)
        segment = segment_codec_lz4.decode(buffer, header)

        self.assertEqual(header.is_self_contained, True)
        self.assertEqual(header.uncompressed_payload_length, len(self.small_msg))
        self.assertGreater(header.uncompressed_payload_length, header.payload_length)
        self.assertEqual(segment.payload, self.small_msg)

    def test_decode_multi_segments(self):
        buffer = BytesIO()
        segment_codec_no_compression.encode(buffer, self.large_msg)

        buffer.seek(0)
        # We should have 2 segments to read
        headers = []
        segments = []
        headers.append(segment_codec_no_compression.decode_header(buffer))
        segments.append(segment_codec_no_compression.decode(buffer, headers[0]))
        headers.append(segment_codec_no_compression.decode_header(buffer))
        segments.append(segment_codec_no_compression.decode(buffer, headers[1]))

        self.assertTrue(all([h.is_self_contained is False for h in headers]))
        decoded_msg = segments[0].payload + segments[1].payload
        self.assertEqual(decoded_msg, self.large_msg)

    @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
    def test_decode_fails_if_corrupted(self):
        buffer = BytesIO()
        segment_codec_lz4.encode(buffer, self.small_msg)
        buffer.seek(buffer.tell()-1)
        buffer.write(b'0')
        buffer.seek(0)
        header = segment_codec_lz4.decode_header(buffer)
        with self.assertRaises(CrcException):
            segment_codec_lz4.decode(buffer, header)

    @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed')
    def test_decode_tiny_msg_not_compressed(self):
        buffer = BytesIO()
        segment_codec_lz4.encode(buffer, b'b')
        buffer.seek(0)
        header = segment_codec_lz4.decode_header(buffer)
        segment = segment_codec_lz4.decode(buffer, header)
        self.assertEqual(header.uncompressed_payload_length, 0)
        self.assertEqual(header.payload_length, 1)
        self.assertEqual(segment.payload, b'b')
