1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
|
#include <gtest/gtest.h>
#include <ATen/ATen.h>
#include <torch/nativert/executor/Executor.h>
#include <torch/nativert/graph/Graph.h>
#include <torch/torch.h>
#include <torch/nativert/kernels/KernelHandlerRegistry.h>
namespace torch::nativert {
/*
* This is a lightweight version of ModelRunner that executes a model in
* interpreter mode given a string graph with no weights/attributes
*/
class SimpleTestModelRunner {
public:
SimpleTestModelRunner(
const std::string_view source,
const ExecutorConfig& config) {
register_kernel_handlers();
graph_ = stringToGraph(source);
weights_ = std::make_shared<Weights>(graph_.get());
executor_ = std::make_unique<Executor>(config, graph_, weights_);
}
std::vector<c10::IValue> run(const std::vector<c10::IValue>& inputs) const {
return executor_->execute(inputs);
}
ProfileMetrics benchmarkIndividualNodes(
const std::vector<c10::IValue>& inputs) const {
return executor_->benchmarkIndividualNodes({inputs}, 10, 10);
}
private:
std::shared_ptr<Graph> graph_;
std::unique_ptr<Executor> executor_;
std::shared_ptr<Weights> weights_;
};
inline void compareIValue(
const c10::IValue& expected,
const c10::IValue& actual,
bool native = false) {
if (expected.isTensor()) {
EXPECT_TRUE(actual.isTensor());
EXPECT_TRUE(torch::allclose(
expected.toTensor(),
actual.toTensor(),
1e-5,
1e-8,
/*equal_nan*/ true));
if (!native) {
EXPECT_TRUE(expected.toTensor().strides() == actual.toTensor().strides());
}
} else if (expected.isTuple()) {
EXPECT_TRUE(actual.isTuple());
auto expected_tuple = expected.toTupleRef().elements();
auto actual_tuple = actual.toTupleRef().elements();
ASSERT_TRUE(expected_tuple.size() == actual_tuple.size());
for (size_t i = 0; i < expected_tuple.size(); i++) {
compareIValue(expected_tuple[i], actual_tuple[i], native);
}
} else if (expected.isList()) {
EXPECT_TRUE(actual.isList());
auto expected_list = expected.toList();
auto actual_list = actual.toList();
ASSERT_TRUE(expected_list.size() == actual_list.size());
for (size_t i = 0; i < expected_list.size(); i++) {
compareIValue(expected_list[i], actual_list[i], native);
}
} else if (expected.isGenericDict()) {
EXPECT_TRUE(actual.isGenericDict());
auto expected_dict = expected.toGenericDict();
auto actual_dict = actual.toGenericDict();
EXPECT_TRUE(expected_dict.size() == actual_dict.size());
for (auto& expected_kv : expected_dict) {
auto actual_kv = actual_dict.find(expected_kv.key());
ASSERT_FALSE(actual_kv == actual_dict.end());
compareIValue(expected_kv.value(), actual_kv->value(), native);
}
} else {
// Fall back to default comparison from IValue
EXPECT_TRUE(expected == actual);
}
}
void compareIValues(
std::vector<c10::IValue> expected,
std::vector<c10::IValue> actual,
bool native = false) {
ASSERT_TRUE(expected.size() == actual.size());
for (size_t i = 0; i < expected.size(); i++) {
compareIValue(expected[i], actual[i], native);
}
}
inline void testStaticKernelEqualityInternal(
const SimpleTestModelRunner& modelRunner,
const SimpleTestModelRunner& staticModelRunner,
const std::vector<c10::IValue>& args,
bool native = false) {
auto expected = modelRunner.run(args);
auto output = staticModelRunner.run(args);
compareIValues(expected, output, native);
// Run again to test the static kernel when outputs IValue are cached in the
// execution frame
auto output2 = staticModelRunner.run(args);
compareIValues(expected, output2, native);
}
void testStaticKernelEquality(
const std::string_view source,
const std::vector<c10::IValue>& args,
bool native = false) {
ExecutorConfig config;
config.enableStaticCPUKernels = false;
SimpleTestModelRunner model(source, config);
config.enableStaticCPUKernels = true;
SimpleTestModelRunner staticKernelModel(source, config);
testStaticKernelEqualityInternal(model, staticKernelModel, args, native);
}
inline void testGraphABEquality(
const std::string_view graph_a,
const std::string_view graph_b,
const std::vector<c10::IValue>& args,
const ExecutorConfig& config = {},
bool native = false) {
SimpleTestModelRunner model_a(graph_a, config);
auto expected = model_a.run(args);
SimpleTestModelRunner model_b(graph_b, config);
auto output = model_b.run(args);
compareIValues(expected, output, native);
}
inline void testGraphABPerf(
const std::string_view graph_a,
const std::string_view graph_b,
const std::vector<c10::IValue>& args,
const ExecutorConfig& config = {}) {
SimpleTestModelRunner model_a(graph_a, config);
auto resultA = model_a.benchmarkIndividualNodes(args);
SimpleTestModelRunner model_b(graph_b, config);
auto resultB = model_b.benchmarkIndividualNodes(args);
ASSERT_TRUE(resultA.totalTime > resultB.totalTime);
}
} // namespace torch::nativert
|