File: selection.py

package info (click to toggle)
pypy3 7.3.19%2Bdfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 212,236 kB
  • sloc: python: 2,098,316; ansic: 540,565; sh: 21,462; asm: 14,419; cpp: 4,451; makefile: 4,209; objc: 761; xml: 530; exp: 499; javascript: 314; pascal: 244; lisp: 45; csh: 12; awk: 4
file content (353 lines) | stat: -rw-r--r-- 13,581 bytes parent folder | download | duplicates (8)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from pypy.interpreter.error import oefmt
from rpython.rlib.listsort import make_timsort_class
from rpython.rlib.objectmodel import specialize
from rpython.rlib.rarithmetic import widen
from rpython.rlib.rawstorage import raw_storage_getitem, raw_storage_setitem, \
        free_raw_storage, alloc_raw_storage
from rpython.rlib.unroll import unrolling_iterable
from rpython.rtyper.lltypesystem import rffi, lltype
from pypy.module.micronumpy import descriptor, types, constants as NPY
from pypy.module.micronumpy.base import W_NDimArray
from pypy.module.micronumpy.iterators import AllButAxisIter

INT_SIZE = rffi.sizeof(lltype.Signed)

all_types = (types.all_float_types + types.all_complex_types +
             types.all_int_types)
all_types = [i for i in all_types if not issubclass(i[0], types.Float16)]
all_types = unrolling_iterable(all_types)


def make_argsort_function(space, itemtype, comp_type, count=1):
    TP = itemtype.T
    step = rffi.sizeof(TP)

    class Repr(object):
        def __init__(self, index_stride_size, stride_size, size, values,
                     indexes, index_start, start):
            self.index_stride_size = index_stride_size
            self.stride_size = stride_size
            self.index_start = index_start
            self.start = start
            self.size = size
            self.values = values
            self.indexes = indexes

        def getitem(self, idx):
            if count < 2:
                v = raw_storage_getitem(TP, self.values, idx * self.stride_size
                                    + self.start)
            else:
                v = []
                for i in range(count):
                    _v = raw_storage_getitem(TP, self.values, idx * self.stride_size
                                    + self.start + step * i)
                    v.append(_v)
            if comp_type == 'int':
                v = widen(v)
            elif comp_type == 'float':
                v = float(v)
            elif comp_type == 'complex':
                v = [float(v[0]),float(v[1])]
            else:
                raise NotImplementedError('cannot reach')
            return (v, raw_storage_getitem(lltype.Signed, self.indexes,
                                           idx * self.index_stride_size +
                                           self.index_start))

        def setitem(self, idx, item):
            if count < 2:
                raw_storage_setitem(self.values, idx * self.stride_size +
                                self.start, rffi.cast(TP, item[0]))
            else:
                i = 0
                for val in item[0]:
                    raw_storage_setitem(self.values, idx * self.stride_size +
                                self.start + i*step, rffi.cast(TP, val))
                    i += 1
            raw_storage_setitem(self.indexes, idx * self.index_stride_size +
                                self.index_start, item[1])

    class ArgArrayRepWithStorage(Repr):
        def __init__(self, index_stride_size, stride_size, size):
            start = 0
            dtype = descriptor.get_dtype_cache(space).w_longdtype
            indexes = dtype.itemtype.malloc(size * dtype.elsize)
            values = alloc_raw_storage(size * stride_size,
                                            track_allocation=False)
            Repr.__init__(self, dtype.elsize, stride_size,
                          size, values, indexes, start, start)

        def __del__(self):
            free_raw_storage(self.indexes, track_allocation=False)
            free_raw_storage(self.values, track_allocation=False)

    def arg_getitem(lst, item):
        return lst.getitem(item)

    def arg_setitem(lst, item, value):
        lst.setitem(item, value)

    def arg_length(lst):
        return lst.size

    def arg_getitem_slice(lst, start, stop):
        retval = ArgArrayRepWithStorage(lst.index_stride_size, lst.stride_size,
                stop-start)
        for i in range(stop-start):
            retval.setitem(i, lst.getitem(i+start))
        return retval

    if count < 2:
        def arg_lt(a, b):
            # Does numpy do <= ?
            return a[0] < b[0] or b[0] != b[0] and a[0] == a[0]
    else:
        def arg_lt(a, b):
            for i in range(count):
                if b[0][i] != b[0][i] and a[0][i] == a[0][i]:
                    return True
                elif b[0][i] == b[0][i] and a[0][i] != a[0][i]:
                    return False
            for i in range(count):
                if a[0][i] < b[0][i]:
                    return True
                elif a[0][i] > b[0][i]:
                    return False
            # Does numpy do True?
            return False

    ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
                                 arg_getitem_slice, arg_lt)

    def argsort(arr, space, w_axis):
        if w_axis is space.w_None:
            # note that it's fine ot pass None here as we're not going
            # to pass the result around (None is the link to base in slices)
            if arr.get_size() > 0:
                arr = arr.reshape(None, [arr.get_size()])
            axis = 0
        elif w_axis is None:
            axis = -1
        else:
            axis = space.int_w(w_axis)
        # create array of indexes
        dtype = descriptor.get_dtype_cache(space).w_longdtype
        index_arr = W_NDimArray.from_shape(space, arr.get_shape(), dtype)
        with index_arr.implementation as storage, arr as arr_storage:
            if len(arr.get_shape()) == 1:
                for i in range(arr.get_size()):
                    raw_storage_setitem(storage, i * INT_SIZE, i)
                r = Repr(INT_SIZE, arr.strides[0], arr.get_size(), arr_storage,
                         storage, 0, arr.start)
                ArgSort(r).sort()
            else:
                shape = arr.get_shape()
                if axis < 0:
                    axis = len(shape) + axis
                if axis < 0 or axis >= len(shape):
                    raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
                arr_iter = AllButAxisIter(arr, axis)
                arr_state = arr_iter.reset()
                index_impl = index_arr.implementation
                index_iter = AllButAxisIter(index_impl, axis)
                index_state = index_iter.reset()
                stride_size = arr.strides[axis]
                index_stride_size = index_impl.strides[axis]
                axis_size = arr.shape[axis]
                while not arr_iter.done(arr_state):
                    for i in range(axis_size):
                        raw_storage_setitem(storage, i * index_stride_size +
                                            index_state.offset, i)
                    r = Repr(index_stride_size, stride_size, axis_size,
                         arr_storage, storage, index_state.offset, arr_state.offset)
                    ArgSort(r).sort()
                    arr_state = arr_iter.next(arr_state)
                    index_state = index_iter.next(index_state)
            return index_arr

    return argsort


