1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
|
#include <memory>
#include <vector>
#include <gtest/gtest.h>
#include "caffe2/utils/cast.h"
namespace caffe2 {
TEST(CastTest, GetCastDataType) {
auto castOp = [](std::string t) {
// Ensure lowercase.
std::transform(t.begin(), t.end(), t.begin(), ::tolower);
auto op = CreateOperatorDef("Cast", "", {}, {});
AddArgument("to", t, &op);
return op;
};
#define X(t) \
EXPECT_EQ( \
TensorProto_DataType_##t, \
cast::GetCastDataType(ArgumentHelper(castOp(#t)), "to"));
X(FLOAT);
X(INT32);
X(BYTE);
X(STRING);
X(BOOL);
X(UINT8);
X(INT8);
X(UINT16);
X(INT16);
X(INT64);
X(FLOAT16);
X(DOUBLE);
#undef X
}
} // namespace caffe2
|