1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
|
//===- LowerMemorySpaceAttributes.cpp ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// Implementation of a pass that rewrites the IR so that uses of
/// `gpu::AddressSpaceAttr` in memref memory space annotations are replaced
/// with caller-specified numeric values.
///
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
namespace mlir {
#define GEN_PASS_DEF_GPULOWERMEMORYSPACEATTRIBUTESPASS
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::gpu;
//===----------------------------------------------------------------------===//
// Conversion Target
//===----------------------------------------------------------------------===//
/// Returns true if the given `type` is considered as legal during memory space
/// attribute lowering.
static bool isLegalType(Type type) {
if (auto memRefType = type.dyn_cast<BaseMemRefType>()) {
return !memRefType.getMemorySpace()
.isa_and_nonnull<gpu::AddressSpaceAttr>();
}
return true;
}
/// Returns true if the given `attr` is considered legal during memory space
/// attribute lowering.
static bool isLegalAttr(Attribute attr) {
if (auto typeAttr = attr.dyn_cast<TypeAttr>())
return isLegalType(typeAttr.getValue());
return true;
}
/// Returns true if the given `op` is legal during memory space attribute
/// lowering.
static bool isLegalOp(Operation *op) {
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
return llvm::all_of(funcOp.getArgumentTypes(), isLegalType) &&
llvm::all_of(funcOp.getResultTypes(), isLegalType) &&
llvm::all_of(funcOp.getFunctionBody().getArgumentTypes(),
isLegalType);
}
auto attrs = llvm::map_range(op->getAttrs(), [](const NamedAttribute &attr) {
return attr.getValue();
});
return llvm::all_of(op->getOperandTypes(), isLegalType) &&
llvm::all_of(op->getResultTypes(), isLegalType) &&
llvm::all_of(attrs, isLegalAttr);
}
void gpu::populateLowerMemorySpaceOpLegality(ConversionTarget &target) {
target.markUnknownOpDynamicallyLegal(isLegalOp);
}
//===----------------------------------------------------------------------===//
// Type Converter
//===----------------------------------------------------------------------===//
IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}
void mlir::gpu::populateMemorySpaceAttributeTypeConversions(
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
typeConverter.addConversion([mapping](Type type) -> std::optional<Type> {
auto subElementType = type.dyn_cast_or_null<SubElementTypeInterface>();
if (!subElementType)
return type;
Type newType = subElementType.replaceSubElements(
[mapping](Attribute attr) -> std::optional<Attribute> {
auto memorySpaceAttr = attr.dyn_cast_or_null<gpu::AddressSpaceAttr>();
if (!memorySpaceAttr)
return std::nullopt;
auto newValue = wrapNumericMemorySpace(
attr.getContext(), mapping(memorySpaceAttr.getValue()));
return newValue;
});
return newType;
});
}
namespace {
/// Converts any op that has operands/results/attributes with numeric MemRef
/// memory spaces.
struct LowerMemRefAddressSpacePattern final : public ConversionPattern {
LowerMemRefAddressSpacePattern(MLIRContext *context, TypeConverter &converter)
: ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<NamedAttribute> newAttrs;
newAttrs.reserve(op->getAttrs().size());
for (auto attr : op->getAttrs()) {
if (auto typeAttr = attr.getValue().dyn_cast<TypeAttr>()) {
auto newAttr = getTypeConverter()->convertType(typeAttr.getValue());
newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
} else {
newAttrs.push_back(attr);
}
}
SmallVector<Type> newResults;
(void)getTypeConverter()->convertTypes(op->getResultTypes(), newResults);
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
newResults, newAttrs, op->getSuccessors());
for (Region ®ion : op->getRegions()) {
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
TypeConverter::SignatureConversion result(newRegion->getNumArguments());
(void)getTypeConverter()->convertSignatureArgs(
newRegion->getArgumentTypes(), result);
rewriter.applySignatureConversion(newRegion, result);
}
Operation *newOp = rewriter.create(state);
rewriter.replaceOp(op, newOp->getResults());
return success();
}
};
} // namespace
void mlir::gpu::populateMemorySpaceLoweringPatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<LowerMemRefAddressSpacePattern>(patterns.getContext(),
typeConverter);
}
namespace {
class LowerMemorySpaceAttributesPass
: public mlir::impl::GPULowerMemorySpaceAttributesPassBase<
LowerMemorySpaceAttributesPass> {
public:
using Base::Base;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
ConversionTarget target(getContext());
populateLowerMemorySpaceOpLegality(target);
TypeConverter typeConverter;
typeConverter.addConversion([](Type t) { return t; });
populateMemorySpaceAttributeTypeConversions(
typeConverter, [this](AddressSpace space) -> unsigned {
switch (space) {
case AddressSpace::Global:
return globalAddrSpace;
case AddressSpace::Workgroup:
return workgroupAddrSpace;
case AddressSpace::Private:
return privateAddrSpace;
}
llvm_unreachable("unknown address space enum value");
return 0;
});
RewritePatternSet patterns(context);
populateMemorySpaceLoweringPatterns(typeConverter, patterns);
if (failed(applyFullConversion(op, target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
|