1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
|
#pragma once
#include "caffe2/opt/backend_transformer_base.h"
#include <unordered_set>
namespace caffe2 {
struct TvmTransformOptions final : public BackendTransformOptions {
explicit TvmTransformOptions() : BackendTransformOptions() {}
// Whether to enable profiling based jit
bool profiling_based_jit{false};
};
class TORCH_API TvmTransformer final : public BackendTransformerBase {
public:
explicit TvmTransformer(const TvmTransformOptions& opts)
: BackendTransformerBase(), opts_(opts) {}
~TvmTransformer() override {}
// Given workspace and predict net, cluster continuous parts that can be run
// by TVM and create one TVMJit op for each clustered subgraph.
// \param ws c2 workspace
// \param pred_net c2 predict net
// \param weight_names list of the names of the constant weights
// \param shape_hints User provided shape info, usually for primary inputs so
// that bound shape inference can have something to start
// \param blocklisted_ops a set of ops that we don't want to lower to TVM in
// terms of their net positions. This is very useful for debugging but for
// normal runs it should be empty
void transform(
Workspace* ws,
NetDef* pred_net,
const std::vector<std::string>& weight_names,
const ShapeInfoMap& shape_hints,
const std::unordered_set<int>& blocklisted_ops) override;
static const std::unordered_set<std::string>& getSupportedOps();
static bool canConvertFullGraph(
const caffe2::NetDef& net,
const std::unordered_set<int>& blocklisted_ops);
private:
// Given TVM runnable subnets, contract them into one TVMJitOp
NetDef buildTvmOp(
const caffe2::NetDef& net,
const std::unordered_set<std::string>& weights,
const ShapeInfoMap& shape_hints);
// Apply transform to cluster connected TVM runnable ops into one TVMJitOp
NetDef applyTvmTransform(
NetDef* pred_net,
const std::unordered_set<std::string>& weights,
const std::unordered_set<int>& blocklisted_ops,
const ShapeInfoMap& shape_hints);
// Options
TvmTransformOptions opts_;
// Track number of TVMJitOp we created
int tvm_op_id_{0};
// Model id
std::string model_id_;
};
// Helper function to clean up a net and run tvm transform.
TORCH_API void tvmTransform(
NetDef* net,
Workspace* ws,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::string>& weight_names,
const ShapeInfoMap& shape_hints,
const std::unordered_set<int>& blocklisted_ops,
int32_t max_batch_size,
int32_t max_seq_size,
int32_t num_embeddings,
int32_t embedding_size,
int32_t tvm_min_ops,
bool tvm_profiling_based_jit,
bool debug);
TORCH_API void cleanUpPredictNet(
NetDef* net,
const std::vector<std::string>& input_names,
const std::vector<std::string>& output_names,
const std::vector<std::string>& weight_names);
} // namespace caffe2
|