File: sort.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (211 lines) | stat: -rw-r--r-- 6,699 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#include "Halide.h"
#include "halide_benchmark.h"
#include <algorithm>
#include <cstdio>

using namespace Halide;
using namespace Halide::Tools;

Var x("x"), y("y");

Func bitonic_sort(Func input, int size) {
    Func next, prev = input;

    Var xo("xo"), xi("xi");

    for (int pass_size = 1; pass_size < size; pass_size <<= 1) {
        for (int chunk_size = pass_size; chunk_size > 0; chunk_size >>= 1) {
            next = Func("bitonic_pass");
            Expr chunk_start = (x / (2 * chunk_size)) * (2 * chunk_size);
            Expr chunk_end = (x / (2 * chunk_size) + 1) * (2 * chunk_size);
            Expr chunk_middle = chunk_start + chunk_size;
            Expr chunk_index = x - chunk_start;
            if (pass_size == chunk_size && pass_size > 1) {
                // Flipped pass
                Expr partner = 2 * chunk_middle - x - 1;
                // We need a clamp here to help out bounds inference
                partner = clamp(partner, chunk_start, chunk_end - 1);
                next(x) = select(x < chunk_middle,
                                 min(prev(x), prev(partner)),
                                 max(prev(x), prev(partner)));

            } else {
                // Regular pass
                Expr partner = chunk_start + (chunk_index + chunk_size) % (chunk_size * 2);
                next(x) = select(x < chunk_middle,
                                 min(prev(x), prev(partner)),
                                 max(prev(x), prev(partner)));
            }

            if (pass_size > 1) {
                next.split(x, xo, xi, 2 * chunk_size);
            }
            if (chunk_size > 128) {
                next.parallel(xo);
            }
            next.compute_root();
            prev = next;
        }
    }

    return next;
}

// Merge sort contiguous chunks of size s in a 1d func.
Func merge_sort(Func input, int total_size) {
    std::vector<Func> stages;
    Func result;

    const int parallel_work_size = 512;

    Func parallel_stage("parallel_stage");

    // First gather the input into a 2D array of width four where each row is sorted
    {
        assert(input.dimensions() == 1);
        // Use a small sorting network
        Expr a0 = input(4 * y);
        Expr a1 = input(4 * y + 1);
        Expr a2 = input(4 * y + 2);
        Expr a3 = input(4 * y + 3);

        Expr b0 = min(a0, a1);
        Expr b1 = max(a0, a1);
        Expr b2 = min(a2, a3);
        Expr b3 = max(a2, a3);

        a0 = min(b0, b2);
        a1 = max(b0, b2);
        a2 = min(b1, b3);
        a3 = max(b1, b3);

        b0 = a0;
        b1 = min(a1, a2);
        b2 = max(a1, a2);
        b3 = a3;

        result(x, y) = select(x == 0, b0,
                              select(x == 1, b1,
                                     select(x == 2, b2, b3)));

        result.compute_at(parallel_stage, y).bound(x, 0, 4).unroll(x);

        stages.push_back(result);
    }

    // Now build up to the total size, merging each pair of rows
    for (int chunk_size = 4; chunk_size < total_size; chunk_size *= 2) {
        // "result" contains the sorted halves
        assert(result.dimensions() == 2);

        // Merge pairs of rows from the partial result
        Func merge_rows("merge_rows");
        RDom r(0, chunk_size * 2);

        // The first dimension of merge_rows is within the chunk, and the
        // second dimension is the chunk index.  Keeps track of two
        // pointers we're merging from and an output value.
        merge_rows(x, y) = Tuple(0, 0, cast(input.value().type(), 0));

        Expr candidate_a = merge_rows(r - 1, y)[0];
        Expr candidate_b = merge_rows(r - 1, y)[1];
        Expr valid_a = candidate_a < chunk_size;
        Expr valid_b = candidate_b < chunk_size;
        Expr value_a = result(clamp(candidate_a, 0, chunk_size - 1), 2 * y);
        Expr value_b = result(clamp(candidate_b, 0, chunk_size - 1), 2 * y + 1);
        merge_rows(r, y) = select(valid_a && ((value_a < value_b) || !valid_b),
                                  Tuple(candidate_a + 1, candidate_b, value_a),
                                  Tuple(candidate_a, candidate_b + 1, value_b));

        if (chunk_size <= parallel_work_size) {
            merge_rows.compute_at(parallel_stage, y);
        } else {
            merge_rows.compute_root();
        }

        if (chunk_size == parallel_work_size) {
            parallel_stage(x, y) = merge_rows(x, y)[2];
            parallel_stage.compute_root().parallel(y);
            result = parallel_stage;
        } else {
            result = lambda(x, y, merge_rows(x, y)[2]);
        }
    }

    // Convert back to 1D
    return lambda(x, result(x, 0));
}

int main(int argc, char **argv) {
    Target target = get_jit_target_from_environment();
    if (target.arch == Target::WebAssembly) {
        printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n");
        return 0;
    }

    const int N = 1 << 10;

    Buffer<int> data(N);
    for (int i = 0; i < N; i++) {
        data(i) = rand() & 0xfffff;
    }
    Func input = lambda(x, data(x));

    printf("Bitonic sort...\n");
    Func f = bitonic_sort(input, N);
    f.bound(x, 0, N);
    f.compile_jit();
    printf("Running...\n");
    Buffer<int> bitonic_sorted(N);
    f.realize(bitonic_sorted);
    double t_bitonic = benchmark([&]() {
        f.realize(bitonic_sorted);
    });

    printf("Merge sort...\n");
    f = merge_sort(input, N);
    f.bound(x, 0, N);
    f.compile_jit();
    printf("Running...\n");
    Buffer<int> merge_sorted(N);
    f.realize(merge_sorted);
    double t_merge = benchmark([&]() {
        f.realize(merge_sorted);
    });

    Buffer<int> correct(N);
    for (int i = 0; i < N; i++) {
        correct(i) = data(i);
    }
    printf("std::sort...\n");
    double t_std = benchmark([&]() {
        std::sort(&correct(0), &correct(N));
    });

    printf("Times:\n"
           "bitonic sort: %fms \n"
           "merge sort: %fms \n"
           "std::sort %fms\n",
           t_bitonic * 1e3, t_merge * 1e3, t_std * 1e3);

    if (N <= 100) {
        for (int i = 0; i < N; i++) {
            printf("%8d %8d %8d\n",
                   correct(i), bitonic_sorted(i), merge_sorted(i));
        }
    }

    for (int i = 0; i < N; i++) {
        if (bitonic_sorted(i) != correct(i)) {
            printf("bitonic sort failed: %d -> %d instead of %d\n", i, bitonic_sorted(i), correct(i));
            return 1;
        }
        if (merge_sorted(i) != correct(i)) {
            printf("merge sort failed: %d -> %d instead of %d\n", i, merge_sorted(i), correct(i));
            return 1;
        }
    }

    printf("Success!\n");
    return 0;
}