1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
#include "Halide.h"
#include <stdio.h>
extern "C" HALIDE_EXPORT_SYMBOL int flip_x_and_sum(halide_buffer_t *in1, halide_buffer_t *in2, halide_buffer_t *out) {
int min = out->dim[0].min;
int max = out->dim[0].min + out->dim[0].extent - 1;
int extent = out->dim[0].extent;
int flipped_min = -max;
int flipped_max = -min;
if (in1->host == nullptr || in2->host == nullptr) {
// If any of the inputs have a null host pointer, we're in
// bounds inference mode, and should mutate those input
// buffers that have a null host pointer.
printf("Doing flip_x_and_sum bounds inference over [%d %d]\n", min, max);
if (in1->is_bounds_query()) {
in1->dim[0].min = flipped_min;
in1->dim[0].extent = extent;
}
if (in2->is_bounds_query()) {
in2->dim[0].min = flipped_min;
in2->dim[0].extent = extent;
}
// We don't mutate the output buffer, because we can handle
// any size output.
} else {
assert(in1->type == halide_type_of<uint8_t>());
assert(in2->type == halide_type_of<int32_t>());
assert(out->type == halide_type_of<uint8_t>());
printf("Computing flip_x_and_sum over [%d %d]\n", min, max);
// Check the inputs are as large as we expected. They should
// be, if the above bounds inference code is right.
assert(in1->dim[0].min <= flipped_min &&
in1->dim[0].min + in1->dim[0].extent > flipped_max);
assert(in2->dim[0].min <= flipped_min &&
in2->dim[0].min + in2->dim[0].extent > flipped_max);
// Check the strides are what we want.
assert(in1->dim[0].stride == 1 && in2->dim[0].stride == 1 && out->dim[0].stride == 1);
// Get pointers to the origin from each of the inputs (because
// we're flipping about the origin)
uint8_t *dst = (uint8_t *)(out->host) - out->dim[0].min;
uint8_t *src1 = (uint8_t *)(in1->host) - in1->dim[0].min;
int *src2 = (int *)(in2->host) - in2->dim[0].min;
// Do the flip.
for (int i = min; i <= max; i++) {
dst[i] = src1[-i] + src2[-i];
}
}
return 0;
}
using namespace Halide;
int main(int argc, char **argv) {
Func f, g, h;
Var x;
// Make some input data in the range [-99, 0]
Buffer<uint8_t> input(100);
input.set_min(-99);
lambda(x, cast<uint8_t>(x * x)).realize(input);
assert(input(-99) == (uint8_t)(-99 * -99));
f(x) = x * x;
std::vector<ExternFuncArgument> args(2);
args[0] = input;
args[1] = f;
g.define_extern("flip_x_and_sum", args, UInt(8), 1);
h(x) = g(x) * 2;
f.compute_root();
Var xi;
h.vectorize(x, 8).unroll(x, 2).split(x, x, xi, 4).parallel(x);
Pipeline p({h, g});
Buffer<uint8_t> h_buf(100), g_buf(100);
p.realize({h_buf, g_buf});
for (int i = 0; i < 100; i++) {
uint8_t correct = 4 * i * i;
if (h_buf(i) != correct) {
printf("result(%d) = %d instead of %d\n", i, h_buf(i), correct);
return 1;
}
}
printf("Success!\n");
return 0;
}
|