File: stream_compaction.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (53 lines) | stat: -rw-r--r-- 1,569 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {

    // A zero-one function:
    Func f;
    Var x;
    f(x) = select((x % 7 == 0) || (x % 5 == 0), 1, 0);
    f.compute_root();

    // Take the cumulative sum. To do this part in parallel see the parallel_reductions test.
    Func cum_sum;
    cum_sum(x) = 0;
    RDom r(0, 1000);
    cum_sum(r + 1) = f(r) + cum_sum(r);
    cum_sum.compute_root();

    // Write out the coordinates of all the ones. We'd use Tuples in the 2d case.
    Func ones;
    ones(x) = -1;  // Initialize to -1 as a sentinel.

    // Figure out which bin each coordinate should go into. Need a
    // clamp so that Halide knows how much space to allocate for ones.
    Expr bin = clamp(cum_sum(r), 0, 1000);

    // In this context, undef means skip writing when f(r) != 1
    ones(bin) = select(f(r) == 1, r, undef<int>());

    // This is actually safe to parallelize, because 'bin' is going to
    // be one-to-one with 'r' when f(r) == 1, but that's too subtle
    // for Halide to prove:
    ones.update().allow_race_conditions().parallel(r, 50);

    Buffer<int> result = ones.realize({1001});
    int next = 0;
    for (int i = 0; i < result.width(); i++) {
        if (result(i) != next) {
            printf("result(%d) = %d instead of %d\n", i, result(i), next);
            return 1;
        } else {
            do {
                next++;
            } while ((next % 5) && (next % 7));
        }
        if (next >= 1000) break;
    }

    printf("Success!\n");
    return 0;
}