def argsort_array(arr, space, w_axis):
    cache = space.fromcache(ArgSortCache) # that populates ArgSortClasses
    itemtype = arr.dtype.itemtype
    for tp in all_types:
        if isinstance(itemtype, tp[0]):
            return cache._lookup(tp)(arr, space, w_axis)
    # XXX this should probably be changed
    raise oefmt(space.w_NotImplementedError,
                "sorting of non-numeric types '%s' is not implemented",
                arr.dtype.get_name())


def make_sort_function(space, itemtype, comp_type, count=1):
    TP = itemtype.T
    step = rffi.sizeof(TP)

    class Repr(object):
        def __init__(self, stride_size, size, values, start):
            self.stride_size = stride_size
            self.start = start
            self.size = size
            self.values = values

        def getitem(self, item):
            if count < 2:
                v = raw_storage_getitem(TP, self.values, item * self.stride_size
                                    + self.start)
            else:
                v = []
                for i in range(count):
                    _v = raw_storage_getitem(TP, self.values, item * self.stride_size
                                    + self.start + step * i)
                    v.append(_v)
            if comp_type == 'int':
                v = widen(v)
            elif comp_type == 'float':
                v = float(v)
            elif comp_type == 'complex':
                v = [float(v[0]),float(v[1])]
            else:
                raise NotImplementedError('cannot reach')
            return (v)

        def setitem(self, idx, item):
            if count < 2:
                raw_storage_setitem(self.values, idx * self.stride_size +
                                self.start, rffi.cast(TP, item))
            else:
                i = 0
                for val in item:
                    raw_storage_setitem(self.values, idx * self.stride_size +
                                self.start + i*step, rffi.cast(TP, val))
                    i += 1

    class ArgArrayRepWithStorage(Repr):
        def __init__(self, stride_size, size):
            start = 0
            values = alloc_raw_storage(size * stride_size,
                                            track_allocation=False)
            Repr.__init__(self, stride_size,
                          size, values, start)

        def __del__(self):
            free_raw_storage(self.values, track_allocation=False)

    def arg_getitem(lst, item):
        return lst.getitem(item)

    def arg_setitem(lst, item, value):
        lst.setitem(item, value)

    def arg_length(lst):
        return lst.size

    def arg_getitem_slice(lst, start, stop):
        retval = ArgArrayRepWithStorage(lst.stride_size, stop-start)
        for i in range(stop-start):
            retval.setitem(i, lst.getitem(i+start))
        return retval

    if count < 2:
        def arg_lt(a, b):
            # handles NAN and INF
            return a < b or b != b and a == a
    else:
        def arg_lt(a, b):
            for i in range(count):
                if b[i] != b[i] and a[i] == a[i]:
                    return True
                elif b[i] == b[i] and a[i] != a[i]:
                    return False
            for i in range(count):
                if a[i] < b[i]:
                    return True
                elif a[i] > b[i]:
                    return False
            # Does numpy do True?
            return False

    ArgSort = make_timsort_class(arg_getitem, arg_setitem, arg_length,
                                 arg_getitem_slice, arg_lt)

    def sort(arr, space, w_axis):
        if w_axis is space.w_None:
            # note that it's fine to pass None here as we're not going
            # to pass the result around (None is the link to base in slices)
            arr = arr.reshape(None, [arr.get_size()])
            axis = 0
        elif w_axis is None:
            axis = -1
        else:
            axis = space.int_w(w_axis)
        with arr as storage:
            if len(arr.get_shape()) == 1:
                r = Repr(arr.strides[0], arr.get_size(), storage,
                         arr.start)
                ArgSort(r).sort()
            else:
                shape = arr.get_shape()
                if axis < 0:
                    axis = len(shape) + axis
                if axis < 0 or axis >= len(shape):
                    raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
                arr_iter = AllButAxisIter(arr, axis)
                arr_state = arr_iter.reset()
                stride_size = arr.strides[axis]
                axis_size = arr.shape[axis]
                while not arr_iter.done(arr_state):
                    r = Repr(stride_size, axis_size, storage, arr_state.offset)
                    ArgSort(r).sort()
                    arr_state = arr_iter.next(arr_state)

    return sort


