File: extract_concat_bits.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (152 lines) | stat: -rw-r--r-- 5,055 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include "Halide.h"

using namespace Halide;

class CountOps : public Internal::IRMutator {
    Expr visit(const Internal::Reinterpret *op) override {
        std::cerr << Expr(op) << " " << op->type.lanes() << " " << op->value.type().lanes() << "\n";
        if (op->type.lanes() != op->value.type().lanes()) {
            std::cerr << "Got one\n";
            reinterprets++;
        }
        return Internal::IRMutator::visit(op);
    }

    Expr visit(const Internal::Call *op) override {
        if (op->is_intrinsic(Internal::Call::concat_bits)) {
            concats++;
        } else if (op->is_intrinsic(Internal::Call::extract_bits)) {
            extracts++;
        }
        return Internal::IRMutator::visit(op);
    }

public:
    int extracts = 0, concats = 0, reinterprets = 0;
};

int main(int argc, char **argv) {
    for (bool vectorize : {false, true}) {
        // Reinterpret an array of a wide type as a larger array of a smaller type
        Func f, g;
        Var x;

        f(x) = cast<uint32_t>(x);

        // Reinterpret to a narrower type.
        g(x) = extract_bits<uint8_t>(f(x / 4), 8 * (x % 4));

        f.compute_root();

        if (vectorize) {
            f.vectorize(x, 8);
            // The align_bounds directive is critical so that the x%4 term above collapses.
            g.align_bounds(x, 4).vectorize(x, 32);

            // An alternative to the align_bounds call:
            // g.output_buffer().dim(0).set_min(0);
        }

        CountOps counter;
        g.add_custom_lowering_pass(&counter, nullptr);

        Buffer<uint8_t> out = g.realize({1024});
        std::cerr << counter.extracts << " " << counter.reinterprets << " " << counter.concats << "\n";

        if (vectorize) {
            if (counter.extracts > 0) {
                printf("Saw an unwanted extract_bits call in lowered code\n");
                return 1;
            } else if (counter.reinterprets == 0) {
                printf("Did not see a vector reinterpret in lowered code\n");
                return 1;
            }
        }

        for (uint32_t i = 0; i < (uint32_t)out.width(); i++) {
            uint8_t correct = (i / 4) >> (8 * (i % 4));
            if (out(i) != correct) {
                printf("out(%d) = %d instead of %d\n", i, out(i), correct);
                return 1;
            }
        }
    }

    for (bool vectorize : {false, true}) {
        // Reinterpret an array of a narrow type as a smaller array of a wide type
        Func f, g;
        Var x;

        f(x) = cast<uint8_t>(x);

        g(x) = concat_bits({f(4 * x), f(4 * x + 1), f(4 * x + 2), f(4 * x + 3)});

        f.compute_root();

        if (vectorize) {
            f.vectorize(x, 32);
            g.vectorize(x, 8);
        }

        CountOps counter;
        g.add_custom_lowering_pass(&counter, nullptr);

        Buffer<uint32_t> out = g.realize({64});

        if (counter.concats > 0) {
            printf("Saw an unwanted concat_bits call in lowered code\n");
            return 1;
        } else if (counter.reinterprets == 0) {
            printf("Did not see a vector reinterpret in lowered code\n");
            return 1;
        }

        for (int i = 0; i < 64; i++) {
            for (int b = 0; b < 4; b++) {
                uint8_t correct = i * 4 + b;
                uint8_t result = (out(i) >> (b * 8)) & 0xff;
                if (result != correct) {
                    printf("out(%d) byte %d = %d instead of %d\n", i, b, result, correct);
                    return 1;
                }
            }
        }
    }

    // Also test cases that aren't expected to fold into reinterprets
    {
        Func f;
        Var x("x");
        f(x) = cast<uint16_t>(x);

        auto check = [&](const Expr &a, const Expr &b) {
            Func g;
            g(x) = cast<uint8_t>(a == b);
            Buffer<uint8_t> out = g.realize({1024});
            for (int i = 0; i < out.width(); i++) {
                if (out(i) == 0) {
                    std::cerr << "Mismatch between: " << a << " and " << b << " when x == " << i << "\n";
                    exit(1);
                }
            }
        };

        // concat_bits is little-endian
        check(concat_bits({f(x), cast<uint16_t>(37)}), cast<uint32_t>(f(x)) + (37 << 16));
        check(concat_bits({cast<uint16_t>(0), f(x), cast<uint16_t>(0), cast<uint16_t>(0)}), cast(UInt(64), f(x)) << 16);

        // extract_bits is equivalent to right shifting and then casting to a narrower type
        check(extract_bits<uint8_t>(f(x), 3), cast<uint8_t>(f(x) >> 3));

        // Extract bits zero-fills out-of-range bits
        check(extract_bits<uint16_t>(f(x), 3), f(x) >> 3);
        check(extract_bits<int16_t>(f(x), 8), (f(x) >> 8) & 0xff);
        check(extract_bits<uint8_t>(f(x), -1), cast<uint8_t>(f(x)) << 1);

        // MSB of the mantissa of an ieee float
        check(extract_bits<uint8_t>(cast<float>(f(x)), 15), cast<uint8_t>(reinterpret<uint32_t>(cast<float>(f(x))) >> 15));
    }

    printf("Success!\n");
    return 0;
}