1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
|
#include <caffe2/ideep/ideep_utils.h>
using namespace caffe2;
namespace {
USE_IDEEP_DEF_ALIASES();
class IDEEPInt8FullyConnectedOp final : public IDEEPOperator {
public:
USE_IDEEP_DEF_ALIASES();
USE_IDEEP_OPERATOR_FUNCTIONS();
IDEEPInt8FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
: IDEEPOperator(operator_def, ws),
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)),
axis_w_(OperatorBase::GetSingleArgument<int32_t>("axis_w", 1)),
scale_(this->template GetSingleArgument<float>("Y_scale", 1.0)),
zero_point_(
this->template GetSingleArgument<int32_t>("Y_zero_point", 0)) {
CAFFE_ENFORCE(zero_point_ == 128 || zero_point_ == 0);
if (zero_point_ == 0) {
Y_data_type_ = idtype::u8;
} else {
Y_data_type_ = idtype::s8;
}
Y_scales_ = ConvertScales({scale_});
}
// NOLINTNEXTLINE(modernize-use-override,modernize-use-equals-default)
virtual ~IDEEPInt8FullyConnectedOp() {}
bool RunOnDevice() override {
const auto& X = Input(INPUT);
const auto& filter = Input(FILTER);
auto* Y = Output(OUTPUT);
itensor X_in = X;
auto X_dims = CanonicalDims(X_in.get_dims(), axis_);
if (X_in.get_dims() != X_dims) {
X_in.reshape(X_dims);
}
if (cached_X_descriptor_ != X.get_descriptor()) {
cached_X_descriptor_ = X.dup_descriptor();
Y_.init({{X.get_dim(0), filter.get_dim(0)}, idtype::f32});
}
if (cached_weights_descriptor_ != filter.get_descriptor()) {
cached_weights_descriptor_ = filter.dup_descriptor();
CAFFE_ENFORCE(filter.get_data_type() == idtype::s8 && filter.has_scale());
// INT8 FC is not supported so far.
filter_ = filter.to_public();
auto filter_dims = CanonicalDims(filter_.get_dims(), axis_w_);
if (filter_.get_dims() != filter_dims) {
filter_.reshape(filter_dims);
}
if (InputSize() > BIAS) {
bias_ = Input(BIAS).to_public();
}
Y_.init({{X.get_dim(0), filter.get_dim(0)}, idtype::f32});
}
if (InputSize() > BIAS) {
ideep::inner_product_forward::compute(
X_in, filter_, bias_, Y_);
} else {
ideep::inner_product_forward::compute(X_in, filter_, Y_);
}
Y->init({Y_.get_dims(), Y_data_type_});
Y->set_scale(Y_scales_);
Y->feed_from(Y_);
return true;
}
private:
size_t axis_{1};
size_t axis_w_{1};
float scale_;
int32_t zero_point_;
idtype Y_data_type_;
itensor filter_, bias_, Y_;
iscale Y_scales_;
itensor::descriptor cached_X_descriptor_, cached_weights_descriptor_;
INPUT_TAGS(INPUT, FILTER, BIAS);
OUTPUT_TAGS(OUTPUT);
};
REGISTER_IDEEP_OPERATOR_WITH_ENGINE(Int8FC, DNNLOWP, IDEEPInt8FullyConnectedOp);
} // namespace
|