1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
|
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
using namespace Halide::Internal;
// Wrapper class to call loop_carry on a given statement.
class LoopCarryWrapper : public IRMutator {
using IRMutator::visit;
int register_count_;
Stmt mutate(const Stmt &stmt) override {
return simplify(loop_carry(stmt, register_count_));
}
public:
LoopCarryWrapper(int register_count)
: register_count_(register_count) {
}
};
int main(int argc, char **argv) {
Func input;
Func g;
Func h;
Func f;
Var x, y, xo, yo, xi, yi;
input(x, y) = x + y;
Expr sum_expr = 0;
for (int ix = -100; ix <= 100; ix++) {
// Generate two chains of sums, but only one of them will be carried.
sum_expr += input(x, y + ix);
sum_expr += input(x + 13, y + 2 * ix);
}
g(x, y) = sum_expr;
h(x, y) = g(x, y) + 12;
f(x, y) = h(x, y);
// Make a maximum number of the carried values very large for the purpose
// of this test.
constexpr int kMaxRegisterCount = 1024;
f.add_custom_lowering_pass(new LoopCarryWrapper(kMaxRegisterCount));
const int size = 128;
f.compute_root()
.bound(x, 0, size)
.bound(y, 0, size);
h.compute_root()
.tile(x, y, xo, yo, xi, yi, 16, 16, TailStrategy::RoundUp);
g.compute_at(h, xo)
.reorder(y, x)
.vectorize(x, 4);
input.compute_root();
f.realize({size, size});
printf("Success!\n");
return 0;
}
|