"""
In both ssqi() (in _util.c) and sum_indices() (in util.py), we divide our
bitarray into equally sized blocks in order to calculate the sum of active
indices.  We use the same trick but for different reasons:

  (a) in ssqi(), we want to loop over bytes (blocks of 8 bits) and use
      lookup tables (for sum z_j [**2])

  (b) in sum_indices() we want to loop over blocks of smaller bitarrays
      in order to keep the summation in ssqi() from overflowing

The trick is to write

    x_j = y_j + z_j        where  y_j = y  : if bit j is active
                                        0  : otherwise

for each block.  Here, j is the index within each block.
That is, j is in range(block size).
Using the above, we get:

    sum x_j   =   k * y  +  sum z_j

where k is the bit count (per block).  And:

    sum x_j**2   =   k * y**2  +  2 * sum z_j * y  +  sum z_j**2

These are the sums for each block and their sum (over all blocks) is what
we are interested in.

                   (a)  ssqi()          (b)  sum_indices()
------------------------------------------------------------
block              c (char)             block (bitarray)
block size         8                    n
i                  byte index           block index
y                  8 * i                n * i
k                  count_table[c]       block.count()
z1 = sum z_j       sum_table[c]         _ssqi(block)
z2 = sum z_j**2    sum_sqr_table[c]     _ssqi(block, 2)
"""
import unittest
from random import getrandbits, randint, randrange, sample

from bitarray.util import zeros, ones, urandom, _ssqi, sum_indices
from bitarray.test_util import SumIndicesUtil


N19 = 1 << 19  # 512 Kbit =  64 KB
N20 = 1 << 20  #   1 Mbit = 128 KB
N21 = 1 << 21  #   2 Mbit = 256 KB
N22 = 1 << 22  #   4 Mbit = 512 KB
N23 = 1 << 23  #   8 Mbit =   1 MB
N28 = 1 << 28  # 256 Mbit =  32 MB
N30 = 1 << 30  #   1 Gbit = 128 MB
N31 = 1 << 31  #   2 Gbit = 256 MB
N32 = 1 << 32  #   4 Gbit = 512 MB
N33 = 1 << 33  #   8 Gbit =   1 GB

MAX_UINT64 = (1 << 64) - 1


def sum_range(n):
    "Return sum(range(n))"
    return n * (n - 1) // 2

def sum_sqr_range(n):
    "Return sum(i * i for i in range(n))"
    return n * (n - 1) * (2 * n - 1) // 6


class SumRangeTests(unittest.TestCase):

    def test_sum_range(self):
        for n in range(1000):
            self.assertEqual(sum_range(n), sum(range(n)))

    def test_sum_sqr_range(self):
        for n in range(1000):
            self.assertEqual(sum_sqr_range(n), sum(i * i for i in range(n)))

    def test_mode(self):
        for n in range(1000):
            for mode, f in [(1, sum_range),
                            (2, sum_sqr_range)]:
                sum_ones = 3 if mode == 1 else 2 * n - 1
                sum_ones *= n * (n - 1)
                sum_ones //= 6
                self.assertEqual(sum_ones, f(n))

    def test_o2(self):
        for n in range(1000):
            o1 = n * (n - 1) // 2
            o2, r = divmod(o1 * (2 * n - 1), 3)
            self.assertEqual(r, 0)
            self.assertEqual(o2, sum_sqr_range(n))


class ExampleImplementationTests(unittest.TestCase):

    def sum_indices(self, a, mode=1):
        n = 503  # block size in bits
        nblocks = (len(a) + n - 1) // n  # number of blocks
        sm = 0
        for i in range(nblocks):
            y = n * i
            block = a[y : y + n]

            k = block.count()
            z1 = _ssqi(block)
            self.assertEqual(
                # Note that j are indices within each block.
                # Also note that we use len(block) instead of block_size,
                # as the last block may be truncated.
                z1, sum(j for j in range(len(block)) if block[j]))

            if mode == 1:
                x = k * y + z1
            else:
                z2 = _ssqi(block, 2)
                x = (k * y + 2 * z1) * y + z2

            # x is the sum [of squares] of indices for each block
            self.assertEqual(
                # Note that here t are indices of the full bitarray a.
                x, sum(t ** mode for t in range(y, y + len(block)) if a[t]))

            sm += x

        return sm

    def test_sum_indices(self):
        for _ in range(100):
            n = randrange(10_000)
            a = urandom(n)
            mode = randint(1, 2)
            self.assertEqual(self.sum_indices(a, mode), sum_indices(a, mode))


