File: multiple_outputs_extern.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (100 lines) | stat: -rw-r--r-- 3,175 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include "Halide.h"
#include <stdio.h>

extern "C" HALIDE_EXPORT_SYMBOL int flip_x_and_sum(halide_buffer_t *in1, halide_buffer_t *in2, halide_buffer_t *out) {
    int min = out->dim[0].min;
    int max = out->dim[0].min + out->dim[0].extent - 1;

    int extent = out->dim[0].extent;
    int flipped_min = -max;
    int flipped_max = -min;

    if (in1->host == nullptr || in2->host == nullptr) {
        // If any of the inputs have a null host pointer, we're in
        // bounds inference mode, and should mutate those input
        // buffers that have a null host pointer.
        printf("Doing flip_x_and_sum bounds inference over [%d %d]\n", min, max);
        if (in1->is_bounds_query()) {
            in1->dim[0].min = flipped_min;
            in1->dim[0].extent = extent;
        }
        if (in2->is_bounds_query()) {
            in2->dim[0].min = flipped_min;
            in2->dim[0].extent = extent;
        }
        // We don't mutate the output buffer, because we can handle
        // any size output.
    } else {
        assert(in1->type == halide_type_of<uint8_t>());
        assert(in2->type == halide_type_of<int32_t>());
        assert(out->type == halide_type_of<uint8_t>());

        printf("Computing flip_x_and_sum over [%d %d]\n", min, max);

        // Check the inputs are as large as we expected. They should
        // be, if the above bounds inference code is right.
        assert(in1->dim[0].min <= flipped_min &&
               in1->dim[0].min + in1->dim[0].extent > flipped_max);
        assert(in2->dim[0].min <= flipped_min &&
               in2->dim[0].min + in2->dim[0].extent > flipped_max);

        // Check the strides are what we want.
        assert(in1->dim[0].stride == 1 && in2->dim[0].stride == 1 && out->dim[0].stride == 1);

        // Get pointers to the origin from each of the inputs (because
        // we're flipping about the origin)
        uint8_t *dst = (uint8_t *)(out->host) - out->dim[0].min;
        uint8_t *src1 = (uint8_t *)(in1->host) - in1->dim[0].min;
        int *src2 = (int *)(in2->host) - in2->dim[0].min;

        // Do the flip.
        for (int i = min; i <= max; i++) {
            dst[i] = src1[-i] + src2[-i];
        }
    }

    return 0;
}

using namespace Halide;

int main(int argc, char **argv) {
    Func f, g, h;
    Var x;

    // Make some input data in the range [-99, 0]
    Buffer<uint8_t> input(100);
    input.set_min(-99);
    lambda(x, cast<uint8_t>(x * x)).realize(input);

    assert(input(-99) == (uint8_t)(-99 * -99));

    f(x) = x * x;

    std::vector<ExternFuncArgument> args(2);
    args[0] = input;
    args[1] = f;
    g.define_extern("flip_x_and_sum", args, UInt(8), 1);

    h(x) = g(x) * 2;

    f.compute_root();
    Var xi;
    h.vectorize(x, 8).unroll(x, 2).split(x, x, xi, 4).parallel(x);

    Pipeline p({h, g});

    Buffer<uint8_t> h_buf(100), g_buf(100);
    p.realize({h_buf, g_buf});

    for (int i = 0; i < 100; i++) {
        uint8_t correct = 4 * i * i;
        if (h_buf(i) != correct) {
            printf("result(%d) = %d instead of %d\n", i, h_buf(i), correct);
            return 1;
        }
    }

    printf("Success!\n");
    return 0;
}