1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
|
//===- TestDialectConversion.cpp - Test DialectConversion functionality ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace test;
//===----------------------------------------------------------------------===//
// Test PDLL Support
//===----------------------------------------------------------------------===//
#include "TestDialectConversionPDLLPatterns.h.inc"
namespace {
struct PDLLTypeConverter : public TypeConverter {
PDLLTypeConverter() {
addConversion(convertType);
addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
// Convert I64 to F64.
if (t.isSignlessInteger(64)) {
results.push_back(FloatType::getF64(t.getContext()));
return success();
}
// Otherwise, convert the type directly.
results.push_back(t);
return success();
}
/// Hook for materializing a conversion.
static std::optional<Value> materializeCast(OpBuilder &builder,
Type resultType,
ValueRange inputs, Location loc) {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
}
};
struct TestDialectConversionPDLLPass
: public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass)
StringRef getArgument() const final { return "test-dialect-conversion-pdll"; }
StringRef getDescription() const final {
return "Test DialectConversion PDLL functionality";
}
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>();
}
LogicalResult initialize(MLIRContext *ctx) override {
// Build the pattern set within the `initialize` to avoid recompiling PDL
// patterns during each `runOnOperation` invocation.
RewritePatternSet patternList(ctx);
registerConversionPDLFunctions(patternList);
populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter));
patterns = std::move(patternList);
return success();
}
void runOnOperation() final {
mlir::ConversionTarget target(getContext());
target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
target.addDynamicallyLegalDialect<TestDialect>(
[this](Operation *op) { return converter.isLegal(op); });
if (failed(mlir::applyFullConversion(getOperation(), target, patterns)))
signalPassFailure();
}
FrozenRewritePatternSet patterns;
PDLLTypeConverter converter;
};
} // namespace
namespace mlir {
namespace test {
void registerTestDialectConversionPasses() {
PassRegistration<TestDialectConversionPDLLPass>();
}
} // namespace test
} // namespace mlir
|