1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
#include "caffe2/operators/multi_class_accuracy_op.h"
namespace caffe2 {
template <>
bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
auto& X = Input(PREDICTION);
auto& label = Input(LABEL);
TORCH_DCHECK_EQ(X.dim(), 2);
// amount, number of instances
int N = X.dim32(0);
// dimension, number of classes
int D = X.dim32(1);
TORCH_DCHECK_EQ(label.dim(), 1);
TORCH_DCHECK_EQ(label.dim32(0), N);
auto* Y0 = Output(0, {D}, at::dtype<float>());
auto* Y1 = Output(1, {D}, at::dtype<int>());
const auto* Xdata = X.data<float>();
const auto* labeldata = label.data<int>();
auto* accuracies = Y0->template mutable_data<float>();
auto* amounts = Y1->template mutable_data<int>();
std::fill(accuracies, accuracies + D, 0);
std::fill(amounts, amounts + D, 0);
for (int i = 0; i < N; ++i) {
float maxval = std::numeric_limits<float>::lowest();
int maxid = 0;
for (int j = 0; j < D; ++j) {
if (Xdata[i * D + j] > maxval) {
maxval = Xdata[i * D + j];
maxid = j;
}
}
int labelid = labeldata[i];
TORCH_DCHECK_LT(labelid, D);
if (maxid == labelid) {
accuracies[labelid]++;
}
amounts[labelid]++;
}
for (int i = 0; i < D; ++i) {
int amount = amounts[i];
if (amount) {
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
accuracies[i] /= amount;
}
}
return true;
}
REGISTER_CPU_OPERATOR(
MultiClassAccuracy, MultiClassAccuracyOp<float, CPUContext>);
OPERATOR_SCHEMA(MultiClassAccuracy)
.NumInputs(2)
.NumOutputs(2)
.SetDoc(R"DOC(
Respectively compute accuracy score for each class given a number of instances
and predicted scores of each class for each instance.
)DOC")
.Input(
0,
"prediction",
"2-D float tensor (N,D,) of predicted scores of each class for "
"each data. N is the number of instances, i.e., batch size. D is number of "
"possible classes/labels.")
.Input(
1,
"labels",
"1-D int tensor (N,) of labels for each instance.")
.Output(
0,
"accuracies",
"1-D float tensor (D,) of accuracy for each class. If a class has no "
"instance in the batch, its accuracy score is set to zero.")
.Output(
1,
"amounts",
"1-D int tensor (D,) of number of instances for each class in the batch.");
SHOULD_NOT_DO_GRADIENT(MultiClassAccuracy);
} // namespace caffe2
|