File: compute_outermost.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (52 lines) | stat: -rw-r--r-- 1,329 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

Func blur(Func in) {
    Func blurx, blury;
    Var x, y;
    blurx(x, y) = in(x - 1, y) + in(x, y) + in(x + 1, y);
    blury(x, y) = (blurx(x, y - 1) + blurx(x, y) + blurx(x, y + 1)) / 9;

    // Compute blurx at the same level as blury is computed at,
    // wherever that may be. Note that this also means blurx would be
    // included in any specializations of blury.
    blurx.compute_at(blury, Var::outermost());

    return blury;
}

int main(int argc, char **argv) {
    Func fn1, fn2;
    Var x, y;

    fn1(x, y) = x + y;
    fn2(x, y) = 2 * x + 3 * y;

    Func blur_fn1 = blur(fn1);
    Func blur_fn2 = blur(fn2);

    Func out;
    out(x, y) = blur_fn1(x, y) + blur_fn2(x, y);

    Var xi, yi, t;
    out.tile(x, y, xi, yi, 16, 16).fuse(x, y, t).parallel(t);
    blur_fn1.compute_at(out, t);
    blur_fn2.compute_at(out, t);

    Buffer<int> result = out.realize({256, 256});
    for (int y = 0; y < 256; y++) {
        for (int x = 0; x < 256; x++) {
            int correct = 3 * x + 4 * y;
            if (result(x, y) != correct) {
                printf("result(%d, %d) = %d instead of %d\n",
                       x, y, result(x, y), correct);
                return 1;
            }
        }
    }

    printf("Success!\n");
    return 0;
}