1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
#include "caffe2/core/tensor.h"
#include "caffe2/core/workspace.h"
#include "test_utils.h"
namespace {
template <typename T>
void assertTensorEqualsWithType(
const caffe2::TensorCPU& tensor1,
const caffe2::TensorCPU& tensor2,
float /* unused */) {
CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
for (auto idx = 0; idx < tensor1.numel(); ++idx) {
CAFFE_ENFORCE_EQ(
tensor1.data<T>()[idx],
tensor2.data<T>()[idx],
"Mismatch at index ",
idx);
}
}
template <>
void assertTensorEqualsWithType<float>(
const caffe2::TensorCPU& tensor1,
const caffe2::TensorCPU& tensor2,
float eps) {
CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
for (auto idx = 0; idx < tensor1.numel(); ++idx) {
// When a == b, a - b may not be equal to 0
if (tensor1.data<float>()[idx] != tensor2.data<float>()[idx]) {
CAFFE_ENFORCE_LT(
fabs(tensor1.data<float>()[idx] - tensor2.data<float>()[idx]),
eps,
"Mismatch at index ",
idx,
" exceeds threshold of ",
eps);
}
}
}
} // namespace
namespace caffe2 {
namespace testing {
// Asserts that two float values are close within epsilon.
void assertNear(float value1, float value2, float epsilon) {
// These two enforces will give good debug messages.
CAFFE_ENFORCE_LE(value1, value2 + epsilon);
CAFFE_ENFORCE_GE(value1, value2 - epsilon);
}
void assertTensorEquals(
const TensorCPU& tensor1,
const TensorCPU& tensor2,
float eps) {
CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
if (tensor1.IsType<float>()) {
CAFFE_ENFORCE(tensor2.IsType<float>());
assertTensorEqualsWithType<float>(tensor1, tensor2, eps);
} else if (tensor1.IsType<int>()) {
CAFFE_ENFORCE(tensor2.IsType<int>());
assertTensorEqualsWithType<int>(tensor1, tensor2, eps);
} else if (tensor1.IsType<int64_t>()) {
CAFFE_ENFORCE(tensor2.IsType<int64_t>());
assertTensorEqualsWithType<int64_t>(tensor1, tensor2, eps);
}
// Add more types if needed.
}
void assertTensorListEquals(
const std::vector<std::string>& tensorNames,
const Workspace& workspace1,
const Workspace& workspace2) {
for (const std::string& tensorName : tensorNames) {
CAFFE_ENFORCE(workspace1.HasBlob(tensorName));
CAFFE_ENFORCE(workspace2.HasBlob(tensorName));
auto& tensor1 = getTensor(workspace1, tensorName);
auto& tensor2 = getTensor(workspace2, tensorName);
assertTensorEquals(tensor1, tensor2);
}
}
const caffe2::Tensor& getTensor(
const caffe2::Workspace& workspace,
const std::string& name) {
CAFFE_ENFORCE(workspace.HasBlob(name));
return workspace.GetBlob(name)->Get<caffe2::Tensor>();
}
caffe2::Tensor* createTensor(
const std::string& name,
caffe2::Workspace* workspace) {
return BlobGetMutableTensor(workspace->CreateBlob(name), caffe2::CPU);
}
caffe2::OperatorDef* createOperator(
const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
caffe2::NetDef* net) {
auto* op = net->add_op();
op->set_type(type);
for (const auto& in : inputs) {
op->add_input(in);
}
for (const auto& out : outputs) {
op->add_output(out);
}
return op;
}
NetMutator& NetMutator::newOp(
const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
lastCreatedOp_ = createOperator(type, inputs, outputs, net_);
return *this;
}
NetMutator& NetMutator::externalInputs(
const std::vector<std::string>& externalInputs) {
for (auto& blob : externalInputs) {
net_->add_external_input(blob);
}
return *this;
}
NetMutator& NetMutator::externalOutputs(
const std::vector<std::string>& externalOutputs) {
for (auto& blob : externalOutputs) {
net_->add_external_output(blob);
}
return *this;
}
NetMutator& NetMutator::setDeviceOptionName(const std::string& name) {
CAFFE_ENFORCE(lastCreatedOp_ != nullptr);
lastCreatedOp_->mutable_device_option()->set_node_name(name);
return *this;
}
} // namespace testing
} // namespace caffe2
|