File: combinations_squashed_order_n_3_to_5.py

package info (click to toggle)
python-awkward 2.8.10-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 25,140 kB
  • sloc: python: 182,845; cpp: 33,828; sh: 432; makefile: 21; javascript: 8
file content (117 lines) | stat: -rw-r--r-- 3,771 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import cupy as cp

def root2(a):
    return cp.floor((1+cp.sqrt(8*a+1))/2)


def root3(a):
    out = 2*cp.ones(a.shape)
    mask = a > 0
    rad = cp.power(cp.sqrt(3)*cp.sqrt(243*a[mask]**2 - 1) + 27*a[mask], 1./3)
    # 1e-12 to correct rounding error (good to 1000 choose 3)
    out[mask] = cp.floor(cp.power(3, -2./3)*rad + cp.power(3, -1./3)/rad + 1 + 1e-12)
    return out


def root4(a):
    # good to (at least) 100 choose 4
    return cp.floor((cp.sqrt(4*cp.sqrt(24*a + 1) + 5) + 3)/2)


def repeat(x, repeats):
    all_stops = cp.cumsum(repeats)
    parents = cp.zeros(all_stops[-1].item(), dtype=int)
    stops, stop_counts = cp.unique(all_stops[:-1], return_counts=True)
    parents[stops] = stop_counts
    cp.cumsum(parents, out=parents)
    return x[parents]


def argchoose(starts, stops, n, absolute=False, replacement=False):
    counts = stops - starts
    if n > 5:
        raise NotImplementedError
    elif n == 5:
        counts_comb = counts*(counts - 1)*(counts - 2)*(counts - 3)*(counts - 4)//120
    elif n == 4:
        counts_comb = counts*(counts - 1)*(counts - 2)*(counts - 3)//24
    elif n == 3:
        counts_comb = counts*(counts - 1)*(counts - 2)//6
    elif n == 2:
        counts_comb = counts*(counts - 1)//2
    elif n <= 1:
        raise ValueError("Choosing 0 or 1 items is trivial")

    offsets = cp.cumsum(cp.concatenate((cp.array([0]), counts_comb)))
    offsets2 = cp.cumsum(cp.concatenate((cp.array([0]), counts)))
    parents = cp.zeros(int(offsets[-1]), dtype=int)
    parents2 = cp.zeros(int(offsets2[-1]), dtype=int)
    for i in range(1, len(offsets)):
        parents[offsets[i-1]:offsets[i]] = i - 1
    for i in range(1, len(offsets2)):
        parents2[offsets2[i-1]:offsets2[i]] = i - 1
    local = cp.arange(offsets2[-1]) - offsets2[parents2]
    indices = cp.arange(offsets[-1])

    if n == 5:
        k5 = indices - offsets[parents]
        i5 = repeat(local, local*(local - 1)*(local - 2)*(local - 3)//24)
        k4 = k5 - i5*(i5 - 1)*(i5 - 2)*(i5 - 3)*(i5 - 4)//120
        i4 = root4(k4)
        k3 = k4 - i4*(i4 - 1)*(i4 - 2)*(i4 - 3)//24
        i3 = root3(k3)
        k2 = k3 - i3*(i3 - 1)*(i3 - 2)//6
        i2 = root2(k2)
        k1 = k2 - i2*(i2 - 1)//2
        i1 = k1
        if absolute:
            starts_parents = starts[parents]
            for idx in [i1, i2, i3, i4, i5]:
                idx += starts_parents
        out = cp.vstack((i1, i2, i3, i4, i5)).T

    elif n == 4:
        k4 = indices - offsets[parents]
        i4 = repeat(local, local*(local - 1)*(local - 2)//6)
        k3 = k4 - i4*(i4 - 1)*(i4 - 2)*(i4 - 3)//24
        i3 = root3(k3)
        k2 = k3 - i3*(i3 - 1)*(i3 - 2)//6
        i2 = root2(k2)
        k1 = k2 - i2*(i2 - 1)//2
        i1 = k1
        if absolute:
            starts_parents = starts[parents]
            for idx in [i1, i2, i3, i4]:
                idx += starts_parents
        out = cp.vstack((i1, i2, i3, i4)).T

    elif n == 3:
        k3 = indices - offsets[parents]
        i3 = repeat(local, local*(local - 1)//2)
        k2 = k3 - i3*(i3 - 1)*(i3 - 2)//6
        i2 = root2(k2)
        k1 = k2 - i2*(i2 - 1)//2
        i1 = k1
        if absolute:
            starts_parents = starts[parents]
            for idx in [i1, i2, i3]:
                idx += starts_parents
        out = cp.vstack((i1, i2, i3)).T

    return out


content = cp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

counts = cp.array([4, 0, 3, 1, 5])
starts = cp.array([0, 4, 4, 7, 8])
stops = cp.array([4, 4, 7, 8, 13])

result = argchoose(starts, stops, 3)
print("argcombinations (n = 3):\n", result)

result = argchoose(starts, stops, 4)
print("argcombinations (n = 4):\n", result)

result = argchoose(starts, stops, 5)
print("argcombinations (n = 5):\n", result)