File: norm.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (68 lines) | stat: -rw-r--r-- 2,092 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <torch/csrc/jit/tensorexpr/operators/misc.h>
#include <torch/csrc/jit/tensorexpr/operators/norm.h>

namespace torch::jit::tensorexpr {

Tensor computeBatchNorm(
    const std::vector<ArgValue>& inputs,
    const std::vector<ExprHandle>& outputShape,
    const std::vector<ExprHandle>& outputStrides,
    const std::optional<ScalarType>& outputType,
    at::Device device) {
  bool hasWeight = true;
  bool hasBias = true;

  if (std::holds_alternative<ArgNone>(inputs[1])) {
    hasWeight = false;
  }

  if (std::holds_alternative<ArgNone>(inputs[2])) {
    hasBias = false;
  }

  return Compute(
      "aten_batch_norm",
      outputShape,
      outputStrides,
      [&](const std::vector<VarHandle>& axes) {
        TORCH_INTERNAL_ASSERT(axes.size() >= 2);
        // axes: N, C, H, W
        std::vector<ExprHandle> indices(axes.begin(), axes.end());
        ExprHandle c = indices[1];

        // Parameter list:
        // input, weight, bias, mean, var, training, momentum, eps,
        // cudnn_enabled
        std::vector<ExprHandle> exprInputs = {
            tensorOrConstant(inputs[0], indices), // input
            tensorOrConstant(inputs[3], {c}), // mean
            tensorOrConstant(inputs[4], {c}), // var
            constant(inputs[7]) // eps
        };

        ExprHandle weight = FloatImm::make(1);
        ExprHandle bias = FloatImm::make(0);
        if (hasWeight) {
          weight = tensorOrConstant(inputs[1], {c});
          exprInputs.push_back(weight);
        }
        if (hasBias) {
          bias = tensorOrConstant(inputs[2], {c});
          exprInputs.push_back(bias);
        }
        promoteInputs(exprInputs);

        ExprHandle input = exprInputs[0];
        ExprHandle mean = exprInputs[1];
        ExprHandle var = exprInputs[2];
        ExprHandle eps = exprInputs[3];

        auto inv_var = rsqrt(var + eps);
        auto alpha = inv_var * weight;
        auto beta = bias - mean * alpha;
        auto output = input * alpha + beta;
        return demoteOutput(output, outputType);
      });
}

} // namespace torch::jit::tensorexpr