File: freeze_quantization_params.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (60 lines) | stat: -rw-r--r-- 1,972 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <caffe2/quantization/server/int8_gen_quant_params.h>
#include <caffe2/utils/proto_utils.h>
#include "freeze_quantization_params.h"

namespace caffe2 {
void freezeQuantizationParams(NetDef* net, Workspace* ws) {
  for (auto& op : *net->mutable_op()) {
    if ((op.type() == "Int8Quantize" && op.input_size() == 2) ||
        (op.type() == "Int8FC" && op.input_size() == 4)) {
      int lastPos = op.input_size() - 1;
      const auto& paramName = op.input(lastPos);
      auto* b = ws->GetBlob(paramName);
      if (!b) {
        LOG(WARNING)
            << "ParamBlob " << paramName
            << " does not exist in the workspace. Skip freezing current op.";
        continue;
      }
      if (!b->template IsType<caffe2::unique_ptr<Int8QuantParamsBlob>>()) {
        LOG(WARNING)
            << "ParamBlob " << paramName
            << " is not of caffe2::unique_ptr<Int8QuantParamsBlob> type. Skip freezing current op.";
        continue;
      }

      // Extract and set scale and zero point for the op
      const auto* param =
          b->template Get<caffe2::unique_ptr<Int8QuantParamsBlob>>().get();
      CAFFE_ENFORCE(param);
      const float scale = param->qparam.scale;
      const int zero_point = param->qparam.zero_point;
      bool argSet = false;
      for (auto& arg : *op.mutable_arg()) {
        if (arg.name() == "Y_scale") {
          arg.set_f(scale);
          argSet = true;
          break;
        }
      }
      if (!argSet) {
        op.add_arg()->CopyFrom(MakeArgument<float>("Y_scale", scale));
      }
      argSet = false;
      for (auto& arg : *op.mutable_arg()) {
        if (arg.name() == "Y_zero_point") {
          arg.set_i(zero_point);
          argSet = true;
          break;
        }
      }
      if (!argSet) {
        op.add_arg()->CopyFrom(MakeArgument<int>("Y_zero_point", zero_point));
      }

      // Remove last input of the op
      op.mutable_input()->RemoveLast();
    }
  }
}
} // namespace caffe2