File: int8_conv_op.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (258 lines) | stat: -rw-r--r-- 9,188 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
#include <caffe2/ideep/operators/conv_pool_base_op.h>

using namespace caffe2;

namespace {

class IDEEPInt8ConvOp : public IDEEPConvPoolOpBase {
 public:
  USE_IDEEP_DEF_ALIASES();
  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();

  // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
  IDEEPInt8ConvOp(const OperatorDef& operator_def, Workspace* ws)
      : IDEEPConvPoolOpBase(operator_def, ws),
        scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
        zero_point_(
            this->template GetSingleArgument<int32_t>("Y_zero_point", 0)) {
    OPERATOR_NEEDS_FEATURE(pad_l() == pad_r() && pad_t() == pad_b(),
                           "Uneven padding not supported.");
    fusion_type_ = FUSION_UNKNOWN;
    last_input_ = BIAS_OR_INPUT_S;
    algo_ = ialgo::convolution_direct;
    auto conv_algorithm = OperatorBase::GetSingleArgument<int>(
        "conv_algorithm", CONV_ALGORITHM_AUTO);
    if (conv_algorithm == CONV_ALGORITHM_WINOGRAD) {
      algo_ = ialgo::convolution_winograd;
    }
    CAFFE_ENFORCE(zero_point_ == 128 || zero_point_ == 0);
    Y_scales_ = ConvertScales({scale_});
  }
  // NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default)
  virtual ~IDEEPInt8ConvOp() {}

  bool RunOnDeviceWithOrderNCHW() override {
    const auto &X = Input(INPUT_X);
    const auto &filter = Input(FILTER);
    auto *Y = Output(OUTPUT);

    CAFFE_ENFORCE(X.has_scale());
    CAFFE_ENFORCE(4 == X.ndims() && 4 == filter.ndims());
    CAFFE_ENFORCE(X.get_data_type() == idtype::s8
        || X.get_data_type() == idtype::u8);
    CAFFE_ENFORCE(filter.get_dim(2) == kernel_h());
    CAFFE_ENFORCE(filter.get_dim(3) == kernel_w());
    CAFFE_ENFORCE(
        X.get_dim(1) == filter.get_dim(1) * group_,
        "Convolution op: input channels does not match: # of input channels ",
        X.get_dim(1), " is not equal to kernel channels * group:",
        filter.get_dim(1), "*", group_);

    bool input_changed = (cached_X_descriptor_ != X.get_descriptor());
    if (input_changed) {
      cached_X_descriptor_ = X.dup_descriptor();
    }

    bool weights_changed = (cached_weights_descriptor_ != filter.get_descriptor());
    if (weights_changed) {
      cached_weights_descriptor_ = filter.dup_descriptor();
      CAFFE_ENFORCE(filter.get_data_type() == idtype::s8 && filter.has_scale());

      auto X_dt = X.get_data_type();
      lowp_kind_ = ilowp_kind::LOWP_U8S8;
      if (X_dt == idtype::s8) {
        lowp_kind_ = ilowp_kind::LOWP_S8S8;
      }

      auto expected_descriptor =
          ideep::convolution_forward::expected_weights_desc(
              filter.get_dims(),
              idtype::s8,
              {stride_.begin(), stride_.end()},
              pad_tl(),
              pad_br(),
              {dilation_.begin(), dilation_.end()},
              group_,
              algo_,
              iprop::forward_inference,
              X_dt, X.get_dims());
      if (filter.get_desc() != expected_descriptor) {
        filter_.init(expected_descriptor);
        filter_.set_scale(filter.get_scale());
        filter_.feed_from(filter);
      } else {
        filter_ = filter;
      }

      if (InputSize() > last_input_) {
        // NOTE: If the bias is shared by other operators in this module,
        // The existing bias scale should not satisfy current operator.
        // Thus, we have to requantize it by current input and filter scales.
        auto bias = Input(BIAS_OR_INPUT_S);
        bias_.init({bias.get_dims(), idtype::s32});
        iscale bias_scales (filter_.get_scale());
        for (auto &scale : bias_scales) { scale *= X.get_scale()[0]; }
        bias_.set_scale(bias_scales);
        bias_.feed_from(bias);
      }
    }

    bool with_bias = InputSize() > last_input_;
    if (input_changed || weights_changed) {
      auto Y_dims = CalcOutputDims(X, filter.get_dim(0));
      if (with_bias) {
        ideep::convolution_forward::prepare(
            conv_param,
            X,
            filter_,
            bias_,
            Y_dims,
            *Y,
            {stride_.begin(), stride_.end()},
            {dilation_.begin(), dilation_.end()},
            pad_tl(),
            pad_br(),
            group_,
            iscale(),
            iscale(),
            Y_scales_,
            attr_,
            algo_,
            iprop::forward_inference,
            lowp_kind_);
      } else {
        ideep::convolution_forward::prepare(
            conv_param,
            X,
            filter_,
            Y_dims,
            *Y,
            {stride_.begin(), stride_.end()},
            {dilation_.begin(), dilation_.end()},
            pad_tl(),
            pad_br(),
            group_,
            iscale(),
            iscale(),
            Y_scales_,
            attr_,
            algo_,
            iprop::forward_inference,
            lowp_kind_);
      }
    }

    if (with_bias) {
      ideep::convolution_forward::compute(conv_param, X, filter_, bias_, *Y);
    } else {
      ideep::convolution_forward::compute(conv_param, X, filter_, *Y);
    }

    if (fusion_type_ != FUSION_CONV_RELU && fusion_type_ != FUSION_UNKNOWN) {
      CAFFE_ENFORCE(
          Y == &(Input(InputSize() - 1)),
          "Convolution fusion op: InPlace is enforced for sum fusion.");
    }

    return true;
  }

 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  iattr attr_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  ialgo algo_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  float scale_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  int last_input_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  int32_t zero_point_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  ilowp_kind lowp_kind_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  FusionType fusion_type_;

  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  itensor filter_, bias_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  iscale  Y_scales_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_;
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  ideep::convolution_forward_params conv_param;

