1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
|
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/transforms/single_op_transform.h"
#include "caffe2/utils/proto_utils.h"
namespace caffe2 {
class TORCH_API ConvToNNPackTransform : public SingleOpTransform {
protected:
// Specify what the op needs to be to match the pattern.
bool MatchOperator(const OperatorDef& op) override {
return (
op.type() == "Conv" && op.device_option().device_type() == PROTO_CPU &&
op.engine() != "NNPACK");
}
// Specify how the operator should be replaced.
void ReplaceOperator(OperatorDef* op) override {
op->set_engine("NNPACK");
}
};
} // namespace caffe2
|