1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
#pragma once
#include "caffe2/core/logging.h"
#include "caffe2/opt/shape_info.h"
#include "caffe2/proto/caffe2_pb.h"
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
namespace caffe2 {
// This struct stores the max bound size for batch in the general sense.
// max_batch_size is the upper bound of batch_size.
// max_seq_size is the upper bound of length of every item in a batch.
// Upper bound of length of a batch of items should be max_batch_size *
// max_seq_size.
struct TORCH_API BoundShapeSpec {
explicit BoundShapeSpec(int64_t b, int64_t q)
: max_batch_size(b),
max_seq_size(q),
num_embeddings(0),
embedding_length(0) {}
explicit BoundShapeSpec(int64_t b, int64_t q, int64_t n, int64_t e)
: max_batch_size(b),
max_seq_size(q),
num_embeddings(n),
embedding_length(e) {}
int64_t max_batch_size;
int64_t max_seq_size;
// The following two parameters are for shape inference of UnPackRecords
int64_t num_embeddings;
int64_t embedding_length;
};
/// \class A class that does bound shape inference given a C2 net. Depending on
/// its type, each op have a maximum shape that it accepts. We define some
/// initial bound for certain dimension, for example max batch size or max
/// sequnce lookup size. And the inference will first infer the input size and
/// then propagates the bound shape down the network. For now the variable part
/// (bound part) is the first dimension of the shape, which usually corresponds
/// to the batch size or sequence lookup size.
class BoundShapeInferencerBase {
public:
explicit BoundShapeInferencerBase(const BoundShapeSpec& spec) : spec_(spec) {
CAFFE_ENFORCE_GE(spec_.max_batch_size, 0);
CAFFE_ENFORCE_GE(spec_.max_seq_size, 0);
}
virtual ~BoundShapeInferencerBase() {}
// Initializes BoundShapeInferencer and infers bound shape and type.
// info: shape information of some tensors,
// e.g. shape information of external input / output tensors;
// extract_feature_len:
// indicating whether to extract feature length from SigridTransform
// and other related operators. When enabled,
// extracted feature length information will be used to infer tensor shapes.
virtual void InferBoundShapeAndType(
const NetDef& net,
const ShapeInfoMap& info,
caffe2::Workspace* ws,
bool extract_feature_len = false) = 0;
const ShapeInfoMap& shape_info() const {
return shape_info_;
}
/// Print out all the shape info
std::string PrintShapeInfo() const {
std::stringstream ss;
for (const auto& kv : shape_info_) {
const auto& s = kv.second;
ss << s.shape.name() << ": dim_type: " << s.getDimType() << ", dims: [";
for (const auto d : s.shape.dims()) {
ss << d << ", ";
}
ss << "], dtype: " << s.shape.data_type() << "\n";
}
return ss.str();
}
protected:
const BoundShapeSpec spec_;
ShapeInfoMap shape_info_;
bool extract_feature_len_;
};
class TORCH_API BoundShapeInferencer : public BoundShapeInferencerBase {
public:
explicit BoundShapeInferencer(const BoundShapeSpec& spec)
: BoundShapeInferencerBase(spec) {}
~BoundShapeInferencer() override {}
void InferBoundShapeAndType(
const NetDef& net,
const ShapeInfoMap& info,
caffe2::Workspace* ws,
bool extract_feature_len = false) override;
protected:
TensorShape& CheckAndSetTensorBoundShape(
const std::string& name,
const std::vector<TensorBoundShape::DimType>& t,
std::vector<int64_t> bound_dims,
TensorProto::DataType type,
bool is_quantized,
bool allow_existing_shape = false,
float scale = 1,
int offset = 0,
bool in_place_op = false);
TensorShape& SetTensorBoundShapeIfNotExist(
const std::string& name,
const std::vector<TensorBoundShape::DimType>& t,
std::vector<int64_t> bound_dims,
TensorProto::DataType type,
bool is_quantized);
virtual void InferOps(const OperatorDef& op, caffe2::Workspace* ws);
void InferConcatInputs(const OperatorDef& op);
void InferInt8QuantizeInput(const OperatorDef& op);
void InferElementwiseOpInput(const OperatorDef& op);
void InferElementwiseOp(const OperatorDef& op);
void InferGivenTensorFill(const OperatorDef& op);
void InferSparseLengthsSum(const OperatorDef& op);
void InferFC(const OperatorDef& op);
void InferConcat(const OperatorDef& op);
void InferShape(const OperatorDef& op);
void InferReshape(const OperatorDef& op);
void InferLengthsRangeFill(const OperatorDef& op);
void InferQuantizationTransformation(const OperatorDef& op);
void InferUnPackRecords(const OperatorDef& op);
void InferTile(const OperatorDef& op);
void InferSparseLengthsSumSparseLookup(const OperatorDef& op);
void InferSoftmax(const OperatorDef& op);
void InferBucketize(const OperatorDef& op);
void InferLpNorm(const OperatorDef& op);
void InferClip(const OperatorDef& op);
void InferMean(const OperatorDef& op);
void InferDiv(const OperatorDef& op);
void InferTranspose(const OperatorDef& op);
// Standard shape/type inference using op schema registered shape inference
// function
void InferCommonOp(const OperatorDef& op, const OpSchema* schema = nullptr, bool bypass_input_check = false, bool in_place_op = false);
// Initialize private parameters, such as shape_info, extract_feature_len_
// This is called at the beginning of InferBoundShapeAndType()
virtual void Initialize(const ShapeInfoMap& info, bool extract_feature_len);
void EnsureShapeNames(ShapeInfoMap* info) const;
TensorBoundShape::DimType current_dim_type_{TensorBoundShape_DimType_BATCH};
int64_t current_max_batch_size_{0};
};
TORCH_API std::shared_ptr<BoundShapeInferencerBase> getBoundShapeInferencer(
const BoundShapeSpec& spec);
C10_DECLARE_SHARED_REGISTRY(
BoundShapeInferencerRegistry,
BoundShapeInferencerBase,
const BoundShapeSpec&);
} // namespace caffe2
|