1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
#include <caffe2/core/common.h>
#include <caffe2/core/test_utils.h>
#include <caffe2/core/workspace.h>
#include <caffe2/opt/onnxifi_transformer.h>
#include <caffe2/utils/proto_utils.h>
#include <gtest/gtest.h>
using namespace caffe2::testing;
using namespace caffe2;
namespace {
NetDef createTest(
const std::string& op_type,
Workspace* ws,
bool has_weight,
bool fallback) {
NetDef net;
std::vector<std::string> inputs{
"Data", "Weight", "Idx", "Lengths", "Compressed"};
if (!has_weight) {
inputs = {"Data", "Idx", "Lengths", "Compressed"};
}
NetMutator(&net).newOp(op_type, inputs, {"Out"});
auto* b = ws->CreateBlob("Compressed");
auto* t = BlobGetMutableTensor(b, {1}, at::dtype<int32_t>());
auto* comp = t->template mutable_data<int32_t>();
*comp = fallback ? 0 : 1;
return net;
}
void check(
const NetDef& net,
const std::string& op_type,
bool has_weight,
bool fallback) {
const static std::unordered_map<string, string> slss = {
{"SparseLengthsSum4BitRowwiseSparse", "SparseLengthsSumFused4BitRowwise"},
{"SparseLengthsWeightedSum4BitRowwiseSparse",
"SparseLengthsWeightedSumFused4BitRowwise"},
{"SparseLengthsSum8BitRowwiseSparse", "SparseLengthsSumFused8BitRowwise"},
{"SparseLengthsWeightedSum8BitRowwiseSparse",
"SparseLengthsWeightedSumFused8BitRowwise"},
{"SparseLengthsSum2BitRowwiseSparse", "SparseLengthsSumFused2BitRowwise"},
{"SparseLengthsWeightedSum2BitRowwiseSparse",
"SparseLengthsWeightedSumFused2BitRowwise"}};
if (fallback) {
EXPECT_EQ(net.op_size(), 1);
EXPECT_EQ(net.op(0).type(), slss.at(op_type));
EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3);
EXPECT_EQ(net.op(0).output_size(), 1);
EXPECT_EQ(net.op(0).input(0), "Data");
EXPECT_EQ(net.op(0).input(has_weight ? 2 : 1), "Idx");
EXPECT_EQ(net.op(0).input(has_weight ? 3 : 2), "Lengths");
if (has_weight) {
EXPECT_EQ(net.op(0).input(1), "Weight");
}
EXPECT_EQ(net.op(0).output(0), "Out");
} else {
EXPECT_EQ(net.op_size(), 2);
EXPECT_EQ(net.op(0).type(), "SparseLengthsSumSparseLookup");
EXPECT_EQ(net.op(0).input_size(), has_weight ? 4 : 3);
EXPECT_EQ(net.op(0).output_size(), has_weight ? 3 : 2);
EXPECT_EQ(net.op(0).input(0), "Idx");
EXPECT_EQ(net.op(0).input(1), "Lengths");
EXPECT_EQ(net.op(0).input(2), "Compressed");
EXPECT_EQ(net.op(0).output(0), "Idx_decomp");
EXPECT_EQ(net.op(0).output(1), "Lengths_decomp");
if (has_weight) {
EXPECT_EQ(net.op(0).input(3), "Weight");
EXPECT_EQ(net.op(0).output(2), "Weight_decomp");
}
EXPECT_EQ(net.op(1).type(), slss.at(op_type));
EXPECT_EQ(net.op(1).input_size(), has_weight ? 4 : 3);
EXPECT_EQ(net.op(1).output_size(), 1);
EXPECT_EQ(net.op(1).input(0), "Data");
EXPECT_EQ(net.op(1).input(has_weight ? 2 : 1), "Idx_decomp");
EXPECT_EQ(net.op(1).input(has_weight ? 3 : 2), "Lengths_decomp");
if (has_weight) {
EXPECT_EQ(net.op(1).input(1), "Weight_decomp");
}
EXPECT_EQ(net.op(1).output(0), "Out");
}
}
} // namespace
TEST(splitSparseLengthsSumSparse, sweep) {
std::vector<bool> has_weights = {true, false};
std::vector<bool> fallbacks = {true, false};
std::vector<int> bits = {2, 4, 8};
for (const auto has_weight : has_weights) {
for (const auto bit : bits) {
std::string op_type = "SparseLengths";
op_type += (has_weight ? "WeightedSum" : "Sum");
op_type += caffe2::to_string(bit);
op_type += "BitRowwiseSparse";
for (const auto fallback : fallbacks) {
Workspace ws;
auto net = createTest(op_type, &ws, has_weight, fallback);
splitSparseLengthsSumSparse(&net, ws);
check(net, op_type, has_weight, fallback);
}
}
}
}
|