File: pointwise_elim.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (40 lines) | stat: -rw-r--r-- 1,298 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include "caffe2/core/logging.h"
#include "caffe2/opt/custom/pointwise_elim.h"
#include "caffe2/opt/nql/graphmatcher.h"
#include "caffe2/opt/passes.h"
#include "nomnigraph/Representations/NeuralNet.h"
#include "nomnigraph/Support/Common.h"
#include "nomnigraph/Transformations/SubgraphMatcher.h"

namespace caffe2 {
namespace opt {

using namespace nom::repr;

void fuseCastBatchOneHot(NNModule* nn) {
  nom::nql::GraphMatcher gm;
  gm.initFromString(R"NQL(def nn {
      %cast = Cast(%input)
      %one_hot = BatchOneHot(%cast, %lengths, %values)
      %out = Cast(%one_hot)
  })NQL");
  CAFFE_ENFORCE(gm.getMatcher(), "Unable to parse NQL query.");

  for (const auto& match : gm.getMatches(nn->dataFlow)) {
    // This matches most of prod as of H2 2018
    auto first_cast = nn::getProducer(match["\%cast"]);
    auto second_cast = nn::getProducer(match["\%out"]);
    NOM_REQUIRE_OR_CONT(nn::get<Cast>(first_cast)->getTo() == 10);
    NOM_REQUIRE_OR_CONT(nn::get<Cast>(second_cast)->getTo() == 1);

    nn->replaceSubgraphWithOperator<CastedBatchOneHot>(
        match.subgraph,
        {match["\%input"], match["\%lengths"], match["\%values"]},
        {match["\%out"]});
  }
}

REGISTER_OPT_PASS_FROM_FUNC(FuseCastBatchOneHot, fuseCastBatchOneHot);

} // namespace opt
} // namespace caffe2