1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
|
//===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file Std transformations to expand Divs operation to help for the
// lowering to LLVM. Currently implemented transformations are Ceil and Floor
// for Signed Integers.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDOPS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
namespace {
/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
/// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
/// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
/// code.
///
/// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// ^bb0(%current: f32):
/// %1 = arith.maximumf %current, %fval : f32
/// memref.atomic_yield %1 : f32
/// }
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
PatternRewriter &rewriter) const final {
auto loc = op.getLoc();
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
loc, op.getMemref(), op.getIndices());
OpBuilder bodyBuilder =
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
Value lhs = genericOp.getCurrentValue();
Value rhs = op.getValue();
Value arithOp =
mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
rewriter.replaceOp(op, genericOp.getResult());
return success();
}
};
/// Converts `memref.reshape` that has a target shape of a statically-known
/// size to `memref.reinterpret_cast`.
struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ReshapeOp op,
PatternRewriter &rewriter) const final {
auto shapeType = cast<MemRefType>(op.getShape().getType());
if (!shapeType.hasStaticShape())
return failure();
int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
SmallVector<OpFoldResult, 4> sizes, strides;
sizes.resize(rank);
strides.resize(rank);
Location loc = op.getLoc();
Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (int i = rank - 1; i >= 0; --i) {
Value size;
// Load dynamic sizes from the shape input, use constants for static dims.
if (op.getType().isDynamicDim(i)) {
Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
if (!isa<IndexType>(size.getType()))
size = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), size);
sizes[i] = size;
} else {
auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
sizes[i] = sizeAttr;
}
strides[i] = stride;
if (i > 0)
stride = rewriter.create<arith::MulIOp>(loc, stride, size);
}
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
sizes, strides);
return success();
}
};
struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
void runOnOperation() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
memref::populateExpandOpsPatterns(patterns);
ConversionTarget target(ctx);
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
[](memref::AtomicRMWOp op) {
constexpr std::array shouldBeExpandedKinds = {
arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
});
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
});
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::memref::populateExpandOpsPatterns(RewritePatternSet &patterns) {
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
patterns.getContext());
}
std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
return std::make_unique<ExpandOpsPass>();
}
|