File: nested_tail_strategies.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (207 lines) | stat: -rw-r--r-- 6,752 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#include "Halide.h"

using namespace Halide;

size_t largest_allocation = 0;

void *my_malloc(JITUserContext *user_context, size_t x) {
    largest_allocation = std::max(x, largest_allocation);
    void *orig = malloc(x + 32);
    void *ptr = (void *)((((size_t)orig + 32) >> 5) << 5);
    ((void **)ptr)[-1] = orig;
    return ptr;
}

void my_free(JITUserContext *user_context, void *ptr) {
    free(((void **)ptr)[-1]);
}

void check(Func out, int line, std::vector<TailStrategy> tails) {
    bool has_round_up =
        std::find(tails.begin(), tails.end(), TailStrategy::RoundUp) != tails.end() ||
        std::find(tails.begin(), tails.end(), TailStrategy::RoundUpAndBlend) != tails.end() ||
        std::find(tails.begin(), tails.end(), TailStrategy::PredicateLoads) != tails.end() ||
        std::find(tails.begin(), tails.end(), TailStrategy::PredicateStores) != tails.end();
    bool has_shift_inwards =
        std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwards) != tails.end() ||
        std::find(tails.begin(), tails.end(), TailStrategy::ShiftInwardsAndBlend) != tails.end();

    std::vector<int> sizes_to_try;

    // A size that's a multiple of all the splits should always be
    // exact
    sizes_to_try.push_back(1024);

    // Sizes larger than any of the splits should be fine if we don't
    // have any roundups. The largest split we have is 128
    if (!has_round_up) {
        sizes_to_try.push_back(130);
    }

    // Tiny sizes are fine if we only have GuardWithIf
    if (!has_round_up && !has_shift_inwards) {
        sizes_to_try.push_back(3);
    }

    out.jit_handlers().custom_malloc = my_malloc;
    out.jit_handlers().custom_free = my_free;

    for (int s : sizes_to_try) {
        largest_allocation = 0;
        out.realize({s});
        size_t expected = (s + 1) * 4;
        size_t tolerance = 3 * sizeof(int);
        if (largest_allocation > expected + tolerance) {
            std::cerr << "Failure on line " << line << "\n"
                      << "with tail strategies: ";
            for (auto t : tails) {
                std::cerr << t << " ";
            }
            std::cerr << "\n allocation of " << largest_allocation
                      << " bytes is too large. Expected " << expected + tolerance << "\n";
            abort();
        }
    }
}

int main(int argc, char **argv) {
    if (get_jit_target_from_environment().arch == Target::WebAssembly) {
        printf("[SKIP] WebAssembly JIT does not support custom allocators.\n");
        return 0;
    }

    // We'll randomly subsample these tests, because otherwise there are too many of them.
    std::mt19937 rng(0);
    int seed = argc > 1 ? atoi(argv[1]) : time(nullptr);
    rng.seed(seed);
    std::cout << "Nested tail strategies seed: " << seed << "\n";

    // Test random compositions of tail strategies in simple
    // producer-consumer pipelines. The bounds being tight sometimes
    // depends on the simplifier being able to cancel out things.

    TailStrategy tails[] = {
        TailStrategy::RoundUp,
        TailStrategy::GuardWithIf,
        TailStrategy::ShiftInwards,
        TailStrategy::RoundUpAndBlend,
        TailStrategy::ShiftInwardsAndBlend};

    TailStrategy innermost_tails[] = {
        TailStrategy::RoundUp,
        TailStrategy::GuardWithIf,
        TailStrategy::PredicateLoads,
        TailStrategy::PredicateStores,
        TailStrategy::ShiftInwards,
        TailStrategy::RoundUpAndBlend,
        TailStrategy::ShiftInwardsAndBlend};

    // Two stages. First stage computed at tiles of second.
    for (auto t1 : innermost_tails) {
        for (auto t2 : innermost_tails) {
            Func in, f, g;
            Var x;

            in(x) = x;
            f(x) = in(x);
            g(x) = f(x);

            Var xo, xi;
            g.split(x, xo, xi, 64, t1);
            f.compute_at(g, xo).split(x, xo, xi, 8, t2);
            in.compute_root();

            check(g, __LINE__, {t1, t2});
        }
    }

    // Three stages. First stage computed at tiles of second, second
    // stage computed at tiles of third.
    for (auto t1 : innermost_tails) {
        for (auto t2 : innermost_tails) {
            for (auto t3 : innermost_tails) {
                if ((rng() & 7) != 0) {
                    continue;
                }

                Func in("in"), f("f"), g("g"), h("h");
                Var x;

                in(x) = x;
                f(x) = in(x);
                g(x) = f(x);
                h(x) = g(x);

                Var xo, xi;
                h.split(x, xo, xi, 64, t1);
                g.compute_at(h, xo).split(x, xo, xi, 16, t2);
                f.compute_at(g, xo).split(x, xo, xi, 4, t3);
                in.compute_root();

                check(h, __LINE__, {t1, t2, t3});
            }
        }
    }

    // Three stages. First stage computed at tiles of third, second
    // stage computed at smaller tiles of third.
    for (auto t1 : tails) {
        for (auto t2 : innermost_tails) {
            for (auto t3 : innermost_tails) {
                if ((rng() & 7) != 0) {
                    continue;
                }

                Func in, f, g, h;
                Var x;

                in(x) = x;
                f(x) = in(x);
                g(x) = f(x);
                h(x) = g(x);

                Var xo, xi, xii, xio;
                h.split(x, xo, xi, 128, t1).split(xi, xio, xii, 64);
                g.compute_at(h, xio).split(x, xo, xi, 8, t2);
                f.compute_at(h, xo).split(x, xo, xi, 8, t3);
                in.compute_root();

                check(h, __LINE__, {t1, t2, t3});
            }
        }
    }

    // Same as above, but the splits on the output are composed in
    // reverse order so we don't get a perfect split on the inner one
    // (but can handle smaller outputs).
    for (auto t1 : innermost_tails) {
        for (auto t2 : tails) {
            for (auto t3 : innermost_tails) {
                for (auto t4 : tails) {
                    if ((rng() & 63) != 0) {
                        continue;
                    }

                    Func in("in"), f("f"), g("g"), h("h");
                    Var x;

                    in(x) = x;
                    f(x) = in(x);
                    g(x) = f(x);
                    h(x) = g(x);

                    Var xo, xi, xoo, xoi;
                    h.split(x, xo, xi, 64, t1).split(xo, xoo, xoi, 2, t2);
                    g.compute_at(h, xoi).split(x, xo, xi, 8, t3);
                    f.compute_at(h, xoo).split(x, xo, xi, 8, t4);
                    in.compute_root();

                    check(h, __LINE__, {t1, t2, t3, t4});
                }
            }
        }
    }

    printf("Success!\n");
    return 0;
}