File: autograd_aottest.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (153 lines) | stat: -rw-r--r-- 5,875 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#include "HalideBuffer.h"
#include "HalideRuntime.h"

#include <cmath>
#include <cstdio>

#include "autograd.h"
#include "autograd_grad.h"

using namespace Halide::Runtime;

constexpr int kSize = 64;

int main(int argc, char **argv) {
    int result;

    auto f = [](float a, float b, float c) -> float {
        return 33.f * std::pow(a, 3) +
               22.f * std::pow(b, 2) +
               11.f * c +
               1.f;
    };

    Buffer<float, 1> a(kSize);
    Buffer<float, 1> b(kSize);
    Buffer<float, 1> c(kSize);
    Buffer<float, 1> out(kSize);

    a.for_each_element([&](int x) { a(x) = (float)x; });
    b.for_each_element([&](int x) { b(x) = (float)x; });
    c.for_each_element([&](int x) { c(x) = (float)x; });

    Buffer<uint8_t, 1> lut(256);
    Buffer<uint8_t, 1> lut_indices(kSize);
    Buffer<uint8_t, 1> out_lut(kSize);
    lut.for_each_element([&](int x) { lut(x) = (uint8_t)(x ^ 0xAA); });
    lut_indices.for_each_element([&](int x) { lut_indices(x) = x * 2; });

    result = autograd(a, b, c, lut, lut_indices, out, out_lut);
    if (result != 0) {
        exit(1);
    }
    out.for_each_element([&](int x) {
        float expected = f(a(x), b(x), c(x));
        float actual = out(x);
        assert(expected == actual);
    });
    out_lut.for_each_element([&](int x) {
        uint8_t expected = (uint8_t)(x * 2) ^ 0xAA;
        uint8_t actual = out_lut(x);
        assert(expected == actual);
    });

    Buffer<float, 1> L(kSize);
    L.for_each_element([&](int x) { L(x) = (float)(x - kSize / 2); });

    /*
        The gradient version should have the following args (in this order):
        Inputs:
            input_a
            input_b
            input_c
            lut
            lut_indices
            _grad_loss_for_output     (synthesized)
            _grad_loss_for_output_lut (synthesized)
        Outputs:
            _grad_loss_output_wrt_input_a
            _grad_loss_output_wrt_input_b
            _grad_loss_output_wrt_input_c
            _dummy_grad_loss_output_wrt_lut
            _dummy_grad_loss_output_wrt_lut_indices
            _dummy_grad_loss_output_lut_wrt_input_a
            _dummy_grad_loss_output_lut_wrt_input_b
            _dummy_grad_loss_output_lut_wrt_input_c
            _grad_loss_output_lut_wrt_lut
            _grad_loss_output_lut_wrt_lut_indices

        Note that the outputs with "_dummy" prefixes are placeholder
        outputs that are always filled with zeros; in those cases,
        there is no derivative for the output/input pairing, but we
        emit an output nevertheless so that the function signature
        is always mechanically predictable from the list of inputs and outputs.
    */

    Buffer<float, 1> grad_loss_out_wrt_a(kSize);
    Buffer<float, 1> grad_loss_out_wrt_b(kSize);
    Buffer<float, 1> grad_loss_out_wrt_c(kSize);
    Buffer<float, 1> dummy_grad_loss_output_wrt_lut(kSize);
    Buffer<float, 1> dummy_grad_loss_output_wrt_lut_indices(kSize);
    Buffer<float, 1> dummy_grad_loss_output_lut_wrt_input_a(kSize);
    Buffer<float, 1> dummy_grad_loss_output_lut_wrt_input_b(kSize);
    Buffer<float, 1> dummy_grad_loss_output_lut_wrt_input_c(kSize);
    Buffer<uint8_t, 1> grad_loss_output_lut_wrt_lut(kSize);
    Buffer<uint8_t, 1> grad_loss_output_lut_wrt_lut_indices(kSize);

    result = autograd_grad(/*inputs*/ a, b, c, lut, lut_indices, L, L,
                           /*outputs*/
                           grad_loss_out_wrt_a,
                           grad_loss_out_wrt_b,
                           grad_loss_out_wrt_c,
                           dummy_grad_loss_output_wrt_lut,
                           dummy_grad_loss_output_wrt_lut_indices,
                           dummy_grad_loss_output_lut_wrt_input_a,
                           dummy_grad_loss_output_lut_wrt_input_b,
                           dummy_grad_loss_output_lut_wrt_input_c,
                           grad_loss_output_lut_wrt_lut,
                           grad_loss_output_lut_wrt_lut_indices);
    if (result != 0) {
        exit(1);
    }

    // Although the values are float, all should be exact results,
    // so we don't need to worry about comparing vs. an epsilon
    grad_loss_out_wrt_a.for_each_element([&](int x) {
        // ∂𝐿/∂a = 3a^2 * 33 * L
        float expected = L(x) * std::pow(a(x), 2) * 3.f * 33.f;
        float actual = grad_loss_out_wrt_a(x);
        assert(expected == actual);
    });
    grad_loss_out_wrt_b.for_each_element([&](int x) {
        // ∂𝐿/∂b = b * 44 * L
        float expected = L(x) * b(x) * 44.f;
        float actual = grad_loss_out_wrt_b(x);
        assert(expected == actual);
    });
    grad_loss_out_wrt_c.for_each_element([&](int x) {
        // ∂𝐿/∂c = 11 * L
        float expected = L(x) * 11.f;
        float actual = grad_loss_out_wrt_c(x);
        assert(expected == actual);
    });
    dummy_grad_loss_output_wrt_lut.for_each_value([](float f) { assert(f == 0.f); });
    dummy_grad_loss_output_wrt_lut_indices.for_each_value([](float f) { assert(f == 0.f); });
    dummy_grad_loss_output_lut_wrt_input_a.for_each_value([](float f) { assert(f == 0.f); });
    dummy_grad_loss_output_lut_wrt_input_b.for_each_value([](float f) { assert(f == 0.f); });
    dummy_grad_loss_output_lut_wrt_input_c.for_each_value([](float f) { assert(f == 0.f); });
    grad_loss_output_lut_wrt_lut.for_each_element([&](int x) {
        // TODO: is zero really expected?
        uint8_t expected = 0;
        uint8_t actual = grad_loss_output_lut_wrt_lut(x);
        assert(expected == actual);
    });
    grad_loss_output_lut_wrt_lut_indices.for_each_element([&](int x) {
        // TODO: is zero really expected?
        uint8_t expected = 0;
        uint8_t actual = grad_loss_output_lut_wrt_lut_indices(x);
        assert(expected == actual);
    });

    printf("Success!\n");
    return 0;
}