# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import zlib

from cassandra import DriverException
from cassandra.marshal import int32_pack
from cassandra.protocol import write_uint_le, read_uint_le

CRC24_INIT = 0x875060
CRC24_POLY = 0x1974F0B
CRC24_LENGTH = 3
CRC32_LENGTH = 4
CRC32_INITIAL = zlib.crc32(b"\xfa\x2d\x55\xca")


class CrcException(Exception):
    """
    CRC mismatch error.

    TODO: here to avoid import cycles with cassandra.connection. In the next
          major, the exceptions should be declared in a separated exceptions.py
          file.
    """
    pass


def compute_crc24(data, length):
    crc = CRC24_INIT

    for _ in range(length):
        crc ^= (data & 0xff) << 16
        data >>= 8

        for i in range(8):
            crc <<= 1
            if crc & 0x1000000 != 0:
                crc ^= CRC24_POLY

    return crc


def compute_crc32(data, value):
    crc32 = zlib.crc32(data, value)
    return crc32


class SegmentHeader(object):

    payload_length = None
    uncompressed_payload_length = None
    is_self_contained = None

    def __init__(self, payload_length, uncompressed_payload_length, is_self_contained):
        self.payload_length = payload_length
        self.uncompressed_payload_length = uncompressed_payload_length
        self.is_self_contained = is_self_contained

    @property
    def segment_length(self):
        """
        Return the total length of the segment, including the CRC.
        """
        hl = SegmentCodec.UNCOMPRESSED_HEADER_LENGTH if self.uncompressed_payload_length < 1 \
            else SegmentCodec.COMPRESSED_HEADER_LENGTH
        return hl + CRC24_LENGTH + self.payload_length + CRC32_LENGTH


class Segment(object):

    MAX_PAYLOAD_LENGTH = 128 * 1024 - 1

    payload = None
    is_self_contained = None

    def __init__(self, payload, is_self_contained):
        self.payload = payload
        self.is_self_contained = is_self_contained


class SegmentCodec(object):

    COMPRESSED_HEADER_LENGTH = 5
    UNCOMPRESSED_HEADER_LENGTH = 3
    FLAG_OFFSET = 17

    compressor = None
    decompressor = None

    def __init__(self, compressor=None, decompressor=None):
        self.compressor = compressor
        self.decompressor = decompressor

    @property
    def header_length(self):
        return self.COMPRESSED_HEADER_LENGTH if self.compression \
            else self.UNCOMPRESSED_HEADER_LENGTH

    @property
    def header_length_with_crc(self):
        return (self.COMPRESSED_HEADER_LENGTH if self.compression
                else self.UNCOMPRESSED_HEADER_LENGTH) + CRC24_LENGTH

    @property
    def compression(self):
        return self.compressor and self.decompressor

    def compress(self, data):
        # the uncompressed length is already encoded in the header, so
        # we remove it here
        return self.compressor(data)[4:]

    def decompress(self, encoded_data, uncompressed_length):
        return self.decompressor(int32_pack(uncompressed_length) + encoded_data)

    def encode_header(self, buffer, payload_length, uncompressed_length, is_self_contained):
        if payload_length > Segment.MAX_PAYLOAD_LENGTH:
            raise DriverException('Payload length exceed Segment.MAX_PAYLOAD_LENGTH')

        header_data = payload_length

        flag_offset = self.FLAG_OFFSET
        if self.compression:
            header_data |= uncompressed_length << flag_offset
            flag_offset += 17

        if is_self_contained:
            header_data |= 1 << flag_offset

        write_uint_le(buffer, header_data, size=self.header_length)
        header_crc = compute_crc24(header_data, self.header_length)
        write_uint_le(buffer, header_crc, size=CRC24_LENGTH)

    def _encode_segment(self, buffer, payload, is_self_contained):
        """
        Encode a message to a single segment.
        """
        uncompressed_payload = payload
        uncompressed_payload_length = len(payload)

        if self.compression:
            compressed_payload = self.compress(uncompressed_payload)
            if len(compressed_payload) >= uncompressed_payload_length:
                encoded_payload = uncompressed_payload
                uncompressed_payload_length = 0
            else:
                encoded_payload = compressed_payload
        else:
            encoded_payload = uncompressed_payload

        payload_length = len(encoded_payload)
        self.encode_header(buffer, payload_length, uncompressed_payload_length, is_self_contained)
        payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
        buffer.write(encoded_payload)
        write_uint_le(buffer, payload_crc)

    def encode(self, buffer, msg):
        """
        Encode a message to one of more segments.
        """
        msg_length = len(msg)

        if msg_length > Segment.MAX_PAYLOAD_LENGTH:
            payloads = []
            for i in range(0, msg_length, Segment.MAX_PAYLOAD_LENGTH):
                payloads.append(msg[i:i + Segment.MAX_PAYLOAD_LENGTH])
        else:
            payloads = [msg]

        is_self_contained = len(payloads) == 1
        for payload in payloads:
            self._encode_segment(buffer, payload, is_self_contained)

    def decode_header(self, buffer):
        header_data = read_uint_le(buffer, self.header_length)

        expected_header_crc = read_uint_le(buffer, CRC24_LENGTH)
        actual_header_crc = compute_crc24(header_data, self.header_length)
        if actual_header_crc != expected_header_crc:
            raise CrcException('CRC mismatch on header {:x}. Received {:x}", computed {:x}.'.format(
                header_data, expected_header_crc, actual_header_crc))

        payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
        header_data >>= 17

        if self.compression:
            uncompressed_payload_length = header_data & Segment.MAX_PAYLOAD_LENGTH
            header_data >>= 17
        else:
            uncompressed_payload_length = -1

        is_self_contained = (header_data & 1) == 1

        return SegmentHeader(payload_length, uncompressed_payload_length, is_self_contained)

    def decode(self, buffer, header):
        encoded_payload = buffer.read(header.payload_length)
        expected_payload_crc = read_uint_le(buffer)

        actual_payload_crc = compute_crc32(encoded_payload, CRC32_INITIAL)
        if actual_payload_crc != expected_payload_crc:
            raise CrcException('CRC mismatch on payload. Received {:x}", computed {:x}.'.format(
                expected_payload_crc, actual_payload_crc))

        payload = encoded_payload
        if self.compression and header.uncompressed_payload_length > 0:
            payload = self.decompress(encoded_payload, header.uncompressed_payload_length)

        return Segment(payload, header.is_self_contained)
