File: quantization.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (66 lines) | stat: -rw-r--r-- 2,251 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <ATen/Context.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/mobile/quantization.h>

namespace torch {
namespace jit {
namespace mobile {
namespace quantization {

void PTQQuanizationHelper::quantize_dynamic(
    torch::jit::mobile::Module& m,
    const std::string& method_name) {
  at::globalContext().setReleaseWeightsWhenPrepacking(false);
  std::string reset_observers_method_name = "reset_observers_" + method_name;
  std::string observe_method_name = "observe_" + method_name;
  std::string quantize_method_name = "quantize_" + method_name;
  std::string quantized_method_name = "quantized_" + method_name;

  TORCH_CHECK(
      m.find_method(reset_observers_method_name).has_value(),
      "PTQ ready module must have",
      reset_observers_method_name,
      " method.");
  TORCH_CHECK(
      m.find_method(observe_method_name),
      "PTQ ready module must have",
      reset_observers_method_name,
      " method.");
  TORCH_CHECK(
      m.find_method(quantize_method_name),
      "PTQ ready module must have",
      quantize_method_name,
      " method.");
  TORCH_CHECK(
      m.find_method(quantized_method_name),
      "PTQ ready module must have",
      quantized_method_name,
      " method.");
  TORCH_CHECK(
      m.find_method("get_all_bundled_inputs"),
      "PTQ ready module must have get_all_bundled_inputs method.");

  auto inputs = m.run_method("get_all_bundled_inputs")
                    .toList()
                    .get(0)
                    .toTupleRef()
                    .elements()
                    .vec();
  m.get_method(reset_observers_method_name)({});
  m.get_method(observe_method_name)(inputs);
  m.get_method(quantize_method_name)(inputs);

  m.compareMethodSchemas(method_name, quantized_method_name);
  m.unsafeRemoveMethod(method_name);
  const Function& to_be_copied =
      m.find_method(quantized_method_name).value().function();
  m.unsafeCopyMethod(method_name, to_be_copied);
  m.unsafeRemoveMethod(quantized_method_name);
  m.unsafeRemoveMethod(quantize_method_name);
  m.unsafeRemoveMethod(observe_method_name);
  m.unsafeRemoveMethod(reset_observers_method_name);
}
} // namespace quantization
} // namespace mobile
} // namespace jit
} // namespace torch