File: extern_output_expansion.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (71 lines) | stat: -rw-r--r-- 2,086 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "Halide.h"
#include <stdio.h>

// out(x) = in(x) * x;
extern "C" HALIDE_EXPORT_SYMBOL int extern_stage(halide_buffer_t *in, halide_buffer_t *out) {
    assert(in->type == halide_type_of<int>());
    assert(out->type == halide_type_of<int>());
    if (in->host == nullptr || out->host == nullptr) {
        // We require input size = output size, and just for fun,
        // we'll require that the output size must be a multiple of 17

        if (out->is_bounds_query()) {
            out->dim[0].extent = ((out->dim[0].extent + 16) / 17) * 17;
        }
        if (in->is_bounds_query()) {
            in->dim[0].extent = out->dim[0].extent;
            in->dim[0].min = out->dim[0].min;
        }

    } else {
        printf("in: %d %d, out: %d %d\n",
               in->dim[0].min, in->dim[0].extent,
               out->dim[0].min, out->dim[0].extent);
        assert(out->dim[0].extent % 17 == 0);
        int32_t *in_origin = (int32_t *)in->host - in->dim[0].min;
        int32_t *out_origin = (int32_t *)out->host - out->dim[0].min;
        for (int i = out->dim[0].min; i < out->dim[0].min + out->dim[0].extent; i++) {
            out_origin[i] = in_origin[i] * i;
        }
    }
    return 0;
}

using namespace Halide;

int main(int argc, char **argv) {

    // We have two variants we want to test
    for (int i = 0; i < 2; i++) {
        Func f, g, h;
        Var x;
        f(x) = x * x;

        g.define_extern("extern_stage", {f}, Int(32), 1);

        h(x) = g(x) * 2;

        // Compute h in 10-wide sections
        Var xo;
        h.split(x, xo, x, 10);
        f.compute_root();
        if (i == 0) {
            g.compute_at(h, xo);
        } else {
            g.compute_root();
        }

        Buffer<int32_t> result = h.realize({100});

        for (int i = 0; i < 100; i++) {
            int32_t correct = i * i * i * 2;
            if (result(i) != correct) {
                printf("result(%d) = %d instead of %d\n", i, result(i), correct);
                return 1;
            }
        }
    }

    printf("Success!\n");
    return 0;
}