from itertools import islice
from random import randrange


bitcount_table = [bin(i)[2:].count('1') for i in range(256)]


def read_n(n, stream):
    i = 0
    for j in range(n):
        i |= next(stream) << 8 * j
    return i

def sc_decode_header(stream):
    head = next(stream)
    if head & 0xe0:
        raise ValueError("invalid header: 0x%02x" % head)
    endian = 'big' if head & 0x10 else 'little'
    length = head & 0x0f
    nbits = read_n(length, stream)
    return endian, nbits

def sc_decode_block(stream, stats):
    head = next(stream)

    if head < 0xa0:                          # type 0 - 0x00 -- 0x9f
        if head == 0:  # stop byte
            return False
        n = 0
        k = head if head <= 32 else 32 * (head - 31)
    elif head < 0xc0:                        # type 1 - 0xa0 .. 0xbf
        n = 1
        k = head - 0xa0
    elif 0xc2 <= head <= 0xc4:               # type 2 .. 4 - 0xc2 .. 0xc4
        n = head - 0xc0
        k = next(stream)                     # index count byte
    else:
        raise ValueError("Invalid block head: 0x%02x" % head)

    stats['blocks'][n] += 1

    # consume block data
    nconsume = max(1, n) * k   # size of block data to consume below
    if stats.get('count'):
        if n == 0:
            stats['count'][0] += sum(bitcount_table[next(stream)]
                                     for _ in range(k))
            nconsume = 0
        else:
            stats['count'][n] += k

    next(islice(stream, nconsume, nconsume), None)

    return True

def sc_stat(stream, count=False):
    """sc_stat(stream) -> dict

Decode a compressed byte stream (generated by `sc_encode()` and return
useful statistics.  In particular, a list of length 5 with the count for
each block type.
"""
    stream = iter(stream)
    endian, nbits = sc_decode_header(stream)

    stats = {'endian': endian,
             'nbits': nbits,
             'blocks': 5 * [0]}
    if count:
        stats['count'] = 5 * [0]

    while sc_decode_block(stream, stats):
        pass

    return stats

# ---------------------------------------------------------------------------

import unittest

from bitarray import bitarray
from bitarray.util import sc_encode, sc_decode


class Tests(unittest.TestCase):

    def test_empty(self):
        blob = b"\x01\x00\0"
        self.assertEqual(sc_stat(blob),
                         {'endian': 'little',
                          'nbits': 0,
                          'blocks': [0, 0, 0, 0, 0]})
        self.assertEqual(sc_decode(blob), bitarray())

    def test_zeros_explitcit(self):
        for blob, blocks in [
                (b"\x11\x08\0",         [0, 0, 0, 0, 0]),
                (b"\x11\x08\x01\x00\0", [1, 0, 0, 0, 0]),
                (b"\x11\x08\xa0\0",     [0, 1, 0, 0, 0]),
                (b"\x11\x08\xc2\x00\0", [0, 0, 1, 0, 0]),
                (b"\x11\x08\xc3\x00\0", [0, 0, 0, 1, 0]),
                (b"\x11\x08\xc4\x00\0", [0, 0, 0, 0, 1]),
        ]:
            stat = sc_stat(blob, count=True)
            self.assertEqual(stat['blocks'], blocks)
            self.assertEqual(stat['count'], 5 * [0])
            self.assertEqual(sc_decode(blob), bitarray(8))

    def test_untouch(self):
        stream = iter(b"\x01\x07\x01\x73\0XYZ")
        self.assertEqual(sc_decode(stream), bitarray("1100111"))
        self.assertEqual(next(stream), ord('X'))

    def test_random(self):
        n = 20_000_000
        a = bitarray(n)
        for c in range(0, 21, 2):
            lst = [randrange(n) for _ in range(1 << c)]
            a[lst] = 1
            stat = sc_stat(sc_encode(a), count=True)
            # print(c, len(a), a.count(), stat['blocks'])
            self.assertEqual(sum(stat['count']), a.count())

    def test_end_of_stream(self):
        for blob in [b'', b'\x00', b'\x01', b'\x02\x77',
                     b'\x01\x04\x01', b'\x01\x04\xa1', b'\x01\x04\xa0']:
            self.assertRaises(StopIteration, sc_stat, blob)
            self.assertRaises(StopIteration, sc_decode, blob)

    def test_values(self):
        b = [0x11, 3, 1, 32, 0]
        self.assertEqual(sc_decode(b), bitarray("001"))
        self.assertEqual(sc_stat(b), {'endian': 'big',
                                      'nbits': 3,
                                      'blocks': [1, 0, 0, 0, 0]})
        for x in -1, 256:
            b[-1] = x
            self.assertRaises(ValueError, sc_stat, b)
        for x in None, "F", Ellipsis, []:
            b[-1] = x
            self.assertRaises(TypeError, sc_stat, b)

    def test_example(self):
        n = 1 << 26
        a = bitarray(n, 'little')
        a[:1 << 16] = 1
        for i in range(2, 1 << 16):
            a[n // i] = 1
        b = sc_encode(a)
        stat = sc_stat(b, True)
        self.assertEqual(stat['blocks'], [2, 147, 3, 1, 1])
        self.assertEqual(stat['count'], [1 << 16, 374, 427, 220, 2])
        self.assertEqual(a, sc_decode(b))

        a.reverse()
        b = sc_encode(a)
        self.assertEqual(sc_stat(b)['blocks'], [2, 256, 254, 3, 0])
        self.assertEqual(a, sc_decode(b))


if __name__ == '__main__':
    unittest.main()