  INPUT_TAGS(INPUT_X, FILTER, BIAS_OR_INPUT_S, INPUT_S);
  OUTPUT_TAGS(OUTPUT);
};

class IDEEPInt8ConvReluOp final : public IDEEPInt8ConvOp {
 public:
  USE_IDEEP_DEF_ALIASES();
  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();

  IDEEPInt8ConvReluOp(const OperatorDef& operator_def, Workspace* ws)
      : IDEEPInt8ConvOp(operator_def, ws) {
    CAFFE_ENFORCE(zero_point_ == 0);
    last_input_ = BIAS_OR_INPUT_S;
    attr_ = iattr::fuse_relu();
    fusion_type_ = FUSION_CONV_RELU;
  }
  // NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default)
  virtual ~IDEEPInt8ConvReluOp() {}
};

class IDEEPInt8ConvSumOp final : public IDEEPInt8ConvOp {
 public:
  USE_IDEEP_DEF_ALIASES();
  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();

  IDEEPInt8ConvSumOp(const OperatorDef& operator_def, Workspace* ws)
      : IDEEPInt8ConvOp(operator_def, ws) {
    last_input_ = INPUT_S;
    attr_ = iattr::fuse_sum();
    fusion_type_ = FUSION_CONV_SUM;
  }
  // NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default)
  virtual ~IDEEPInt8ConvSumOp() {}
};

class IDEEPInt8ConvSumReluOp final : public IDEEPInt8ConvOp {
 public:
  USE_IDEEP_DEF_ALIASES();
  USE_IDEEP_CONV_POOL_BASE_FUNCTIONS();

  IDEEPInt8ConvSumReluOp(const OperatorDef& operator_def, Workspace* ws)
      : IDEEPInt8ConvOp(operator_def, ws) {
    last_input_ = INPUT_S;
    attr_ = iattr::residual();
    fusion_type_ = FUSION_CONV_SUM_RELU;
  }
  // NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default)
  virtual ~IDEEPInt8ConvSumReluOp() {}
};

REGISTER_IDEEP_OPERATOR_WITH_ENGINE(Int8Conv, DNNLOWP, IDEEPInt8ConvOp);
REGISTER_IDEEP_OPERATOR_WITH_ENGINE(Int8ConvRelu, DNNLOWP, IDEEPInt8ConvReluOp);
REGISTER_IDEEP_OPERATOR_WITH_ENGINE(Int8ConvSum, DNNLOWP, IDEEPInt8ConvSumOp);
REGISTER_IDEEP_OPERATOR_WITH_ENGINE(Int8ConvSumRelu, DNNLOWP, IDEEPInt8ConvSumReluOp);

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function)
OPERATOR_SCHEMA(Int8ConvSum)
    .NumInputs(2, 4)
    .NumOutputs(1)
    .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
    .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
        ConvPoolOpBase<CPUContext>::CostInferenceForConv))
    .AllowInplace({{2, 0}, {3, 0}});

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function)
OPERATOR_SCHEMA(Int8ConvSumRelu)
    .NumInputs(2, 4)
    .NumOutputs(1)
    .TensorInferenceFunction(ConvPoolOpBase<CPUContext>::TensorInferenceForConv)
    .CostInferenceFunction(OpSchema::CostInferenceFunctionType(
        ConvPoolOpBase<CPUContext>::CostInferenceForConv))
    .AllowInplace({{2, 0}, {3, 0}});

} // namespace