File: jsd_op.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (98 lines) | stat: -rw-r--r-- 2,940 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include "caffe2/operators/jsd_op.h"

namespace caffe2 {

namespace {

static constexpr float kLOG_THRESHOLD() {
  return 1e-20;
}

inline float logit(float p) {
  // it computes log(p / (1-p))
  // to avoid numeric issue, hard code p log(p) when p approaches 0
  float x = std::min(std::max(p, kLOG_THRESHOLD()), 1 - kLOG_THRESHOLD());
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
  return -log(1. / x - 1.);
}

inline float entropy(float p) {
  if (p < kLOG_THRESHOLD() || 1 - p < kLOG_THRESHOLD()) {
    return 0.;
  } else {
    float q = 1 - p;
    return -p * log(p) - q * log(q);
  }
}
} // namespace

template <>
bool BernoulliJSDOp<float, CPUContext>::RunOnDevice() {
  auto& X = Input(0); // predicted probabilities
  auto& T = Input(1); // target probabilities
  int N = X.numel();
  CAFFE_ENFORCE_EQ(T.numel(), N);
  auto* L = Output(0, X.sizes(), at::dtype<float>()); // JSD loss output
  auto* x_data = X.data<float>();
  auto* t_data = T.data<float>();
  auto* l_data = L->template mutable_data<float>();
  for (int i = 0; i < N; i++) {
    auto p_mdl = x_data[i];
    auto p_emp = t_data[i];
    auto p_avg = (p_mdl + p_emp) / 2.;
    auto jsd = entropy(p_avg) - (entropy(p_mdl) + entropy(p_emp)) / 2.;
    l_data[i] = jsd;
  }
  return true;
}

template <>
bool BernoulliJSDGradientOp<float, CPUContext>::RunOnDevice() {
  auto& go = Input(0);
  auto& X = Input(1);
  auto& T = Input(2);

  int N = X.numel();
  auto* gi = Output(0, X.sizes(), at::dtype<float>());
  auto* go_data = go.data<float>();
  auto* x_data = X.data<float>();
  auto* t_data = T.data<float>();
  auto* gi_data = gi->template mutable_data<float>();
  for (int i = 0; i < N; i++) {
    auto p_mdl = x_data[i];
    auto p_emp = t_data[i];
    auto p_avg = (p_mdl + p_emp) / 2.;
    auto g_jsd = (logit(p_mdl) - logit(p_avg)) / 2.;
    gi_data[i] = go_data[i] * g_jsd;
  }
  return true;
}
REGISTER_CPU_OPERATOR(BernoulliJSD, BernoulliJSDOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
    BernoulliJSDGradient,
    BernoulliJSDGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(BernoulliJSD)
    .NumInputs(2)
    .NumOutputs(1)
    .SetDoc(R"DOC(
Computes the Jensen-Shannon divergence (JSD) between two Bernoulli distributions
where each is parametrized by a single probability.
)DOC")
    .Input(0, "X", "array of probabilities for prediction")
    .Input(0, "T", "array of probabilities for target")
    .Output(0, "L", "array of JSD losses");
OPERATOR_SCHEMA(BernoulliJSDGradient).NumInputs(3).NumOutputs(1);

class GetBernoulliJSDGradient : public GradientMakerBase {
  using GradientMakerBase::GradientMakerBase;
  vector<OperatorDef> GetGradientDefs() override {
    return SingleGradientDef(
        "BernoulliJSDGradient",
        "",
        vector<string>{GO(0), I(0), I(1)},
        vector<string>{GI(0)});
  }
};
REGISTER_GRADIENT(BernoulliJSD, GetBernoulliJSDGradient);

} // namespace caffe2