1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
#include "Halide.h"
#include <map>
#include <sstream>
#include <string>
#include <vector>
using namespace Halide::Internal;
namespace Halide {
namespace {
// Note to reader: this test is meant as a simple way to verify that arbitrary
// implementations of AbstractGenerator work properly. That said, we recommend
// that you don't imitate this code; AbstractGenerator is an *internal*
// abtraction, intended for Halide to build on internally. If you use AbstractGenerator
// directly, you'll almost certainly have more work maintaining your code
// on your own.
const char *const AbstractGeneratorTestName = "abstractgeneratortest";
// We could use std::stoi() here, but we explicitly want to assert-fail
// if we can't parse the string as a valid int.
int string_to_int(const std::string &s) {
std::istringstream iss(s);
int i;
iss >> i;
_halide_user_assert(!iss.fail() && iss.get() == EOF) << "Unable to parse: " << s;
return i;
}
class AbstractGeneratorTest : public AbstractGenerator {
// Boilerplate
const GeneratorContext context_;
// Constants (aka GeneratorParams)
GeneratorParamsMap constants_ = {
{"scaling", "2"},
};
// Inputs
ImageParam input_{Int(32), 2, "input"};
Param<int32_t> offset_{"offset"};
// Outputs
Func output_{"output"};
// Misc
Pipeline pipeline_;
public:
explicit AbstractGeneratorTest(const GeneratorContext &context)
: context_(context) {
}
std::string name() override {
return AbstractGeneratorTestName;
}
GeneratorContext context() const override {
return context_;
}
std::vector<ArgInfo> arginfos() override {
return {
{"input", ArgInfoDirection::Input, ArgInfoKind::Buffer, {Int(32)}, 2},
{"offset", ArgInfoDirection::Input, ArgInfoKind::Scalar, {Int(32)}, 0},
{"output", ArgInfoDirection::Output, ArgInfoKind::Buffer, {Int(32)}, 2},
};
}
bool allow_out_of_order_inputs_and_outputs() const override {
return false;
}
void set_generatorparam_value(const std::string &name, const std::string &value) override {
_halide_user_assert(!pipeline_.defined());
_halide_user_assert(constants_.count(name) == 1) << "Unknown Constant: " << name;
constants_[name] = value;
}
void set_generatorparam_value(const std::string &name, const LoopLevel &value) override {
_halide_user_assert(!pipeline_.defined());
_halide_user_assert(constants_.count(name) == 1) << "Unknown Constant: " << name;
_halide_user_assert(false) << "This Generator has no LoopLevel constants.";
}
Pipeline build_pipeline() override {
_halide_user_assert(!pipeline_.defined());
const int scaling = string_to_int(constants_.at("scaling"));
Var x, y;
output_(x, y) = input_(x, y) * scaling + offset_;
output_.compute_root();
pipeline_ = output_;
return pipeline_;
}
std::vector<Parameter> input_parameter(const std::string &name) override {
_halide_user_assert(pipeline_.defined());
if (name == "input") {
return {input_.parameter()};
}
if (name == "offset") {
return {offset_.parameter()};
}
_halide_user_assert(false) << "Unknown input: " << name;
return {};
}
std::vector<Func> output_func(const std::string &name) override {
_halide_user_assert(pipeline_.defined());
if (name == "output") {
return {output_};
}
_halide_user_assert(false) << "Unknown output: " << name;
return {};
}
void bind_input(const std::string &name, const std::vector<Parameter> &v) override {
_halide_user_assert(false) << "OOPS";
}
void bind_input(const std::string &name, const std::vector<Func> &v) override {
_halide_user_assert(false) << "OOPS";
}
void bind_input(const std::string &name, const std::vector<Expr> &v) override {
_halide_user_assert(false) << "OOPS";
}
bool emit_cpp_stub(const std::string & /*stub_file_path*/) override {
// not supported
return false;
}
bool emit_hlpipe(const std::string & /*hlpipe_file_path*/) override {
// not supported
return false;
}
};
RegisterGenerator register_something(AbstractGeneratorTestName,
[](const GeneratorContext &context) -> AbstractGeneratorPtr {
return std::unique_ptr<AbstractGeneratorTest>(new AbstractGeneratorTest(context));
});
} // namespace
} // namespace Halide
|