File: DecomposeAffineOps.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (172 lines) | stat: -rw-r--r-- 7,226 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
//===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functionality to progressively decompose coarse-grained
// affine ops into finer-grained ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

using namespace mlir;
using namespace mlir::affine;

#define DEBUG_TYPE "decompose-affine-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")

/// Count the number of loops surrounding `operand` such that operand could be
/// hoisted above.
/// Stop counting at the first loop over which the operand cannot be hoisted.
static int64_t numEnclosingInvariantLoops(OpOperand &operand) {
  int64_t count = 0;
  Operation *currentOp = operand.getOwner();
  while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
    if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
      break;
    currentOp = loopOp;
    count++;
  }
  return count;
}

void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
                                                 AffineApplyOp op) {
  SmallVector<int64_t> numInvariant = llvm::to_vector(
      llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) {
        return numEnclosingInvariantLoops(operand);
      }));

  int64_t numOperands = op.getNumOperands();
  SmallVector<int64_t> operandPositions =
      llvm::to_vector(llvm::seq<int64_t>(0, numOperands));
  llvm::stable_sort(operandPositions, [&numInvariant](size_t i1, size_t i2) {
    return numInvariant[i1] > numInvariant[i2];
  });

  SmallVector<AffineExpr> replacements(numOperands);
  SmallVector<Value> operands(numOperands);
  for (int64_t i = 0; i < numOperands; ++i) {
    operands[i] = op.getOperand(operandPositions[i]);
    replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext());
  }

  AffineMap map = op.getAffineMap();
  ArrayRef<AffineExpr> repls{replacements};
  map = map.replaceDimsAndSymbols(repls.take_front(map.getNumDims()),
                                  repls.drop_front(map.getNumDims()),
                                  /*numResultDims=*/0,
                                  /*numResultSyms=*/numOperands);
  map = AffineMap::get(0, numOperands,
                       simplifyAffineExpr(map.getResult(0), 0, numOperands),
                       op->getContext());
  canonicalizeMapAndOperands(&map, &operands);

  rewriter.startRootUpdate(op);
  op.setMap(map);
  op->setOperands(operands);
  rewriter.finalizeRootUpdate(op);
}

/// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
/// map and with the same operands.
/// Canonicalize the map and operands to deduplicate and drop dead operands
/// before returning but do not perform maximal composition of AffineApplyOp
/// which would defeat the purpose.
static AffineApplyOp createSubApply(RewriterBase &rewriter,
                                    AffineApplyOp originalOp, AffineExpr expr) {
  MLIRContext *ctx = originalOp->getContext();
  AffineMap m = originalOp.getAffineMap();
  auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx);
  SmallVector<Value> rhsOperands = originalOp->getOperands();
  canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
  return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
                                        rhsOperands);
}

FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
                                                 AffineApplyOp op) {
  // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a
  // top-level binary expression that we can reassociate (i.e. add or mul).
  AffineMap m = op.getAffineMap();
  if (m.getNumDims() > 0)
    return rewriter.notifyMatchFailure(op, "expected no dims");

  AffineExpr remainingExp = m.getResult(0);
  auto binExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
  if (!binExpr)
    return rewriter.notifyMatchFailure(op, "terminal affine.apply");

  if (!binExpr.getLHS().isa<AffineBinaryOpExpr>() &&
      !binExpr.getRHS().isa<AffineBinaryOpExpr>())
    return rewriter.notifyMatchFailure(op, "terminal affine.apply");

  bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
                        (binExpr.getKind() == AffineExprKind::Mul));
  if (!supportedKind)
    return rewriter.notifyMatchFailure(
        op, "only add or mul binary expr can be reassociated");

  LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");

  // 2. Iteratively extract the RHS subexpressions while the top-level binary
  // expr kind remains the same.
  MLIRContext *ctx = op->getContext();
  SmallVector<AffineExpr> subExpressions;
  while (true) {
    auto currentBinExpr = remainingExp.dyn_cast<AffineBinaryOpExpr>();
    if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
      subExpressions.push_back(remainingExp);
      LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
      break;
    }
    subExpressions.push_back(currentBinExpr.getRHS());
    LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
    remainingExp = currentBinExpr.getLHS();
  }

  // 3. Reorder subExpressions by the min symbol they are a function of.
  // This also takes care of properly reordering local variables.
  // This however won't be able to split expression that cannot be reassociated
  // such as ones that involve divs and multiple symbols.
  auto getMaxSymbol = [&](AffineExpr e) -> int64_t {
    for (int64_t i = m.getNumSymbols(); i >= 0; --i)
      if (e.isFunctionOfSymbol(i))
        return i;
    return -1;
  };
  llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
    return getMaxSymbol(e1) < getMaxSymbol(e2);
  });
  LLVM_DEBUG(
      llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
      llvm::dbgs() << "\n");

  // 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
  auto s0 = getAffineSymbolExpr(0, ctx);
  auto s1 = getAffineSymbolExpr(1, ctx);
  AffineMap binMap = AffineMap::get(
      /*dimCount=*/0, /*symbolCount=*/2,
      getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx);

  auto current = createSubApply(rewriter, op, subExpressions[0]);
  for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
    Value tmp = createSubApply(rewriter, op, subExpressions[i]);
    current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
                                             ValueRange{current, tmp});
    LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
  }

  // 5. Replace original op.
  rewriter.replaceOp(op, current.getResult());
  return current;
}