1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This is intended to be a simple high-level (target-agnostic) matmul
// transposition transformation.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "linalg-transpose-matmul"
using namespace mlir;
using namespace mlir::linalg;
/// Pattern to replace
///
/// linalg.matmul(a, b)
///
/// with
///
/// linalg.matmul_transpose_a(linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");
Location loc = matmulOp.getLoc();
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
}
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
/// Pattern to replace
///
/// linalg.batch_matmul(a, b)
///
/// with
///
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");
Location loc = batchMatmulOp.getLoc();
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (type.isDynamicDim(2))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
}
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
namespace {
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
LogicalResult matchAndRewrite(linalg::MatmulOp op,
PatternRewriter &rewriter) const override {
if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
return failure();
}
return success();
}
private:
bool transposeLHS;
};
struct TransposeBatchMatmul final
: public OpRewritePattern<linalg::BatchMatmulOp> {
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
PatternRewriter &rewriter) const override {
if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
return failure();
}
return success();
}
private:
bool transposeLHS;
};
} // namespace
void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS) {
patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
transposeLHS);
}
|