1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
|
#pragma once
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/predictor/predictor.h"
#include "caffe2/utils/filler.h"
namespace caffe2 {
namespace emulator {
typedef caffe2::Predictor::TensorList TensorList_t;
/*
* A filler to initialize the parameters and inputs of a predictor
*/
class Filler {
protected:
virtual void fill_input_internal(TensorList_t* input_data) const = 0;
public:
// initialize the workspace with parameter
virtual void fill_parameter(Workspace* ws) const = 0;
// generate input data and return input data size
size_t fill_input(TensorList_t* input_data) const {
CAFFE_ENFORCE(input_data, "input_data is null");
input_data->clear();
fill_input_internal(input_data);
uint64_t bytes = 0;
for (const auto& item : *input_data) {
bytes += item.nbytes();
}
if (bytes == 0) {
LOG(WARNING) << "0 input bytes filled";
}
return bytes;
}
const std::vector<std::string>& get_input_names() const {
CAFFE_ENFORCE(!input_names_.empty(), "input names is not initialized");
return input_names_;
}
virtual ~Filler() noexcept {}
protected:
std::vector<std::string> input_names_;
};
/*
* @init_net: a reader net to generate parameters
* @data_net: a reader net to generate inputs
*/
class DataNetFiller : public Filler {
public:
DataNetFiller(const NetDef&& init_net, const NetDef&& data_net)
: init_net_(init_net), data_net_(data_net) {
// The output of the data_net_ will be served as the input
int op_size = data_net_.op_size();
for (const auto i : c10::irange(op_size)) {
OperatorDef op_def = data_net_.op(i);
// We rely on Fill op to generate inputs
CAFFE_ENFORCE(op_def.type().find("Fill") != std::string::npos);
int output_size = op_def.output_size();
for (const auto j : c10::irange(output_size)) {
input_names_.push_back(op_def.output(j));
}
}
}
void fill_input_internal(TensorList_t* input_data) const override;
void fill_parameter(Workspace* ws) const override;
private:
const NetDef init_net_;
const NetDef data_net_;
};
void fill_with_type(
const TensorFiller& filler,
const std::string& type,
TensorCPU* output);
/*
* @run_net: the predict net with parameter and input names
* @input_dims: the input dimensions of all operator inputs of run_net
* @input_types: the input types of all operator inputs of run_net
*/
class DataRandomFiller : public Filler {
public:
DataRandomFiller(
const NetDef& run_net,
const std::vector<std::vector<std::vector<int64_t>>>& input_dims,
const std::vector<std::vector<std::string>>& input_types);
void fill_input_internal(TensorList_t* input_data) const override;
void fill_parameter(Workspace* ws) const override;
static TensorFiller get_tensor_filler(
const OperatorDef& op_def,
int input_index,
const std::vector<std::vector<int64_t>>& input_dims) {
Workspace ws;
for (const auto i : c10::irange(op_def.input_size())) {
// CreateOperator requires all input blobs present
ws.CreateBlob(op_def.input(i));
}
CAFFE_ENFORCE(op_def.has_type());
const OpSchema* schema = caffe2::OpSchemaRegistry::Schema(op_def.type());
if (schema == nullptr) {
throw std::invalid_argument(
op_def.type() + " does not have input fillers");
}
auto filler = schema->InputFillers(input_dims)[input_index];
return filler;
}
protected:
DataRandomFiller() {}
using filler_type_pair_t = std::pair<TensorFiller, std::string>;
std::unordered_map<std::string, filler_type_pair_t> parameters_;
std::unordered_map<std::string, filler_type_pair_t> inputs_;
};
// A DataRandomFiller that is more convenient to use in unit tests.
// Callers just need to supply input dimensions and types for non-intermediate
// blobs.
// It also treats parameters the same way as non-intermediate inputs (no
// handling of parameters separately).
class TestDataRandomFiller : public DataRandomFiller {
public:
TestDataRandomFiller(
const NetDef& net,
const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
const std::vector<std::vector<std::string>>& inputTypes);
// Fill input directly to the workspace.
void fillInputToWorkspace(Workspace* workspace) const;
};
// Convenient helpers to fill data to workspace.
TORCH_API void fillRandomNetworkInputs(
const NetDef& net,
const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
const std::vector<std::vector<std::string>>& inputTypes,
Workspace* workspace);
} // namespace emulator
} // namespace caffe2
|