File: sort.py

package info (click to toggle)
theano 1.0.3+dfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bullseye, buster, sid
  • size: 30,752 kB
  • sloc: python: 141,182; ansic: 9,505; makefile: 259; sh: 214; pascal: 81
file content (375 lines) | stat: -rw-r--r-- 12,720 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
from __future__ import absolute_import, print_function, division
import os
from string import Template

import numpy as np

import theano
from theano import Apply
from theano.tensor import as_tensor_variable
from theano.tensor.sort import TopKOp

from .basic_ops import (GpuKernelBase, Kernel, infer_context_name,
                        as_gpuarray_variable, gpuarray_helper_inc_dir)
from .opt import register_opt, op_lifter, register_opt2
from .type import GpuArrayType

try:
    import pygpu
    import pygpu.gpuarray as ga
except ImportError as e:
    # To make sure theano is importable
    pass


# TODO GPU sort / argsort
class GpuTopKOp(GpuKernelBase, TopKOp):
    '''Implements TopKOp on gpu

    Currently the output seem sorted, but we do not test it. So as on
    the CPU, we only support sorted=False for now.

    '''
    __props__ = TopKOp.__props__
    _f16_ok = True

    def __init__(
        self, axis=-1,
        sorted=True,
        idx_dtype='int64',
        return_values=True,
        return_indices=True
    ):
        if sorted:
            raise NotImplementedError(
                "GpuTopK currently is not sure to give sorted output even if they look sorted..")
        GpuKernelBase.__init__(self)
        TopKOp.__init__(
            self, axis=axis,
            sorted=sorted,
            idx_dtype=idx_dtype,
            return_values=return_values,
            return_indices=return_indices)

    def perform(self, node, inputs, output_storage, params):
        raise NotImplementedError()

    def c_headers(self):
        return ['gpuarray_api.h', 'gpuarray_helper.h', 'numpy_compat.h']

    def c_header_dirs(self):
        return [
            os.path.dirname(__file__),
            gpuarray_helper_inc_dir(),
            pygpu.get_include()]

    def c_code_cache_version(self):
        return (4,)

    def gpu_kernels(self, node, nodename):
        # load kernel source
        device_type = node.inputs[0].type.context.kind
        kernel_ext = {b'cuda': '.cu', b'opencl': '.cl'}[device_type]
        common_ext = {b'cuda': '.cuh', b'opencl': '.h'}[device_type]

        # prepare "$" macros
        if device_type == b'cuda':
            ndim = node.inputs[0].ndim
            dstv_strides_code = ''.join('ssize_t dstv_strides_%d, ' % i for i in range(ndim))
            dsti_strides_code = ''.join('ssize_t dsti_strides_%d, ' % i for i in range(ndim))
            src_strides_code = ''.join('ssize_t src_strides_%d, ' % i for i in range(ndim))
            set_slice_code = '''
        gidx = gid %% dims_%(i)d;
        gid /= dims_%(i)d;
        {dstv};
        {dsti};
        src = ptr_add(src, gidx*src_strides_%(i)d);\n'''.format(
                dstv='dstv = ptr_add(dstv, gidx*dstv_strides_%(i)d)' if self.return_values else '',
                dsti='dsti = ptr_add(dsti, gidx*dsti_strides_%(i)d)' if self.return_indices else '')
            set_slice_code = ''.join(
                set_slice_code % dict(i=j) for j in range(1, ndim))
            if self.return_values:
                set_slice_code += """
                dstv = ptr_add(dstv, dstv_offset);
                """
            if self.return_indices:
                set_slice_code += """
                dsti = ptr_add(dsti, dsti_offset);
                """
            set_slice_code += """
                src = ptr_add(src, src_offset);
            """
            flags = Kernel.get_flags(node.inputs[0].dtype)
            subs = dict(
                inp_t=ga.dtype_to_ctype(node.inputs[0].dtype),
                out_t=ga.dtype_to_ctype(self.idx_dtype),
                dims=''.join('size_t dims_%d, ' % i for i in range(1, ndim)),
                dstv='INPUT_TYPE *dstv,' if self.return_values else '',
                dstv_offset='size_t dstv_offset,' if self.return_values else '',
                dsti='INDEX_TYPE *dsti,' if self.return_indices else '',
                dsti_offset='size_t dsti_offset,' if self.return_indices else '',
                dstv_strides=dstv_strides_code if self.return_values else '',
                dsti_strides=dsti_strides_code if self.return_indices else '',
                src_strides=src_strides_code,
                set_slice=set_slice_code,
                write_value=int(self.return_values),
                write_index=int(self.return_indices),
                ndim=str(ndim)
                )
        elif device_type == b'opencl':
            raise NotImplementedError()

        # setup parameters
        param_types = [ga.SIZE] * (ndim - 1)  # dims
        for _ in range(self.return_values + self.return_indices):
            param_types.append(ga.GpuArray)  # dst*
            param_types.append(ga.SIZE)  # offset
            param_types.extend([ga.SSIZE] * ndim)  # dst*_strides
        param_types.append(ga.SIZE)  # k
        param_types.append(ga.GpuArray)  # src
        param_types.append(ga.SIZE)  # offset
        param_types.extend([ga.SSIZE] * ndim)  # src_strides
        param_types.append(ga.SIZE)  # size

        # load and compile kernels
        with open(os.path.join(
            os.path.dirname(__file__), 'c_code', 'topk_common' + common_ext
        )) as f:
            common_src = f.read()

        kernels = []

        def build_kernel(fname, kname, subs):
            with open(os.path.join(
                os.path.dirname(__file__), 'c_code', fname)
            ) as f:
                kernel_src = f.read()
            ker = Kernel(
                code=("#include <cluda.h>\n" +
                      Template(common_src + kernel_src).substitute(**subs)),
                name=kname,
                params=param_types,
                flags=flags,
                objvar=kname + nodename)
            return ker

        subs['count_t'] = 'int'
        kernels.append(
            build_kernel('topk_dense' + kernel_ext, 'k_topk_dense', subs))
        subs['kname'] = 'k_topk_dense_large'
        kernels.append(
            build_kernel('topk_dense_large' + kernel_ext, 'k_topk_dense_large', subs))
        subs['count_t'] = 'long long'
        subs['kname'] = 'k_topk_dense_xlarge'
        kernels.append(
            build_kernel('topk_dense_large' + kernel_ext, 'k_topk_dense_xlarge', subs))
        return kernels

    def c_code(self, node, nodename, inps, outs, sub):
        context = node.inputs[0].type.context
        if context.kind != b'cuda':
            raise NotImplementedError(
                '%s: We only have CUDA '
                'implementation so far.' % self.__class__.__name__)
        x, k = inps
        inp_dtc = ga.dtype_to_typecode(node.inputs[0].dtype)
        if not self.return_indices:
            yv, = outs
        elif self.return_values:
            yv, yi = outs
        else:
            yi, = outs
        out_dtype_s = self.idx_dtype
        out_dtc = ga.dtype_to_typecode(out_dtype_s)
        fail = sub['fail']
        ctx = sub['params']
        k_dtype = node.inputs[1].type.dtype_specs()[1]
        # max threads per block
        MAX_TPB = context.maxlsize0
        # max blocks per grid
        MAX_BPG = context.maxgsize0
        WARP_SIZE = 32

        ndim = node.inputs[0].ndim
        reordered_axes = list(range(ndim))
        axis = self.axis % ndim
        del(reordered_axes[axis])
        reordered_axes = [axis] + reordered_axes
        dims = ''.join('dims[%d], ' % i for i in reordered_axes[1:])
        prep_output = ''
        if self.return_values:
            def_dvstrides = 'const ssize_t *dvstrides = PyGpuArray_STRIDES(%s)' % yv
            params_dv = '%s->ga.data, %s->ga.offset,\n' % (yv, yv)
            params_dv += ''.join('dvstrides[%d], ' % i for i in reordered_axes)
            prep_output += '''
    if (0 != theano_prep_output(
        &%(yv)s, %(ndim)d, odims,
        %(inp_dtc)s, GA_C_ORDER, %(ctx)s)) {
        %(fail)s;
    }\n''' % locals()
        else:
            def_dvstrides = params_dv = ''

        if self.return_indices:
            def_distrides = 'const ssize_t *distrides = PyGpuArray_STRIDES(%s)' % yi
            params_di = '%s->ga.data, %s->ga.offset,\n' % (yi, yi)
            params_di += ''.join('distrides[%d], ' % i for i in reordered_axes)
            prep_output += '''
    if (0 != theano_prep_output(
        &%(yi)s, %(ndim)d, odims,
        %(out_dtc)s, GA_C_ORDER, %(ctx)s)) {
        %(fail)s;
    }\n''' % locals()
        else:
            def_distrides = params_di = ''
        sstrides = ', '.join('sstrides[%d]' % i for i in reordered_axes)
        code = '''
{
    const ssize_t k_ = ((%(k_dtype)s*)(PyArray_DATA(%(k)s)))[0];
    const size_t *dims = PyGpuArray_DIMS(%(x)s);
    size_t odims[%(ndim)d];
    for (int i=0; i<%(ndim)d; i++)
        odims[i] = dims[i];

    odims[%(axis)d] = k_>=0 ? k_ : -k_;

    if (0 == odims[%(axis)d]) {
        PyErr_SetString(
            PyExc_ValueError,
            "topk: kth must not be zero");
        %(fail)s;
    } else if (dims[%(axis)d] < odims[%(axis)d]) {
        PyErr_SetString(
            PyExc_ValueError,
            "topk: kth cannot be larger than the size of specified axis %(axis)d");
        %(fail)s;
    }
    %(prep_output)s

    size_t grid_size=1, block_size=1;
    for (int i=0; i<%(ndim)d; ++i) {
        if (i!=%(axis)d)
            grid_size *= dims[i];
        else
            block_size = dims[i];
    }
    // round up to multiples of warp size
    block_size = ((block_size + %(WARP_SIZE)d - 1) / %(WARP_SIZE)d) * %(WARP_SIZE)d;

    if (grid_size > %(MAX_BPG)d) {
        PyErr_SetString(
            PyExc_ValueError,
            "topk: too many slices to work with, expected <= %(MAX_BPG)d");
        %(fail)s;
    }

    %(def_dvstrides)s;
    %(def_distrides)s;
    const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);

    int err;
    if (dims[%(axis)d] > (1u << 31)) {
        block_size = %(MAX_TPB)d;
        err = k_topk_dense_xlarge_call(
                1, &grid_size, &block_size, 0,
                %(dims)s
                %(params_dv)s
                %(params_di)s
                k_,
                %(x)s->ga.data,
                %(x)s->ga.offset,
                %(sstrides)s,
                dims[%(axis)d]
        );
    } else if (block_size > %(MAX_TPB)d) {
        block_size = %(MAX_TPB)d;
        err = k_topk_dense_large_call(
                1, &grid_size, &block_size, 0,
                %(dims)s
                %(params_dv)s
                %(params_di)s
                k_,
                %(x)s->ga.data,
                %(x)s->ga.offset,
                %(sstrides)s,
                dims[%(axis)d]
        );
    } else {
        err = k_topk_dense_call(
                1, &grid_size, &block_size, 0,
                %(dims)s
                %(params_dv)s
                %(params_di)s
                k_,
                %(x)s->ga.data,
                %(x)s->ga.offset,
                %(sstrides)s,
                dims[%(axis)d]
        );
    }
    if (err != GA_NO_ERROR) {
        PyErr_SetString(
            PyExc_RuntimeError,
            "topk: gpu kernel failed to execute");
        %(fail)s;
    }
}
        '''
        return code % locals()

    def make_node(self, inp, kth):
        ctx_name = infer_context_name(inp)
        inp = as_gpuarray_variable(inp, ctx_name)
        kth = as_tensor_variable(kth)
        bcast = inp.type.broadcastable
        outs = []
        if self.return_values:
            outs.append(inp.type())
        if self.return_indices:
            outs.append(GpuArrayType(
                dtype=self.idx_dtype,
                broadcastable=bcast,
                context_name=ctx_name)())
        return Apply(self, [inp, kth], outs)

    def get_params(self, node):
        return node.inputs[0].type.context


class ValuesEqApproxNoOrder():
    """
    We ignore the order of elements on a given axis during the comparison.
    """

    def __init__(self, axis):
        self.axis = axis

    def __call__(self, val1, val2):
        v1 = np.sort(val1, axis=self.axis)
        v2 = np.sort(val2, axis=self.axis)
        ret = theano.tensor.type.values_eq_approx(v1, v2)
        return ret


@register_opt('fast_compile')
@op_lifter([TopKOp], cuda_only=True)
@register_opt2([TopKOp], 'fast_compile')
def local_gpua_topkop(op, ctx_name, inputs, outputs):
    axis = op.axis
    rv = op.return_values
    ri = op.return_indices
    x, k = inputs
    x = as_gpuarray_variable(x, ctx_name)
    if op.sorted:
        return
    gpu_op = GpuTopKOp(
        axis=axis,
        sorted=op.sorted,
        idx_dtype=op.idx_dtype,
        return_values=rv,
        return_indices=ri)
    rets = gpu_op(x, k, return_list=True)
    c = ValuesEqApproxNoOrder(axis)
    for r in rets:
        r.tag.values_eq_approx = c
    return rets