File: abstractgeneratortest_generator.cpp

package info (click to toggle)
halide 21.0.0-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 55,752 kB
  • sloc: cpp: 289,334; ansic: 22,751; python: 7,486; makefile: 4,299; sh: 2,508; java: 1,549; javascript: 282; pascal: 207; xml: 127; asm: 9
file content (151 lines) | stat: -rw-r--r-- 4,720 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#include "Halide.h"

#include <map>
#include <sstream>
#include <string>
#include <vector>

using namespace Halide::Internal;

namespace Halide {
namespace {

// Note to reader: this test is meant as a simple way to verify that arbitrary
// implementations of AbstractGenerator work properly. That said, we recommend
// that you don't imitate this code; AbstractGenerator is an *internal*
// abtraction, intended for Halide to build on internally. If you use AbstractGenerator
// directly, you'll almost certainly have more work maintaining your code
// on your own.

const char *const AbstractGeneratorTestName = "abstractgeneratortest";

// We could use std::stoi() here, but we explicitly want to assert-fail
// if we can't parse the string as a valid int.
int string_to_int(const std::string &s) {
    std::istringstream iss(s);
    int i;
    iss >> i;
    _halide_user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << s;
    return i;
}

class AbstractGeneratorTest : public AbstractGenerator {
    // Boilerplate
    const GeneratorContext context_;

    // Constants (aka GeneratorParams)
    GeneratorParamsMap constants_ = {
        {"scaling", "2"},
    };

    // Inputs
    ImageParam input_{Int(32), 2, "input"};
    Param<int32_t> offset_{"offset"};

    // Outputs
    Func output_{"output"};

    // Misc
    Pipeline pipeline_;

public:
    explicit AbstractGeneratorTest(const GeneratorContext &context)
        : context_(context) {
    }

    std::string name() override {
        return AbstractGeneratorTestName;
    }

    GeneratorContext context() const override {
        return context_;
    }

    std::vector<ArgInfo> arginfos() override {
        return {
            {"input", ArgInfoDirection::Input, ArgInfoKind::Buffer, {Int(32)}, 2},
            {"offset", ArgInfoDirection::Input, ArgInfoKind::Scalar, {Int(32)}, 0},
            {"output", ArgInfoDirection::Output, ArgInfoKind::Buffer, {Int(32)}, 2},
        };
    }

    bool allow_out_of_order_inputs_and_outputs() const override {
        return false;
    }

    void set_generatorparam_value(const std::string &name, const std::string &value) override {
        _halide_user_assert(!pipeline_.defined());
        _halide_user_assert(constants_.count(name) == 1) << "Unknown Constant: " << name;
        constants_[name] = value;
    }

    void set_generatorparam_value(const std::string &name, const LoopLevel &value) override {
        _halide_user_assert(!pipeline_.defined());
        _halide_user_assert(constants_.count(name) == 1) << "Unknown Constant: " << name;
        _halide_user_assert(false) << "This Generator has no LoopLevel constants.";
    }

    Pipeline build_pipeline() override {
        _halide_user_assert(!pipeline_.defined());

        const int scaling = string_to_int(constants_.at("scaling"));

        Var x, y;
        output_(x, y) = input_(x, y) * scaling + offset_;
        output_.compute_root();

        pipeline_ = output_;
        return pipeline_;
    }

    std::vector<Parameter> input_parameter(const std::string &name) override {
        _halide_user_assert(pipeline_.defined());
        if (name == "input") {
            return {input_.parameter()};
        }
        if (name == "offset") {
            return {offset_.parameter()};
        }
        _halide_user_assert(false) << "Unknown input: " << name;
        return {};
    }

    std::vector<Func> output_func(const std::string &name) override {
        _halide_user_assert(pipeline_.defined());
        if (name == "output") {
            return {output_};
        }
        _halide_user_assert(false) << "Unknown output: " << name;
        return {};
    }

    void bind_input(const std::string &name, const std::vector<Parameter> &v) override {
        _halide_user_assert(false) << "OOPS";
    }

    void bind_input(const std::string &name, const std::vector<Func> &v) override {
        _halide_user_assert(false) << "OOPS";
    }

    void bind_input(const std::string &name, const std::vector<Expr> &v) override {
        _halide_user_assert(false) << "OOPS";
    }

    bool emit_cpp_stub(const std::string & /*stub_file_path*/) override {
        // not supported
        return false;
    }

    bool emit_hlpipe(const std::string & /*hlpipe_file_path*/) override {
        // not supported
        return false;
    }
};

RegisterGenerator register_something(AbstractGeneratorTestName,
                                     [](const GeneratorContext &context) -> AbstractGeneratorPtr {
                                         return std::unique_ptr<AbstractGeneratorTest>(new AbstractGeneratorTest(context));
                                     });

}  // namespace
}  // namespace Halide