1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
|
//===-- MyExtension.cpp - Transform dialect tutorial ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines Transform dialect extension operations used in the
// Chapter 3 of the Transform dialect tutorial.
//
//===----------------------------------------------------------------------===//
#include "MyExtension.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/TypeSwitch.h"
#define GET_TYPEDEF_CLASSES
#include "MyExtensionTypes.cpp.inc"
#define GET_OP_CLASSES
#include "MyExtension.cpp.inc"
//===---------------------------------------------------------------------===//
// MyExtension
//===---------------------------------------------------------------------===//
// Define a new transform dialect extension. This uses the CRTP idiom to
// identify extensions.
class MyExtension
: public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
// The extension must derive the base constructor.
using Base::Base;
// This function initializes the extension, similarly to `initialize` in
// dialect definitions. List individual operations and dependent dialects
// here.
void init();
};
void MyExtension::init() {
// Similarly to dialects, an extension can declare a dependent dialect. This
// dialect will be loaded along with the extension and, therefore, along with
// the Transform dialect. Only declare as dependent the dialects that contain
// the attributes or types used by transform operations. Do NOT declare as
// dependent the dialects produced during the transformation.
// declareDependentDialect<MyDialect>();
// When transformations are applied, they may produce new operations from
// previously unloaded dialects. Typically, a pass would need to declare
// itself dependent on the dialects containing such new operations. To avoid
// confusion with the dialects the extension itself depends on, the Transform
// dialects differentiates between:
// - dependent dialects, which are used by the transform operations, and
// - generated dialects, which contain the entities (attributes, operations,
// types) that may be produced by applying the transformation even when
// not present in the original payload IR.
// In the following chapter, we will be add operations that generate function
// calls and structured control flow operations, so let's declare the
// corresponding dialects as generated.
declareGeneratedDialect<::mlir::scf::SCFDialect>();
declareGeneratedDialect<::mlir::func::FuncDialect>();
// Register the additional transform dialect types with the dialect. List all
// types generated from ODS.
registerTypes<
#define GET_TYPEDEF_LIST
#include "MyExtensionTypes.cpp.inc"
>();
// ODS generates these helpers for type printing and parsing, but the
// Transform dialect provides its own support for types supplied by the
// extension. Reference these functions to avoid a compiler warning.
(void)&generatedTypeParser;
(void)&generatedTypePrinter;
// Finally, we register the additional transform operations with the dialect.
// List all operations generated from ODS. This call will perform additional
// checks that the operations implement the transform and memory effect
// interfaces required by the dialect interpreter and assert if they do not.
registerTransformOps<
#define GET_OP_LIST
#include "MyExtension.cpp.inc"
>();
}
//===---------------------------------------------------------------------===//
// ChangeCallTargetOp
//===---------------------------------------------------------------------===//
static void updateCallee(mlir::func::CallOp call, llvm::StringRef newTarget) {
call.setCallee(newTarget);
}
// Implementation of our transform dialect operation.
// This operation returns a tri-state result that can be one of:
// - success when the transformation succeeded;
// - definite failure when the transformation failed in such a way that
// following
// transformations are impossible or undesirable, typically it could have left
// payload IR in an invalid state; it is expected that a diagnostic is emitted
// immediately before returning the definite error;
// - silenceable failure when the transformation failed but following
// transformations
// are still applicable, typically this means a precondition for the
// transformation is not satisfied and the payload IR has not been modified.
// The silenceable failure additionally carries a Diagnostic that can be emitted
// to the user.
::mlir::DiagnosedSilenceableFailure
mlir::transform::ChangeCallTargetOp::applyToOne(
// The rewriter that should be used when modifying IR.
::mlir::transform::TransformRewriter &rewriter,
// The single payload operation to which the transformation is applied.
::mlir::func::CallOp call,
// The payload IR entities that will be appended to lists associated with
// the results of this transform operation. This list contains one entry per
// result.
::mlir::transform::ApplyToEachResultList &results,
// The transform application state. This object can be used to query the
// current associations between transform IR values and payload IR entities.
// It can also carry additional user-defined state.
::mlir::transform::TransformState &state) {
// Dispatch to the actual transformation.
updateCallee(call, getNewTarget());
// If everything went well, return success.
return DiagnosedSilenceableFailure::success();
}
void mlir::transform::ChangeCallTargetOp::getEffects(
::llvm::SmallVectorImpl<::mlir::MemoryEffects::EffectInstance> &effects) {
// Indicate that the `call` handle is only read by this operation because the
// associated operation is not erased but rather modified in-place, so the
// reference to it remains valid.
onlyReadsHandle(getCall(), effects);
// Indicate that the payload is modified by this operation.
modifiesPayload(effects);
}
//===---------------------------------------------------------------------===//
// CallToOp
//===---------------------------------------------------------------------===//
static mlir::Operation *replaceCallWithOp(mlir::RewriterBase &rewriter,
mlir::CallOpInterface call) {
// Construct an operation from an unregistered dialect. This is discouraged
// and is only used here for brevity of the overall example.
mlir::OperationState state(call.getLoc(), "my.mm4");
state.types.assign(call->result_type_begin(), call->result_type_end());
state.operands.assign(call->operand_begin(), call->operand_end());
mlir::Operation *replacement = rewriter.create(state);
rewriter.replaceOp(call, replacement->getResults());
return replacement;
}
// See above for the signature description.
mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne(
mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call,
mlir::transform::ApplyToEachResultList &results,
mlir::transform::TransformState &state) {
// Dispatch to the actual transformation.
Operation *replacement = replaceCallWithOp(rewriter, call);
// Associate the payload operation produced by the rewrite with the result
// handle of this transform operation.
results.push_back(replacement);
// If everything went well, return success.
return DiagnosedSilenceableFailure::success();
}
//===---------------------------------------------------------------------===//
// CallOpInterfaceHandleType
//===---------------------------------------------------------------------===//
// The interface declares this method to verify constraints this type has on
// payload operations. It returns the now familiar tri-state result.
mlir::DiagnosedSilenceableFailure
mlir::transform::CallOpInterfaceHandleType::checkPayload(
// Location at which diagnostics should be emitted.
mlir::Location loc,
// List of payload operations that are about to be associated with the
// handle that has this type.
llvm::ArrayRef<mlir::Operation *> payload) const {
// All payload operations are expected to implement CallOpInterface, check
// this.
for (Operation *op : payload) {
if (llvm::isa<mlir::CallOpInterface>(op))
continue;
// By convention, these verifiers always emit a silenceable failure since
// they are checking a precondition.
DiagnosedSilenceableFailure diag =
emitSilenceableError(loc)
<< "expected the payload operation to implement CallOpInterface";
diag.attachNote(op->getLoc()) << "offending operation";
return diag;
}
// If everything is okay, return success.
return DiagnosedSilenceableFailure::success();
}
//===---------------------------------------------------------------------===//
// Extension registration
//===---------------------------------------------------------------------===//
void registerMyExtension(::mlir::DialectRegistry ®istry) {
registry.addExtensions<MyExtension>();
}
|