1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
|
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Reduction Tree Pass class. It provides a framework for
// the implementation of different reduction passes in the MLIR Reduce tool. It
// allows for custom specification of the variant generation behavior. It
// implements methods that define the different possible traversals of the
// reduction tree.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ManagedStatic.h"
namespace mlir {
#define GEN_PASS_DEF_REDUCTIONTREE
#include "mlir/Reducer/Passes.h.inc"
} // namespace mlir
using namespace mlir;
/// We implicitly number each operation in the region and if an operation's
/// number falls into rangeToKeep, we need to keep it and apply the given
/// rewrite patterns on it.
static void applyPatterns(Region ®ion,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToKeep,
bool eraseOpNotInRange) {
std::vector<Operation *> opsNotInRange;
std::vector<Operation *> opsInRange;
size_t keepIndex = 0;
for (const auto &op : enumerate(region.getOps())) {
int index = op.index();
if (keepIndex < rangeToKeep.size() &&
index == rangeToKeep[keepIndex].second)
++keepIndex;
if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
opsNotInRange.push_back(&op.value());
else
opsInRange.push_back(&op.value());
}
// `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
// matching in above iteration. Besides, erase op not-in-range may end up in
// invalid module, so `applyOpPatternsAndFold` should come before that
// transform.
for (Operation *op : opsInRange) {
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
// because we don't have expectation this reduction will be success or not.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyOpPatternsAndFold(op, patterns, config);
}
if (eraseOpNotInRange)
for (Operation *op : opsNotInRange) {
op->dropAllUses();
op->erase();
}
}
/// We will apply the reducer patterns to the operations in the ranges specified
/// by ReductionNode. Note that we are not able to remove an operation without
/// replacing it with another valid operation. However, The validity of module
/// reduction is based on the Tester provided by the user and that means certain
/// invalid module is still interested by the use. Thus we provide an
/// alternative way to remove operations, which is using `eraseOpNotInRange` to
/// erase the operations not in the range specified by ReductionNode.
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test, bool eraseOpNotInRange) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
// node. Thus the root node must be interesting.
if (initStatus.first != Tester::Interestingness::True)
return module.emitWarning() << "uninterested module will not be reduced";
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, ranges, allocator);
// Duplicate the module for root node and locate the region in the copy.
if (failed(root->initialize(module, region)))
llvm_unreachable("unexpected initialization failure");
root->update(initStatus);
ReductionNode *smallestNode = root;
IteratorType iter(root);
while (iter != IteratorType::end()) {
ReductionNode ¤tNode = *iter;
Region &curRegion = currentNode.getRegion();
applyPatterns(curRegion, patterns, currentNode.getRanges(),
eraseOpNotInRange);
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
currentNode.getSize() < smallestNode->getSize())
smallestNode = ¤tNode;
++iter;
}
// At here, we have found an optimal path to reduce the given region. Retrieve
// the path and apply the reducer to it.
SmallVector<ReductionNode *> trace;
ReductionNode *curNode = smallestNode;
trace.push_back(curNode);
while (curNode != root) {
curNode = curNode->getParent();
trace.push_back(curNode);
}
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
}
if (test.isInteresting(module).first != Tester::Interestingness::True)
llvm::report_fatal_error("Reduced module is not interesting");
if (test.isInteresting(module).second != smallestNode->getSize())
llvm::report_fatal_error(
"Reduced module doesn't have consistent size with smallestNode");
return success();
}
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region ®ion,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
// We separate the reduction process into 2 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
/*eraseOpNotInRange=*/true)))
return failure();
// In the second phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
return findOptimal<IteratorType>(module, region, patterns, test,
/*eraseOpNotInRange=*/false);
}
namespace {
//===----------------------------------------------------------------------===//
// Reduction Pattern Interface Collection
//===----------------------------------------------------------------------===//
class ReductionPatternInterfaceCollection
: public DialectInterfaceCollection<DialectReductionPatternInterface> {
public:
using Base::Base;
// Collect the reduce patterns defined by each dialect.
void populateReductionPatterns(RewritePatternSet &pattern) const {
for (const DialectReductionPatternInterface &interface : *this)
interface.populateReductionPatterns(pattern);
}
};
//===----------------------------------------------------------------------===//
// ReductionTreePass
//===----------------------------------------------------------------------===//
/// This class defines the Reduction Tree Pass. It provides a framework to
/// to implement a reduction pass using a tree structure to keep track of the
/// generated reduced variants.
class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
public:
ReductionTreePass() = default;
ReductionTreePass(const ReductionTreePass &pass) = default;
LogicalResult initialize(MLIRContext *context) override;
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
private:
LogicalResult reduceOp(ModuleOp module, Region ®ion);
FrozenRewritePatternSet reducerPatterns;
};
} // namespace
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
RewritePatternSet patterns(context);
ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
reducerPatterns = std::move(patterns);
return success();
}
void ReductionTreePass::runOnOperation() {
Operation *topOperation = getOperation();
while (topOperation->getParentOp() != nullptr)
topOperation = topOperation->getParentOp();
ModuleOp module = dyn_cast<ModuleOp>(topOperation);
if (!module) {
emitError(getOperation()->getLoc())
<< "top-level op must be 'builtin.module'";
return signalPassFailure();
}
SmallVector<Operation *, 8> workList;
workList.push_back(getOperation());
do {
Operation *op = workList.pop_back_val();
for (Region ®ion : op->getRegions())
if (!region.empty())
if (failed(reduceOp(module, region)))
return signalPassFailure();
for (Region ®ion : op->getRegions())
for (Operation &op : region.getOps())
if (op.getNumRegions() != 0)
workList.push_back(&op);
} while (!workList.empty());
}
LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) {
Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
module, region, reducerPatterns, test);
default:
return module.emitError() << "unsupported traversal mode detected";
}
}
std::unique_ptr<Pass> mlir::createReductionTreePass() {
return std::make_unique<ReductionTreePass>();
}
|