1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
#include "caffe2/core/blob_serialization.h"
#include "caffe2/core/common.h"
#include "caffe2/core/context.h"
#include "caffe2/core/tensor_int8.h"
#include <c10/util/typeid.h>
#include "caffe2/core/types.h"
namespace caffe2 {
namespace int8 {
class Int8TensorCPUSerializer : public BlobSerializerBase {
public:
void Serialize(
const void* pointer,
TypeMeta typeMeta,
const string& name,
SerializationAcceptor acceptor) override {
CAFFE_ENFORCE(typeMeta.Match<Int8TensorCPU>());
const auto& tensor = *static_cast<const Int8TensorCPU*>(pointer);
BlobProto blob_proto;
blob_proto.set_name(name);
blob_proto.set_type("Int8TensorCPU");
QTensorProto& proto = *blob_proto.mutable_qtensor();
proto.set_name(name);
for (int i = 0; i < tensor.t.dim(); ++i) {
proto.add_dims(tensor.t.dim32(i));
}
proto.set_precision(8);
proto.set_scale(tensor.scale);
proto.set_bias(tensor.zero_point);
proto.set_is_signed(false);
const TensorProto::DataType data_type =
TypeMetaToDataType(tensor.t.dtype());
proto.set_data_type(data_type);
switch (data_type) {
case TensorProto_DataType_INT32:
detail::CopyToProtoAsIs(
tensor.t.numel(),
tensor.t.template data<int32_t>(),
proto.mutable_data(),
&this->context_);
break;
case TensorProto_DataType_UINT8:
detail::CopyToProtoWithCast(
tensor.t.numel(),
tensor.t.template data<uint8_t>(),
proto.mutable_data(),
&this->context_);
break;
default:
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
}
acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
}
private:
CPUContext context_;
};
class Int8TensorCPUDeserializer : public TensorDeserializer {
public:
void Deserialize(const BlobProto& blob_proto, Blob* blob) override {
const QTensorProto& proto = blob_proto.qtensor();
Int8TensorCPU* tensor = blob->template GetMutable<Int8TensorCPU>();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
tensor->scale = proto.scale();
tensor->zero_point = proto.bias();
vector<int> dims;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
for (const int d : proto.dims()) {
dims.push_back(d);
}
tensor->t.Resize(dims);
switch (proto.data_type()) {
case TensorProto_DataType_INT32:
detail::CopyFromProtoAsIs(
tensor->t.numel(),
proto.data(),
tensor->t.template mutable_data<int32_t>(),
&this->context_);
break;
case TensorProto_DataType_UINT8:
detail::CopyFromProtoWithCast(
tensor->t.numel(),
proto.data(),
tensor->t.template mutable_data<uint8_t>(),
&this->context_);
break;
default:
CAFFE_ENFORCE(false, "Unsupported data type in Int8TensorCPU");
}
}
private:
CPUContext context_;
};
} // namespace int8
namespace {
REGISTER_BLOB_SERIALIZER(
(TypeMeta::Id<int8::Int8TensorCPU>()),
int8::Int8TensorCPUSerializer);
REGISTER_BLOB_DESERIALIZER(Int8TensorCPU, int8::Int8TensorCPUDeserializer);
} // namespace
} // namespace caffe2
|