File: test_sum_indices.py

package info (click to toggle)
python-bitarray 3.6.1-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 1,288 kB
  • sloc: python: 11,456; ansic: 7,657; makefile: 73; sh: 6
file content (274 lines) | stat: -rw-r--r-- 8,851 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
"""
In both ssqi() (in _util.c) and sum_indices() (in util.py), we divide our
bitarray into equally sized blocks in order to calculate the sum of active
indices.  We use the same trick but for different reasons:

  (a) in ssqi(), we want to loop over bytes (blocks of 8 bits) and use
      lookup tables (for sum z_j [**2])

  (b) in sum_indices() we want to loop over blocks of smaller bitarrays
      in order to keep the summation in ssqi() from overflowing

The trick is to write

    x_j = y_j + z_j        where  y_j = y  : if bit j is active
                                        0  : otherwise

for each block.  Here, j is the index within each block.
That is, j is in range(block size).
Using the above, we get:

    sum x_j   =   k * y  +  sum z_j

where k is the bit count (per block).  And:

    sum x_j**2   =   k * y**2  +  2 * sum z_j * y  +  sum z_j**2

These are the sums for each block and their sum (over all blocks) is what
we are interested in.

                   (a)  ssqi()          (b)  sum_indices()
------------------------------------------------------------
block              c (char)             block (bitarray)
block size         8                    n
i                  byte index           block index
y                  8 * i                n * i
k                  count_table[c]       block.count()
z1 = sum z_j       sum_table[c]         _ssqi(block)
z2 = sum z_j**2    sum_sqr_table[c]     _ssqi(block, 2)
"""
import unittest
from random import getrandbits, randint, randrange, sample

from bitarray.util import zeros, ones, urandom, _ssqi, sum_indices
from bitarray.test_util import SumIndicesUtil


N19 = 1 << 19  # 512 Kbit =  64 KB
N20 = 1 << 20  #   1 Mbit = 128 KB
N21 = 1 << 21  #   2 Mbit = 256 KB
N22 = 1 << 22  #   4 Mbit = 512 KB
N23 = 1 << 23  #   8 Mbit =   1 MB
N28 = 1 << 28  # 256 Mbit =  32 MB
N30 = 1 << 30  #   1 Gbit = 128 MB
N31 = 1 << 31  #   2 Gbit = 256 MB
N32 = 1 << 32  #   4 Gbit = 512 MB
N33 = 1 << 33  #   8 Gbit =   1 GB

MAX_UINT64 = (1 << 64) - 1


def sum_range(n):
    "Return sum(range(n))"
    return n * (n - 1) // 2

def sum_sqr_range(n):
    "Return sum(i * i for i in range(n))"
    return n * (n - 1) * (2 * n - 1) // 6


class SumRangeTests(unittest.TestCase):

    def test_sum_range(self):
        for n in range(1000):
            self.assertEqual(sum_range(n), sum(range(n)))

    def test_sum_sqr_range(self):
        for n in range(1000):
            self.assertEqual(sum_sqr_range(n), sum(i * i for i in range(n)))

    def test_mode(self):
        for n in range(1000):
            for mode, f in [(1, sum_range),
                            (2, sum_sqr_range)]:
                sum_ones = 3 if mode == 1 else 2 * n - 1
                sum_ones *= n * (n - 1)
                sum_ones //= 6
                self.assertEqual(sum_ones, f(n))

    def test_o2(self):
        for n in range(1000):
            o1 = n * (n - 1) // 2
            o2, r = divmod(o1 * (2 * n - 1), 3)
            self.assertEqual(r, 0)
            self.assertEqual(o2, sum_sqr_range(n))


class ExampleImplementationTests(unittest.TestCase):

    def sum_indices(self, a, mode=1):
        n = 503  # block size in bits
        nblocks = (len(a) + n - 1) // n  # number of blocks
        sm = 0
        for i in range(nblocks):
            y = n * i
            block = a[y : y + n]

            k = block.count()
            z1 = _ssqi(block)
            self.assertEqual(
                # Note that j are indices within each block.
                # Also note that we use len(block) instead of block_size,
                # as the last block may be truncated.
                z1, sum(j for j in range(len(block)) if block[j]))

            if mode == 1:
                x = k * y + z1
            else:
                z2 = _ssqi(block, 2)
                x = (k * y + 2 * z1) * y + z2

            # x is the sum [of squares] of indices for each block
            self.assertEqual(
                # Note that here t are indices of the full bitarray a.
                x, sum(t ** mode for t in range(y, y + len(block)) if a[t]))

            sm += x

        return sm

    def test_sum_indices(self):
        for _ in range(100):
            n = randrange(10_000)
            a = urandom(n)
            mode = randint(1, 2)
            self.assertEqual(self.sum_indices(a, mode), sum_indices(a, mode))


