File: lesson_17_predicated_rdom.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (227 lines) | stat: -rw-r--r-- 8,270 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
// Halide tutorial lesson 17: Reductions over non-rectangular domains

// This lesson demonstrates how to define updates that iterate over
// subsets of a reduction domain using predicates.

// On linux, you can compile and run it like so:
// g++ lesson_17*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_17 -std=c++17
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_17

// On os x:
// g++ lesson_17*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_17 -std=c++17
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_17

// If you have the entire Halide source tree, you can also build it by
// running:
//    make tutorial_lesson_17_predicated_rdom
// in a shell with the current directory at the top of the halide
// source tree.

#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {

    // In lesson 9, we learned how to use RDom to define a "reduction
    // domain" to use in a Halide update definition. The domain
    // defined by an RDom, however, is always rectangular, and the
    // update occurs at every point in that rectangular domain. In
    // some cases, we might want to iterate over some non-rectangular
    // domain, e.g. a circle. We can achieve this behavior by using
    // the RDom::where directive.

    {
        // Starting with this pure definition:
        Func circle("circle");
        Var x("x"), y("y");
        circle(x, y) = x + y;

        // Say we want an update that multiplies by two the values inside a
        // circular region centered at (3, 3) with radius of 3. To do
        // this, we first define the minimal bounding box over the
        // circular region using an RDom.
        RDom r(0, 7, 0, 7);

        // The bounding box does not have to be minimal. In fact, the
        // box can be of any size, as long it covers the region we'd
        // like to update. However, the tighter the bounding box, the
        // tighter the generated loop bounds will be. Halide will
        // tighten the loop bounds automatically when possible, but in
        // general, it is better to define a minimal bounding box.

        // Then, we use RDom::where to define the predicate over that
        // bounding box, such that the update is performed only if the
        // given predicate evaluates to true, i.e. within the circular
        // region.
        r.where((r.x - 3) * (r.x - 3) + (r.y - 3) * (r.y - 3) <= 10);

        // After defining the predicate, we then define the update.
        circle(r.x, r.y) *= 2;

        Buffer<int> halide_result = circle.realize({7, 7});

        // See figures/lesson_17_rdom_circular.mp4 for a visualization of
        // what this did.

        // The equivalent C is:
        int c_result[7][7];
        for (int y = 0; y < 7; y++) {
            for (int x = 0; x < 7; x++) {
                c_result[y][x] = x + y;
            }
        }
        for (int r_y = 0; r_y < 7; r_y++) {
            for (int r_x = 0; r_x < 7; r_x++) {
                // Update is only performed if the predicate evaluates to true.
                if ((r_x - 3) * (r_x - 3) + (r_y - 3) * (r_y - 3) <= 10) {
                    c_result[r_y][r_x] *= 2;
                }
            }
        }

        // Check the results match:
        for (int y = 0; y < 7; y++) {
            for (int x = 0; x < 7; x++) {
                if (halide_result(x, y) != c_result[y][x]) {
                    printf("halide_result(%d, %d) = %d instead of %d\n",
                           x, y, halide_result(x, y), c_result[y][x]);
                    return -1;
                }
            }
        }
    }

    {
        // We can also define multiple predicates over an RDom. Let's
        // say now we want the update to happen within some triangular
        // region. To do this we define three predicates, where each
        // corresponds to one side of the triangle.
        Func triangle("triangle");
        Var x("x"), y("y");
        triangle(x, y) = x + y;
        // First, let's define the minimal bounding box over the triangular
        // region.
        RDom r(0, 8, 0, 10);
        // Next, let's add the three predicates to the RDom using
        // multiple calls to RDom::where
        r.where(r.x + r.y > 5);
        r.where(3 * r.y - 2 * r.x < 15);
        r.where(4 * r.x - r.y < 20);

        // We can also pack the multiple predicates into one like so:
        // r.where((r.x + r.y > 5) && (3*r.y - 2*r.x < 15) && (4*r.x - r.y < 20));

        // Then define the update.
        triangle(r.x, r.y) *= 2;

        Buffer<int> halide_result = triangle.realize({10, 10});

        // See figures/lesson_17_rdom_triangular.mp4 for a
        // visualization of what this did.

        // The equivalent C is:
        int c_result[10][10];
        for (int y = 0; y < 10; y++) {
            for (int x = 0; x < 10; x++) {
                c_result[y][x] = x + y;
            }
        }
        for (int r_y = 0; r_y < 10; r_y++) {
            for (int r_x = 0; r_x < 8; r_x++) {
                // Update is only performed if the predicate evaluates to true.
                if ((r_x + r_y > 5) && (3 * r_y - 2 * r_x < 15) && (4 * r_x - r_y < 20)) {
                    c_result[r_y][r_x] *= 2;
                }
            }
        }

        // Check the results match:
        for (int y = 0; y < 10; y++) {
            for (int x = 0; x < 10; x++) {
                if (halide_result(x, y) != c_result[y][x]) {
                    printf("halide_result(%d, %d) = %d instead of %d\n",
                           x, y, halide_result(x, y), c_result[y][x]);
                    return -1;
                }
            }
        }
    }

    {
        // The predicate is not limited to the RDom's variables only
        // (r.x, r.y, ...).  It can also refer to free variables in
        // the update definition, and even make calls to other Funcs,
        // or make recursive calls to the same Func. For example:
        Func f("f"), g("g");
        Var x("x"), y("y");
        f(x, y) = 2 * x + y;
        g(x, y) = x + y;

        // This RDom's predicates depend on the initial value of 'f'.
        RDom r1(0, 5, 0, 5);
        r1.where(f(r1.x, r1.y) >= 4);
        r1.where(f(r1.x, r1.y) <= 7);
        f(r1.x, r1.y) /= 10;

        f.compute_root();

        // While this one involves calls to another Func.
        RDom r2(1, 3, 1, 3);
        r2.where(f(r2.x, r2.y) < 1);
        g(r2.x, r2.y) += 17;

        Buffer<int> halide_result_g = g.realize({5, 5});

        // See figures/lesson_17_rdom_calls_in_predicate.mp4 for a
        // visualization of what this did.

        // The equivalent C for 'f' is:
        int c_result_f[5][5];
        for (int y = 0; y < 5; y++) {
            for (int x = 0; x < 5; x++) {
                c_result_f[y][x] = 2 * x + y;
            }
        }
        for (int r1_y = 0; r1_y < 5; r1_y++) {
            for (int r1_x = 0; r1_x < 5; r1_x++) {
                // Update is only performed if the predicate evaluates to true.
                if ((c_result_f[r1_y][r1_x] >= 4) && (c_result_f[r1_y][r1_x] <= 7)) {
                    c_result_f[r1_y][r1_x] /= 10;
                }
            }
        }

        // And, the equivalent C for 'g' is:
        int c_result_g[5][5];
        for (int y = 0; y < 5; y++) {
            for (int x = 0; x < 5; x++) {
                c_result_g[y][x] = x + y;
            }
        }
        for (int r2_y = 1; r2_y < 4; r2_y++) {
            for (int r1_x = 1; r1_x < 4; r1_x++) {
                // Update is only performed if the predicate evaluates to true.
                if (c_result_f[r2_y][r1_x] < 1) {
                    c_result_g[r2_y][r1_x] += 17;
                }
            }
        }

        // Check the results match:
        for (int y = 0; y < 5; y++) {
            for (int x = 0; x < 5; x++) {
                if (halide_result_g(x, y) != c_result_g[y][x]) {
                    printf("halide_result_g(%d, %d) = %d instead of %d\n",
                           x, y, halide_result_g(x, y), c_result_g[y][x]);
                    return -1;
                }
            }
        }
    }

    printf("Success!\n");

    return 0;
}