1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
#ifndef CAFFE2_CORE_BLOB_H_
#define CAFFE2_CORE_BLOB_H_
#include <cstddef>
#include <sstream>
#include <typeinfo>
#include <type_traits>
#include <vector>
#include "caffe2/core/common.h"
#include <ATen/core/blob.h>
#include <c10/util/typeid.h>
#include "caffe2/core/logging.h"
#include "caffe2/core/tensor.h"
#include "caffe2/core/tensor_int8.h"
namespace caffe2 {
inline bool BlobIsInt8TensorCPUType(const Blob& blob) {
return blob.meta().Match<int8::Int8TensorCPU>();
}
inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
bool is_match = blob.meta().Match<Tensor>();
if (!is_match) {
return false;
}
const Tensor* tensor = &blob.Get<Tensor>();
return tensor && *tensor && tensor->GetDeviceType() == device_type;
}
inline Tensor* BlobSetTensor(Blob* blob, Tensor&& tensor) {
return blob->Reset<Tensor>(new Tensor(std::move(tensor)));
}
inline Tensor GetSizedTensorWithOptions(
Tensor&& previous_tensor,
at::IntArrayRef dims,
at::TensorOptions options) {
Tensor tensor = std::move(previous_tensor);
if (!tensor.defined()) {
return caffe2::empty(dims, options);
}
if (tensor.GetDevice() == options.device() ||
(!tensor.GetDevice().has_index() &&
tensor.GetDeviceType() == options.device().type())) {
if (tensor.sizes() != dims) {
// Resize when the dims doesn't match
tensor.Resize(dims);
}
if (tensor.dtype() == options.dtype()) {
tensor.raw_mutable_data();
} else {
// create a new Tensor when the data_type doesn't match
return caffe2::empty(dims, options);
}
return tensor;
}
return caffe2::empty(dims, options);
}
// need to keep both functions that returns Tensor* and the one
// returns Tensor for clangr codemod
inline Tensor*
BlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (*tensor) {
// We only compare device_type if the index is not set since there are Tensors
// TODO: remove the extra check when all the Tensors are properly initialized
if (tensor->GetDevice() == options.device() || (!tensor->GetDevice().has_index() && tensor->GetDeviceType() == options.device().type())) {
if (tensor->sizes() != dims) {
// Resize when the dims doesn't match
tensor->Resize(dims);
}
if (tensor->dtype() == options.dtype()) {
tensor->raw_mutable_data();
} else {
tensor->raw_mutable_data(options.dtype());
}
return tensor;
}
// create a new Tensor when device doesn't match
}
}
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " dims: " << dims;
// << " options: " << options; (operator<< for Options is in at:: now)
return BlobSetTensor(blob, caffe2::empty(dims, options));
}
inline Tensor
XBlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
return BlobGetMutableTensor(blob, dims, options)->UnsafeSharedInstance();
}
inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (*tensor && tensor->GetDeviceType() == device_type) {
return tensor;
}
}
// if we're here, then either Blob didn't hold a Tensor
// or that Tensor had the wrong DeviceType.
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " DeviceType:" << device_type;
return BlobSetTensor(blob, Tensor(device_type));
}
inline const Tensor& BlobGetTensor(const Blob& blob, DeviceType device_type) {
if (blob.IsType<Tensor>()) {
const auto& tensor = blob.Get<Tensor>();
if (tensor.GetDeviceType() == device_type) {
return tensor;
}
}
CAFFE_THROW("Blob didn't contain a Tensor or the device_type doesn't match");
}
inline Tensor BlobGetTensorOrUndefined(const Blob& blob) {
if (blob.IsType<Tensor>()) {
return blob.Get<Tensor>().UnsafeSharedInstance();
} else {
return Tensor();
}
}
} // namespace caffe2
#endif // CAFFE2_CORE_BLOB_H_
|