class SSQI_Tests(SumIndicesUtil):

    # Note carefully that the limits that are calculated and tested here
    # are limits used in internal function _ssqi().
    # The public Python function sum_indices() does NOT impose any limits
    # on the size of bitarrays it can compute.

    def test_calculate_limits(self):
        # calculation of limits used in ssqi() (in _util.c)
        for f, limit in [(sum_range, 6_074_001_000),
                         (sum_sqr_range, 3_810_778)]:
            lo = 0
            hi = MAX_UINT64
            while hi > lo + 1:
                n = (lo + hi) // 2
                if f(n) > MAX_UINT64:
                    hi = n
                else:
                    lo = n
            self.assertTrue(f(n) < MAX_UINT64)
            self.assertTrue(f(n + 1) > MAX_UINT64)
            self.assertEqual(n, limit)

    def test_overflow(self):
        # _ssqi() is limited to bitarrays of about 6 Gbit (4 Mbit mode=2).
        # This limit is never reached because sum_indices() uses
        # a much smaller block size for practical reasons.
        for mode, f, n in [(1, sum_range, 6_074_001_000),
                           (2, sum_sqr_range, 3_810_778)]:
            a = ones(n)
            self.assertTrue(f(len(a)) <= MAX_UINT64)
            self.assertEqual(_ssqi(a, mode), f(n))
            a.append(1)
            self.assertTrue(f(len(a)) > MAX_UINT64)
            self.assertRaises(OverflowError, _ssqi, a, mode)

    def test_sparse(self):
        for _  in range(500):
            n = randint(2, 3_810_778)
            k = randrange(min(1_000, n // 2))
            mode = randint(1, 2)
            freeze = getrandbits(1)
            inv = getrandbits(1)
            self.check_sparse(_ssqi, n, k, mode, freeze, inv)


class SumIndicesTests(SumIndicesUtil):

    def test_urandom(self):
        self.check_urandom(sum_indices, 1_000_003)

    def test_random_sample(self):
        n = N31
        for k in 1, 31, 503:
            mode = randint(1, 2)
            freeze = getrandbits(1)
            inv = getrandbits(1)
            self.check_sparse(sum_indices, n, k, mode, freeze, inv)

    def test_ones(self):
        for m in range(19, 32):
            n = randrange(1 << m)
            mode = randint(1, 2)
            freeze = getrandbits(1)
            self.check_sparse(sum_indices, n, 0, mode, freeze, inv=True)

    def test_sum_random(self):
        for _  in range(50):
            n = randrange(1 << randrange(19, 32))
            k = randrange(min(1_000, n // 2))
            mode = randint(1, 2)
            freeze = getrandbits(1)
            inv = getrandbits(1)
            self.check_sparse(sum_indices, n, k, mode, freeze, inv)


class VarianceTests(unittest.TestCase):

    def variance(self, a, mu=None):
        si = sum_indices(a)
        k = a.count()
        if mu is None:
            mu = si / k
        return (sum_indices(a, 2) - 2 * mu * si) / k + mu * mu

    def variance_values(self, values, mu=None):
        k = len(values)
        if mu is None:
            mu = sum(values) / k
        return sum((x - mu) ** 2 for x in values) / k

    def test_variance(self):
        for _ in range(1_000):
            n = randrange(1, 1_000)
            k = randint(1, max(1, n // 2))
            indices = sample(range(n), k)
            a = zeros(n)
            a[indices] = 1
            mean = sum(indices) / len(indices)
            self.assertAlmostEqual(self.variance(a),
                                   self.variance_values(indices))
            self.assertAlmostEqual(self.variance(a, mean),
                                   self.variance_values(indices, mean))
            mean = 20.5
            self.assertAlmostEqual(self.variance(a, mean),
                                   self.variance_values(indices, mean))


def test_ones():

    for n in [3_810_778,
              3_810_779,
              6_074_001_000,
              6_074_001_001,
              N33, 2 * N33]:
        a = ones(n)
        print("n =    %32d  %6.2f Gbit    %6.2f GB" % (n, n / N30, n / N33))
        print("2^64 = %32d" % (1 << 64))
        res = sum_indices(a)
        print("sum =  %32d" % res)
        assert res == sum_range(n)

        res = sum_indices(a, 2)
        print("sum2 = %32d" % res)
        assert res == sum_sqr_range(n)

        print()

    print("OK")


if __name__ == "__main__":
    import sys
    if '--ones' in sys.argv:
        test_ones()
        sys.exit()
    unittest.main()