1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
|
// Halide tutorial lesson 17: Reductions over non-rectangular domains
// This lesson demonstrates how to define updates that iterate over
// subsets of a reduction domain using predicates.
// On linux, you can compile and run it like so:
// g++ lesson_17*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_17 -std=c++17
// LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_17
// On os x:
// g++ lesson_17*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_17 -std=c++17
// DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_17
// If you have the entire Halide source tree, you can also build it by
// running:
// make tutorial_lesson_17_predicated_rdom
// in a shell with the current directory at the top of the halide
// source tree.
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
int main(int argc, char **argv) {
// In lesson 9, we learned how to use RDom to define a "reduction
// domain" to use in a Halide update definition. The domain
// defined by an RDom, however, is always rectangular, and the
// update occurs at every point in that rectangular domain. In
// some cases, we might want to iterate over some non-rectangular
// domain, e.g. a circle. We can achieve this behavior by using
// the RDom::where directive.
{
// Starting with this pure definition:
Func circle("circle");
Var x("x"), y("y");
circle(x, y) = x + y;
// Say we want an update that multiplies by two the values inside a
// circular region centered at (3, 3) with radius of 3. To do
// this, we first define the minimal bounding box over the
// circular region using an RDom.
RDom r(0, 7, 0, 7);
// The bounding box does not have to be minimal. In fact, the
// box can be of any size, as long it covers the region we'd
// like to update. However, the tighter the bounding box, the
// tighter the generated loop bounds will be. Halide will
// tighten the loop bounds automatically when possible, but in
// general, it is better to define a minimal bounding box.
// Then, we use RDom::where to define the predicate over that
// bounding box, such that the update is performed only if the
// given predicate evaluates to true, i.e. within the circular
// region.
r.where((r.x - 3) * (r.x - 3) + (r.y - 3) * (r.y - 3) <= 10);
// After defining the predicate, we then define the update.
circle(r.x, r.y) *= 2;
Buffer<int> halide_result = circle.realize({7, 7});
// See figures/lesson_17_rdom_circular.mp4 for a visualization of
// what this did.
// The equivalent C is:
int c_result[7][7];
for (int y = 0; y < 7; y++) {
for (int x = 0; x < 7; x++) {
c_result[y][x] = x + y;
}
}
for (int r_y = 0; r_y < 7; r_y++) {
for (int r_x = 0; r_x < 7; r_x++) {
// Update is only performed if the predicate evaluates to true.
if ((r_x - 3) * (r_x - 3) + (r_y - 3) * (r_y - 3) <= 10) {
c_result[r_y][r_x] *= 2;
}
}
}
// Check the results match:
for (int y = 0; y < 7; y++) {
for (int x = 0; x < 7; x++) {
if (halide_result(x, y) != c_result[y][x]) {
printf("halide_result(%d, %d) = %d instead of %d\n",
x, y, halide_result(x, y), c_result[y][x]);
return -1;
}
}
}
}
{
// We can also define multiple predicates over an RDom. Let's
// say now we want the update to happen within some triangular
// region. To do this we define three predicates, where each
// corresponds to one side of the triangle.
Func triangle("triangle");
Var x("x"), y("y");
triangle(x, y) = x + y;
// First, let's define the minimal bounding box over the triangular
// region.
RDom r(0, 8, 0, 10);
// Next, let's add the three predicates to the RDom using
// multiple calls to RDom::where
r.where(r.x + r.y > 5);
r.where(3 * r.y - 2 * r.x < 15);
r.where(4 * r.x - r.y < 20);
// We can also pack the multiple predicates into one like so:
// r.where((r.x + r.y > 5) && (3*r.y - 2*r.x < 15) && (4*r.x - r.y < 20));
// Then define the update.
triangle(r.x, r.y) *= 2;
Buffer<int> halide_result = triangle.realize({10, 10});
// See figures/lesson_17_rdom_triangular.mp4 for a
// visualization of what this did.
// The equivalent C is:
int c_result[10][10];
for (int y = 0; y < 10; y++) {
for (int x = 0; x < 10; x++) {
c_result[y][x] = x + y;
}
}
for (int r_y = 0; r_y < 10; r_y++) {
for (int r_x = 0; r_x < 8; r_x++) {
// Update is only performed if the predicate evaluates to true.
if ((r_x + r_y > 5) && (3 * r_y - 2 * r_x < 15) && (4 * r_x - r_y < 20)) {
c_result[r_y][r_x] *= 2;
}
}
}
// Check the results match:
for (int y = 0; y < 10; y++) {
for (int x = 0; x < 10; x++) {
if (halide_result(x, y) != c_result[y][x]) {
printf("halide_result(%d, %d) = %d instead of %d\n",
x, y, halide_result(x, y), c_result[y][x]);
return -1;
}
}
}
}
{
// The predicate is not limited to the RDom's variables only
// (r.x, r.y, ...). It can also refer to free variables in
// the update definition, and even make calls to other Funcs,
// or make recursive calls to the same Func. For example:
Func f("f"), g("g");
Var x("x"), y("y");
f(x, y) = 2 * x + y;
g(x, y) = x + y;
// This RDom's predicates depend on the initial value of 'f'.
RDom r1(0, 5, 0, 5);
r1.where(f(r1.x, r1.y) >= 4);
r1.where(f(r1.x, r1.y) <= 7);
f(r1.x, r1.y) /= 10;
f.compute_root();
// While this one involves calls to another Func.
RDom r2(1, 3, 1, 3);
r2.where(f(r2.x, r2.y) < 1);
g(r2.x, r2.y) += 17;
Buffer<int> halide_result_g = g.realize({5, 5});
// See figures/lesson_17_rdom_calls_in_predicate.mp4 for a
// visualization of what this did.
// The equivalent C for 'f' is:
int c_result_f[5][5];
for (int y = 0; y < 5; y++) {
for (int x = 0; x < 5; x++) {
c_result_f[y][x] = 2 * x + y;
}
}
for (int r1_y = 0; r1_y < 5; r1_y++) {
for (int r1_x = 0; r1_x < 5; r1_x++) {
// Update is only performed if the predicate evaluates to true.
if ((c_result_f[r1_y][r1_x] >= 4) && (c_result_f[r1_y][r1_x] <= 7)) {
c_result_f[r1_y][r1_x] /= 10;
}
}
}
// And, the equivalent C for 'g' is:
int c_result_g[5][5];
for (int y = 0; y < 5; y++) {
for (int x = 0; x < 5; x++) {
c_result_g[y][x] = x + y;
}
}
for (int r2_y = 1; r2_y < 4; r2_y++) {
for (int r1_x = 1; r1_x < 4; r1_x++) {
// Update is only performed if the predicate evaluates to true.
if (c_result_f[r2_y][r1_x] < 1) {
c_result_g[r2_y][r1_x] += 17;
}
}
}
// Check the results match:
for (int y = 0; y < 5; y++) {
for (int x = 0; x < 5; x++) {
if (halide_result_g(x, y) != c_result_g[y][x]) {
printf("halide_result_g(%d, %d) = %d instead of %d\n",
x, y, halide_result_g(x, y), c_result_g[y][x]);
return -1;
}
}
}
}
printf("Success!\n");
return 0;
}
|