File: multi_class_accuracy_op.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (86 lines) | stat: -rw-r--r-- 2,415 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include "caffe2/operators/multi_class_accuracy_op.h"

namespace caffe2 {

template <>
bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
  auto& X = Input(PREDICTION);
  auto& label = Input(LABEL);

  TORCH_DCHECK_EQ(X.dim(), 2);
  // amount, number of instances
  int N = X.dim32(0);
  // dimension, number of classes
  int D = X.dim32(1);
  TORCH_DCHECK_EQ(label.dim(), 1);
  TORCH_DCHECK_EQ(label.dim32(0), N);
  auto* Y0 = Output(0, {D}, at::dtype<float>());
  auto* Y1 = Output(1, {D}, at::dtype<int>());

  const auto* Xdata = X.data<float>();
  const auto* labeldata = label.data<int>();
  auto* accuracies = Y0->template mutable_data<float>();
  auto* amounts = Y1->template mutable_data<int>();
  std::fill(accuracies, accuracies + D, 0);
  std::fill(amounts, amounts + D, 0);

  for (int i = 0; i < N; ++i) {
    float maxval = std::numeric_limits<float>::lowest();
    int maxid = 0;
    for (int j = 0; j < D; ++j) {
      if (Xdata[i * D + j] > maxval) {
        maxval = Xdata[i * D + j];
        maxid = j;
      }
    }
    int labelid = labeldata[i];
    TORCH_DCHECK_LT(labelid, D);
    if (maxid == labelid) {
      accuracies[labelid]++;
    }
    amounts[labelid]++;
  }

  for (int i = 0; i < D; ++i) {
    int amount = amounts[i];
    if (amount) {
      // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
      accuracies[i] /= amount;
    }
  }

  return true;
}

REGISTER_CPU_OPERATOR(
  MultiClassAccuracy, MultiClassAccuracyOp<float, CPUContext>);

OPERATOR_SCHEMA(MultiClassAccuracy)
  .NumInputs(2)
  .NumOutputs(2)
  .SetDoc(R"DOC(
Respectively compute accuracy score for each class given a number of instances
and predicted scores of each class for each instance.
)DOC")
  .Input(
    0,
    "prediction",
    "2-D float tensor (N,D,) of predicted scores of each class for "
    "each data. N is the number of instances, i.e., batch size. D is number of "
    "possible classes/labels.")
  .Input(
    1,
    "labels",
    "1-D int tensor (N,) of labels for each instance.")
  .Output(
    0,
    "accuracies",
    "1-D float tensor (D,) of accuracy for each class. If a class has no "
    "instance in the batch, its accuracy score is set to zero.")
  .Output(
    1,
    "amounts",
    "1-D int tensor (D,) of number of instances for each class in the batch.");

SHOULD_NOT_DO_GRADIENT(MultiClassAccuracy);
}  // namespace caffe2