File: op_wrapper.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (88 lines) | stat: -rw-r--r-- 2,959 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#pragma once

#include "caffe2/core/operator.h"
#include "caffe2/core/tensor_int8.h"
#include "caffe2/quantization/server/caffe2_dnnlowp_utils.h"
#include "caffe2/quantization/server/dnnlowp.h"

namespace caffe2 {

/**
 * Wrap a floating-point operator with quantized inputs with type T.
 * This class is to measure quantization error against fp32 reference.
 */
template <typename OpType, typename T>
class OpWrapper {
 public:
  OpWrapper(OperatorBase* op, dnnlowp::QuantizationFactory* qfactory)
      : op_(op), qfactory_(qfactory) {
    for (auto name : op->debug_def().input()) {
      local_input_blobs_.push_back(local_ws_.CreateBlob(name));
      TORCH_CHECK_NOTNULL(local_input_blobs_.back());
    }
    OperatorDef def = op->debug_def();
    local_op_.reset(new OpType(def, &local_ws_));
    for (auto name : def.output()) {
      local_output_blobs_.push_back(local_ws_.GetBlob(name));
      TORCH_CHECK_NOTNULL(local_output_blobs_.back());
    }
  }

  void DequantizeInput() {
    const OperatorDef& def = op_->debug_def();
    CPUContext context(def.device_option());

    for (int i = 0; i < op_->InputSize(); ++i) {
      if (op_->InputIsType<int8::Int8TensorCPU>(i)) {
        const TensorCPU& qtensor = op_->Input<int8::Int8TensorCPU>(i).t;
        TensorCPU* float_tensor =
            BlobGetMutableTensor(local_input_blobs_[i], CPU);
        // FIXME: doesn't work for bias so we shouldn't quantize bias before
        // model loading when we're running a shadow operator in fp32 for
        // example for measuring quantization error.
        float_tensor->ResizeLike(qtensor);
        fbgemm::Dequantize<T>(
            qtensor.data<T>(),
            float_tensor->template mutable_data<float>(),
            qtensor.numel(),
            dnnlowp::GetInputTensorQuantizationParamsOf(op_, i, qfactory_));
      } else {
        local_input_blobs_[i]->ShareExternal(
            const_cast<void*>(op_->Inputs()[i]->GetRaw()),
            op_->Inputs()[i]->meta());
      }
    }
  }

  OpType* Get() {
    return local_op_.get();
  }

  dnnlowp::TensorQuantizationParams GetOutputQuantizationParams(
      dnnlowp::QuantizationFactory* qfactory,
      int index = 0) {
    using namespace dnnlowp;

    float min, max;
    auto& out_tensor = local_output_blobs_[index]->template Get<TensorCPU>();
    fbgemm::FindMinMax(
        out_tensor.template data<float>(), &min, &max, out_tensor.numel());
    if (op_->OperatorBase::GetSingleArgument<std::string>("followed_by", "") ==
        "Relu") {
      min = std::max(0.0f, min);
      max = std::max(0.0f, max);
    }

    return qfactory->ChooseQuantizationParams(min, max);
  }

 private:
  OperatorBase* op_; /* container quantized op */
  Workspace local_ws_;
  std::vector<Blob*> local_input_blobs_;
  std::vector<Blob*> local_output_blobs_;
  std::unique_ptr<OpType> local_op_; /* contained fp32 reference op */
  dnnlowp::QuantizationFactory* qfactory_;
};

} // namespace caffe2