1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
//===- NamedOpConversions.cpp - Implements conversions between named ops --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements conversions between named ops that can be seens as
// canonicalizations of named ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSION
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::linalg;
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
}
static LogicalResult
matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
Value iZp, Value kZp, Value init, Attribute stride,
Attribute dilation, PatternRewriter &rewriter) {
Location loc = operation->getLoc();
auto linalgOp = dyn_cast<LinalgOp>(operation);
// Exit out on the memref version of this operation.
if (!linalgOp || !linalgOp.hasTensorSemantics())
return failure();
auto result = operation->getResult(0);
auto kernelTy = dyn_cast<RankedTensorType>(kernel.getType());
auto initTy = dyn_cast<RankedTensorType>(init.getType());
auto resultTy = dyn_cast<RankedTensorType>(result.getType());
if (!kernelTy || !initTy || !resultTy)
return failure();
if (kernelTy.getDimSize(3) != 1)
return failure();
// Collapse kernel dims.
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
auto newKernelTy = RankedTensorType::get(
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
kernelTy.getElementType());
auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
loc, newKernelTy, kernel, collapsedKernelDims);
// Collapse init dims.
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
getIndicesVector(3, 5)};
auto newInitTy =
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
initTy.getDimSize(2), initTy.getDimSize(3)},
initTy.getElementType());
auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
loc, newInitTy, init, collapsedInitDims);
SmallVector<NamedAttribute> preservedAttrs;
Operation *newConv =
TypeSwitch<Operation *, Operation *>(operation)
.Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) {
preservedAttrs = getPrunedAttributeList(op);
return rewriter.create<DepthwiseConv2DNhwcHwcOp>(
loc, newInitTy, ValueRange{input, collapsedKernel},
ValueRange{collapsedInit}, stride, dilation);
})
.Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) {
preservedAttrs = getPrunedAttributeList(op);
return rewriter.create<DepthwiseConv2DNhwcHwcQOp>(
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
ValueRange{collapsedInit}, stride, dilation);
})
.Default([](Operation *op) { return nullptr; });
if (!newConv)
return failure();
for (auto attr : preservedAttrs)
newConv->setAttr(attr.getName(), attr.getValue());
// Expand dimensions back out to
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
operation, resultTy, newConv->getResult(0), collapsedInitDims);
return success();
}
namespace {
struct SimplifyDepthwiseConvOp
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getDpsInputOperand(0)->get();
Value kernel = op.getDpsInputOperand(1)->get();
Value init = op.getDpsInitOperand(0)->get();
auto stride = op.getStrides();
auto dilation = op.getDilations();
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
nullptr, init, stride, dilation,
rewriter);
}
};
struct SimplifyDepthwiseConvQOp
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
PatternRewriter &rewriter) const override {
Operation *operation = op.getOperation();
Value input = op.getDpsInputOperand(0)->get();
Value kernel = op.getDpsInputOperand(1)->get();
Value iZp = op.getDpsInputOperand(2)->get();
Value kZp = op.getDpsInputOperand(3)->get();
Value init = op.getDpsInitOperand(0)->get();
auto stride = op.getStrides();
auto dilation = op.getDilations();
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
init, stride, dilation, rewriter);
}
};
struct LinalgNamedOpConversionPass
: public impl::LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
LinalgNamedOpConversionPass() = default;
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) = default;
void runOnOperation() override {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateLinalgNamedOpConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void mlir::linalg::populateLinalgNamedOpConversionPatterns(
RewritePatternSet &patterns) {
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
patterns.getContext());
}
std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
return std::make_unique<LinalgNamedOpConversionPass>();
}
|