File: batch_permutation_dnnlowp_op.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (66 lines) | stat: -rw-r--r-- 1,786 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "caffe2/quantization/server/batch_permutation_dnnlowp_op.h"

namespace caffe2 {

template <typename T>
bool BatchPermutationDNNLowPOp<T>::RunOnDevice() {
  using namespace dnnlowp;

  this->ParseDNNLowPOperatorArguments_();

  // Choose quantization params
  in_qparams_[INPUT] =
      GetInputTensorQuantizationParamsOf(this, INPUT, qfactory_.get());

  const auto& X = InputTensorCPU_(INPUT);
  const auto& indices = Input(INDICES);
  auto* Y = OutputTensorCPU_(OUTPUT);

  CAFFE_ENFORCE(indices.ndim() == 1, "indices must be 1-d");
  CAFFE_ENFORCE(
      X.dim32(0) == indices.dim32(0),
      "X.dim32(0) must be equal to indices.dim32(0)",
      "(",
      X.dim32(0),
      " vs. ",
      indices.dim32(0),
      ")");
  CAFFE_ENFORCE_GT(X.dim32(0), 0);

  Y->ResizeLike(X);
  const T* X_data = X.template data<T>();
  const int* indices_data = indices.template data<int>();
  T* Y_data = Y->template mutable_data<T>();

  int N = X.dim32(0);
  int K = X.numel() / N;

#ifdef _OPENMP
#pragma omp parallel for
#endif
  for (int i = 0; i < N; ++i) {
    int origIdx = i * K;
    int permuteIdx = indices_data[i] * K;
    std::memcpy(Y_data + origIdx, X_data + permuteIdx, K * sizeof(T));
  }

  // Even if there is a pre-chosen quantization parameters for the output,
  // it is ignored because batch permutation output quantization should be same
  // as the input.
  PropagateOutputTensorQuantizationParams(this, 0, in_qparams_[INPUT]);

  return true;
}

REGISTER_CPU_OPERATOR_WITH_ENGINE(
    BatchPermutation,
    DNNLOWP,
    BatchPermutationDNNLowPOp<uint8_t>);
REGISTER_CPU_OPERATOR_WITH_ENGINE(
    Int8BatchPermutation,
    DNNLOWP,
    BatchPermutationDNNLowPOp<uint8_t>);

OPERATOR_SCHEMA(Int8BatchPermutation).NumInputs(2).NumOutputs(1);

} // namespace caffe2