class SSQI_Tests(SumIndicesUtil):

    # Note carefully that the limits that are calculated and tested here
    # are limits used in internal function _ssqi().
    # The public Python function sum_indices() does NOT impose any limits
    # on the size of bitarrays it can compute.

    def test_calculate_limits(self):
        # calculation of limits used in ssqi() (in _util.c)
        for f, limit in [(sum_range, 6_074_001_000),
                         (sum_sqr_range, 3_810_778)]:
            lo = 0
            hi = MAX_UINT64
            while hi > lo + 1:
                n = (lo + hi) // 2
                if f(n) > MAX_UINT64:
                    hi = n
                else:
                    lo = n
            self.assertTrue(f(n) < MAX_UINT64)
            self.assertTrue(f(n + 1) > MAX_UINT64)
            self.assertEqual(n, limit)

    def test_overflow(self):
        # _ssqi() is limited to bitarrays of about 6 Gbit (4 Mbit mode=2).
        # This limit is never reached because sum_indices() uses
        # a much smaller block size for practical reasons.
        for mode, f, n in [(1, sum_range, 6_074_001_000),
                           (2, sum_sqr_range, 3_810_778)]:
            a = ones(n)
            self.assertTrue(f(len(a)) <= MAX_UINT64)
            self.assertEqual(_ssqi(a, mode), f(n))
            a.append(1)
            self.assertTrue(f(len(a)) > MAX_UINT64)
            self.assertRaises(OverflowError, _ssqi, a, mode)

    def test_sparse(self):
        for _  in range(500):
            n = randint(2, 3_810_778)
            k = randrange(min(1_000, n // 2))
            mode = randint(1, 2)
            freeze = getrandbits(1)
            inv = getrandbits(1)
            self.check_sparse(_ssqi, n, k, mode, freeze, inv)


class SumIndicesTests(SumIndicesUtil):

    def test_urandom(self):
        self.check_urandom(sum_indices, 1_000_003)

    def test_random_sample(self):
        n = N31
        for k in 1, 31, 503:
            mode = randint(1, 2)
            freeze = getrandbits(1)
            inv = getrandbits(1)
            self.check_sparse(sum_indices, n, k, mode, freeze, inv)

    def test_ones(self):
        for m in range(19, 32):
            n = randrange(1 << m)
            mode = randint(1, 2)
            freeze = getrandbits(1)
            self.check_sparse(sum_indices, n, 0, mode, freeze, inv=True)

    def test_sum_random(self):
        for _  in range(50):
            n = randrange(1 << randrange(19, 32))
            k = randrange(min(1_000, n // 2))
            mode = randint(1, 2)
            freeze = getrandbits(1)
            inv = getrandbits(1)
            self.check_sparse(sum_indices, n, k, mode, freeze, inv)


class VarianceTests(unittest.TestCase):

    def variance(self, a, mu=None):
        si = sum_indices(a)
        k = a.count()
        if mu is None:
            mu = si / k
        return (sum_indices(a, 2) - 2 * mu * si) / k + mu * mu

    def variance_values(self, values, mu=None):
        k = len(values)
        if mu is None:
            mu = sum(values) / k
        return sum((x - mu) ** 2 for x in values) / k

    def test_variance(self):
        for _ in range(1_000):
            n = randrange(1, 1_000)
            k = randint(1, max(1, n // 2))
            indices = sample(range(n), k)
            a = zeros(n)
            a[indices] = 1
            mean = sum(indices) / len(indices)
            self.assertAlmostEqual(self.variance(a),
                                   self.variance_values(indices))
            self.assertAlmostEqual(self.variance(a, mean),
                                   self.variance_values(indices, mean))
            mean = 20.5
            self.assertAlmostEqual(self.variance(a, mean),
                                   self.variance_values(indices, mean))


def test_ones():

    for n in [3_810_778,
              3_810_779,
              6_074_001_000,
              6_074_001_001,
              N33, 2 * N33]:
        a = ones(n)
        print("n =    %32d  %6.2f Gbit    %6.2f GB" % (n, n / N30, n / N33))
        print("2^64 = %32d" % (1 << 64))
        res = sum_indices(a)
        print("sum =  %32d" % res)
        assert res == sum_range(n)

        res = sum_indices(a, 2)
        print("sum2 = %32d" % res)
        assert res == sum_sqr_range(n)

        print()

    print("OK")


if __name__ == "__main__":
    import sys
    if '--ones' in sys.argv:
        test_ones()
        sys.exit()
    unittest.main()
