1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
|
#include "caffe2/core/logging.h"
#include "caffe2/opt/custom/pointwise_elim.h"
#include "caffe2/opt/nql/graphmatcher.h"
#include "caffe2/opt/passes.h"
#include "nomnigraph/Representations/NeuralNet.h"
#include "nomnigraph/Support/Common.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"
namespace caffe2 {
namespace opt {
using namespace nom::repr;
void fuseCastBatchOneHot(NNModule* nn) {
nom::nql::GraphMatcher gm;
gm.initFromString(R"NQL(def nn {
%cast = Cast(%input)
%one_hot = BatchOneHot(%cast, %lengths, %values)
%out = Cast(%one_hot)
})NQL");
CAFFE_ENFORCE(gm.getMatcher(), "Unable to parse NQL query.");
for (const auto& match : gm.getMatches(nn->dataFlow)) {
// This matches most of prod as of H2 2018
auto first_cast = nn::getProducer(match["\%cast"]);
auto second_cast = nn::getProducer(match["\%out"]);
NOM_REQUIRE_OR_CONT(nn::get<Cast>(first_cast)->getTo() == 10);
NOM_REQUIRE_OR_CONT(nn::get<Cast>(second_cast)->getTo() == 1);
nn->replaceSubgraphWithOperator<CastedBatchOneHot>(
match.subgraph,
{match["\%input"], match["\%lengths"], match["\%values"]},
{match["\%out"]});
}
}
REGISTER_OPT_PASS_FROM_FUNC(FuseCastBatchOneHot, fuseCastBatchOneHot);
} // namespace opt
} // namespace caffe2
|