File: hoist_loop_invariant_if_statements.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (52 lines) | stat: -rw-r--r-- 1,493 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include "Halide.h"

using namespace Halide;
using namespace Halide::Internal;

int main(int argc, char **argv) {
    Func f, g;
    Var x, y;
    Param<bool> p;

    // In Stmt IR, if statements can be injected by GuardWithIf, RDom
    // predicates, specializations, and uses of undef. There are
    // various situations where an if statement can end up further
    // inside a loop nest than strictly necessary. Here's one:

    f(x, y) = select(p, x + y, undef<int>());
    g(x, y) = select(p, f(x, y), undef<int>());
    f.compute_at(g, x);

    // Both f and g get an if statement for p, which could instead be
    // a single combined top-level if statement. Trim-no-ops is
    // supposed to lift the if statement out of the loops to the top
    // level. Let's check if it worked.

    class Checker : public IRMutator {
        bool in_loop = false;
        Stmt visit(const For *op) override {
            ScopedValue<bool> old(in_loop, true);
            return IRMutator::visit(op);
        }
        Stmt visit(const IfThenElse *op) override {
            if_in_loop |= in_loop;
            return IRMutator::visit(op);
        }

    public:
        bool if_in_loop = false;
    } checker;

    g.add_custom_lowering_pass(&checker, []() {});

    p.set(true);
    g.realize({1024, 1024});

    if (checker.if_in_loop) {
        printf("Found an if statement inside a loop. This was not supposed to happen\n");
        return 1;
    }

    printf("Success!\n");
    return 0;
}