def sort_array(arr, space, w_axis, w_order):
    cache = space.fromcache(SortCache)  # that populates SortClasses
    itemtype = arr.dtype.itemtype
    if arr.dtype.byteorder == NPY.OPPBYTE:
        raise oefmt(space.w_NotImplementedError,
                    "sorting of non-native byteorder not supported yet")
    for tp in all_types:
        if isinstance(itemtype, tp[0]):
            return cache._lookup(tp)(arr, space, w_axis)
    # XXX this should probably be changed
    raise oefmt(space.w_NotImplementedError,
                "sorting of non-numeric types '%s' is not implemented",
                arr.dtype.get_name())


class ArgSortCache(object):
    built = False

    def __init__(self, space):
        if self.built:
            return
        self.built = True
        cache = {}
        for cls, it in all_types._items:
            if it == 'complex':
                cache[cls] = make_argsort_function(space, cls, it, 2)
            else:
                cache[cls] = make_argsort_function(space, cls, it)
        self.cache = cache
        self._lookup = specialize.memo()(lambda tp: cache[tp[0]])


class SortCache(object):
    built = False

    def __init__(self, space):
        if self.built:
            return
        self.built = True
        cache = {}
        for cls, it in all_types._items:
            if it == 'complex':
                cache[cls] = make_sort_function(space, cls, it, 2)
            else:
                cache[cls] = make_sort_function(space, cls, it)
        self.cache = cache
        self._lookup = specialize.memo()(lambda tp: cache[tp[0]])