File: EliminateEmptyTensors.cpp

package info (click to toggle)
llvm-toolchain-17 1%3A17.0.6-22
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,799,624 kB
  • sloc: cpp: 6,428,607; ansic: 1,383,196; asm: 793,408; python: 223,504; objc: 75,364; f90: 60,502; lisp: 33,869; pascal: 15,282; sh: 9,684; perl: 7,453; ml: 4,937; awk: 3,523; makefile: 2,889; javascript: 2,149; xml: 888; fortran: 619; cs: 573
file content (107 lines) | stat: -rw-r--r-- 4,074 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/Transforms.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"

using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::linalg;

/// Get an output operand that matches the given input operand and can be used
/// to eliminate a tensor.empty op.
static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
  for (OpOperand *operand : op.getDpsInitOperands()) {
    // Operand must be unused.
    if (op.payloadUsesValueFromOperand(operand))
      continue;
    // Types must match.
    if (operand->get().getType() != in->get().getType())
      continue;
    // Indexing maps must match.
    if (op.getMatchingIndexingMap(operand) != op.getMatchingIndexingMap(in))
      continue;
    return operand;
  }
  return nullptr;
}

LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
    RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
  OpBuilder::InsertionGuard g(rewriter);
  DominanceInfo domInfo;

  op->walk([&](LinalgOp op) {
    // Only ops with all "parallel" iterator types are supported.
    if (op.getNumParallelLoops() != op.getNumLoops())
      return WalkResult::skip();

    for (OpOperand *in : op.getDpsInputOperands()) {
      // Skip non-tensor operands.
      if (!in->get().getType().isa<RankedTensorType>())
        continue;

      // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
      // equivalent tensors. I.e., stop when there are ops such as extract_slice
      // on the path.
      TraversalConfig config;
      config.followEquivalentOnly = true;
      config.alwaysIncludeLeaves = false;
      SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
          in->get(), /*condition=*/
          [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
          config);
      if (emptyTensors.empty())
        continue;

      // Find matching out operand.
      OpOperand *out = getUnusedOutOperand(op, in);
      if (!out)
        continue;

      // Check if this transform would violate dominance.
      if (!llvm::all_of(emptyTensors, [&](Value v) {
            return domInfo.properlyDominates(out->get(), v.getDefiningOp());
          }))
        continue;

      // Replace all uses of the tensor.empty, but do not delete it yet. It will
      // fold away later (to not invalidate DominanceInfo).
      for (Value v : emptyTensors) {
        assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
        rewriter.replaceAllUsesWith(v, out->get());
      }

      // Turn the "in" into an "out".
      rewriter.updateRootInPlace(op, [&]() {
        out->set(in->get());
        // The original "in" could be removed entirely here (because it will no
        // longer have any uses in the payload), but we delegate this to
        // existing cleanup patterns that remove unused operands.
        in->set(emptyTensors.front());
        BlockArgument outArg = op.getMatchingBlockArgument(out);
        assert(outArg.getUses().empty() && "expected that out has no uses");
        BlockArgument inArg = op.getMatchingBlockArgument(in);
        rewriter.replaceAllUsesWith(inArg, outArg);
        assert(!op.payloadUsesValueFromOperand(in) &&
               "expected that the in operand is now unused");
      });

      state.resetCache();
    }

    return WalkResult::advance();
  });
  return success();
}