1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
//===- ReconcileUnrealizedCasts.cpp - Eliminate noop unrealized casts -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_RECONCILEUNREALIZEDCASTS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
/// Folds the DAGs of `unrealized_conversion_cast`s that have as exit types
/// the same as the input ones.
/// For example, the DAGs `A -> B -> C -> B -> A` and `A -> B -> C -> A`
/// represent a noop within the IR, and thus the initial input values can be
/// propagated.
/// The same does not hold for 'open' chains chains of casts, such as
/// `A -> B -> C`. In this last case there is no cycle among the types and thus
/// the conversion is incomplete. The same hold for 'closed' chains like
/// `A -> B -> A`, but with the result of type `B` being used by some non-cast
/// operations.
/// Bifurcations (that is when a chain starts in between of another one) are
/// also taken into considerations, and all the above considerations remain
/// valid.
/// Special corner cases such as dead casts or single casts with same input and
/// output types are also covered.
struct UnrealizedConversionCastPassthrough
: public OpRewritePattern<UnrealizedConversionCastOp> {
using OpRewritePattern<UnrealizedConversionCastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(UnrealizedConversionCastOp op,
PatternRewriter &rewriter) const override {
// The nodes that either are not used by any operation or have at least
// one user that is not an unrealized cast.
DenseSet<UnrealizedConversionCastOp> exitNodes;
// The nodes whose users are all unrealized casts
DenseSet<UnrealizedConversionCastOp> intermediateNodes;
// Stack used for the depth-first traversal of the use-def DAG.
SmallVector<UnrealizedConversionCastOp, 2> visitStack;
visitStack.push_back(op);
while (!visitStack.empty()) {
UnrealizedConversionCastOp current = visitStack.pop_back_val();
auto users = current->getUsers();
bool isLive = false;
for (Operation *user : users) {
if (auto other = dyn_cast<UnrealizedConversionCastOp>(user)) {
if (other.getInputs() != current.getOutputs())
return rewriter.notifyMatchFailure(
op, "mismatching values propagation");
} else {
isLive = true;
}
// Continue traversing the DAG of unrealized casts
if (auto other = dyn_cast<UnrealizedConversionCastOp>(user))
visitStack.push_back(other);
}
// If the cast is live, then we need to check if the results of the last
// cast have the same type of the root inputs. It this is the case (e.g.
// `{A -> B, B -> A}`, but also `{A -> A}`), then the cycle is just a
// no-op and the inputs can be forwarded. If it's not (e.g.
// `{A -> B, B -> C}`, `{A -> B}`), then the cast chain is incomplete.
bool isCycle = current.getResultTypes() == op.getInputs().getTypes();
if (isLive && !isCycle)
return rewriter.notifyMatchFailure(op,
"live unrealized conversion cast");
bool isExitNode = users.empty() || isLive;
if (isExitNode) {
exitNodes.insert(current);
} else {
intermediateNodes.insert(current);
}
}
// Replace the sink nodes with the root input values
for (UnrealizedConversionCastOp exitNode : exitNodes)
rewriter.replaceOp(exitNode, op.getInputs());
// Erase all the other casts belonging to the DAG
for (UnrealizedConversionCastOp castOp : intermediateNodes)
rewriter.eraseOp(castOp);
return success();
}
};
/// Pass to simplify and eliminate unrealized conversion casts.
struct ReconcileUnrealizedCasts
: public impl::ReconcileUnrealizedCastsBase<ReconcileUnrealizedCasts> {
ReconcileUnrealizedCasts() = default;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateReconcileUnrealizedCastsPatterns(patterns);
ConversionTarget target(getContext());
target.addIllegalOp<UnrealizedConversionCastOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateReconcileUnrealizedCastsPatterns(
RewritePatternSet &patterns) {
patterns.add<UnrealizedConversionCastPassthrough>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::createReconcileUnrealizedCastsPass() {
return std::make_unique<ReconcileUnrealizedCasts>();
}
|