File: configure_generator.cpp

package info (click to toggle)
halide 14.0.0-3
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 49,124 kB
  • sloc: cpp: 238,722; makefile: 4,303; python: 4,047; java: 1,575; sh: 1,384; pascal: 211; xml: 165; javascript: 43; ansic: 34
file content (109 lines) | stat: -rw-r--r-- 3,881 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include "Halide.h"

namespace {

class Configure : public Halide::Generator<Configure> {
public:
    GeneratorParam<int> num_extra_buffer_inputs{"num_extra_buffer_inputs", 3};

    Input<Buffer<>> input{"input"};
    Input<int> bias{"bias"};

    Output<Buffer<>> output{"output"};

    void configure() {
        configure_calls++;

        // It's fine to examine GeneratorParams in the configure() method.
        assert(num_extra_buffer_inputs == 3);

        // Pointers returned by add_input() are managed by the Generator;
        // user code must not free them. We can stash them in member variables
        // as-is or in containers, like so:
        for (int i = 0; i < num_extra_buffer_inputs; ++i) {
            auto *extra = add_input<Buffer<uint8_t, 2>>("extra_" + std::to_string(i));
            extra_buffer_inputs.push_back(extra);
        }

        typed_extra_buffer_input = add_input<Buffer<int16_t, 2>>("typed_extra_buffer_input");

        extra_func_input = add_input<Func>("extra_func_input", UInt(16), 3);

        extra_scalar_input = add_input<int>("extra_scalar_input");

        extra_dynamic_scalar_input = add_input<Expr>("extra_dynamic_scalar_input", Int(8));

        extra_buffer_output = add_output<Buffer<float, 3>>("extra_buffer_output");

        extra_func_output = add_output<Func>("extra_func_output", Float(64), 2);

        // This is ok: you can't *examine* an Input or Output here, but you can call
        // set_type() iff the type is unspecified. (This allows you to base the type on,
        // e.g., the value in get_target(), or the value of any GeneratorParam.)
        input.set_type(Int(32));
        output.set_type(Int(32));

        // Ditto for set_dimensions.
        input.set_dimensions(3);
        output.set_dimensions(3);

        // Will fail: it is not legal to call set_type on an Input or Output that
        // already has a type specified.
        // bias.set_type(Int(32));

        // Will fail: it is not legal to examine Inputs in the configure() method
        // assert(input.dimensions() == 3);

        // Will fail: it is not legal to examine Inputs in the configure() method
        // Expr b = bias;
        // assert(b.defined());

        // Will fail: it is not legal to examine Outputs in the configure() method
        // Func o = output;
        // assert(output.defined());
    }

    void generate() {
        assert(configure_calls == 1);

        // Will fail: it is not legal to call set_type(), etc from anywhere but configure().
        // input.set_type(Int(32));
        // input.set_dimensions(3);

        // Attempting to call add_input() outside of the configure method will fail.
        // auto *this_will_fail = add_input<Buffer<>>("untyped_uint8", UInt(8), 2);

        assert((*extra_dynamic_scalar_input).type() == Int(8));

        Var x, y, c;

        Expr extra_sum = cast<int>(0);
        for (int i = 0; i < num_extra_buffer_inputs; ++i) {
            extra_sum += cast<int>((*extra_buffer_inputs[i])(x, y));
        }
        extra_sum += cast<int>((*typed_extra_buffer_input)(x, y));
        extra_sum += cast<int>((*extra_func_input)(x, y, c));
        extra_sum += *extra_scalar_input + *extra_dynamic_scalar_input;

        output(x, y, c) = input(x, y, c) + bias + extra_sum;

        (*extra_buffer_output)(x, y, c) = cast<float>(output(x, y, c));
        (*extra_func_output)(x, y) = cast<double>(output(x, y, 0));
    }

private:
    int configure_calls = 0;

    std::vector<Input<Buffer<uint8_t, 2>> *> extra_buffer_inputs;
    Input<Buffer<int16_t, 2>> *typed_extra_buffer_input;
    Input<Func> *extra_func_input;
    Input<int> *extra_scalar_input;
    Input<Expr> *extra_dynamic_scalar_input;

    Output<Buffer<float, 3>> *extra_buffer_output;
    Output<Func> *extra_func_output;
};

}  // namespace

HALIDE_REGISTER_GENERATOR(Configure, configure)