File: batch_moments_op.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (122 lines) | stat: -rw-r--r-- 3,603 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include "caffe2/operators/batch_moments_op.h"

#include <string>
#include <vector>

#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"

namespace caffe2 {

template <>
bool BatchMomentsOp<float, CPUContext>::ComputeBatchMomentsNCHW(
    const int N,
    const int C,
    const int HxW,
    const float* X,
    float* mu,
    float* var) {
  math::Set<float, CPUContext>(C, 0.0f, mu, &context_);
  math::Set<float, CPUContext>(C, 0.0f, var, &context_);
  EigenVectorArrayMap<float> mu_arr(mu, C);
  EigenVectorArrayMap<float> var_arr(var, C);
  const float* X_ptr = X;
  const int stride = C * HxW;
  for (int i = 0; i < N; ++i) {
    ConstEigenArrayMap<float> X_arr(X_ptr, HxW, C);
    mu_arr += X_arr.colwise().sum();
    var_arr += X_arr.square().colwise().sum();
    X_ptr += stride;
  }
  const float scale = 1.0f / static_cast<float>(N * HxW);
  math::Scale<float, float, CPUContext>(C, scale, mu, mu, &context_);
  math::Scale<float, float, CPUContext>(C, scale, var, var, &context_);
  return true;
}

template <>
bool BatchMomentsOp<float, CPUContext>::ComputeBatchMomentsNHWC(
    const int N,
    const int C,
    const int HxW,
    const float* X,
    float* mu,
    float* var) {
  ConstEigenArrayMap<float> X_arr(X, C, N * HxW);
  EigenVectorMap<float>(mu, C) = X_arr.rowwise().mean();
  EigenVectorMap<float>(var, C) = X_arr.square().rowwise().mean();
  return true;
}

template <>
bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNCHW(
    const int N,
    const int C,
    const int HxW,
    const float* dmu,
    const float* dvar,
    const float* X,
    float* dX) {
  ConstEigenVectorArrayMap<float> dmu_arr(dmu, C);
  ConstEigenVectorArrayMap<float> dvar_arr(dvar, C);
  const float* X_ptr = X;
  float* dX_ptr = dX;
  const int stride = C * HxW;
  for (int i = 0; i < N; ++i) {
    EigenArrayMap<float> dX_arr(dX_ptr, HxW, C);
    dX_arr = ConstEigenArrayMap<float>(X_ptr, HxW, C).rowwise() *
        dvar_arr.transpose() * 2.0f;
    dX_arr.rowwise() += dmu_arr.transpose();
    X_ptr += stride;
    dX_ptr += stride;
  }
  const float scale = 1.0f / static_cast<float>(N * HxW);
  math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
  return true;
}

template <>
bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNHWC(
    const int N,
    const int C,
    const int HxW,
    const float* dmu,
    const float* dvar,
    const float* X,
    float* dX) {
  const float scale = 1.0f / static_cast<float>(N * HxW);
  EigenArrayMap<float> dX_arr(dX, C, N * HxW);
  dX_arr = ConstEigenArrayMap<float>(X, C, N * HxW).colwise() *
      ConstEigenVectorArrayMap<float>(dvar, C) * 2.0f;
  dX_arr.colwise() += ConstEigenVectorArrayMap<float>(dmu, C);
  math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
  return true;
}

REGISTER_CPU_OPERATOR(BatchMoments, BatchMomentsOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(
    BatchMomentsGradient,
    BatchMomentsGradientOp<float, CPUContext>);

OPERATOR_SCHEMA(BatchMoments).NumInputs(1).NumOutputs(2);
OPERATOR_SCHEMA(BatchMomentsGradient).NumInputs(3).NumOutputs(1);

namespace {

class GetBatchMomentsGradient : public GradientMakerBase {
  using GradientMakerBase::GradientMakerBase;

  std::vector<OperatorDef> GetGradientDefs() override {
    return SingleGradientDef(
        "BatchMomentsGradient",
        "",
        std::vector<std::string>{GO(0), GO(1), I(0)},
        std::vector<std::string>{GI(0)});
  }
};

} // namespace

REGISTER_GRADIENT(BatchMoments, GetBatchMomentsGradient);

} // namespace caffe2