1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
|
//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Complex dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "complex-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto spirvType =
getTypeConverter()->convertType<ShapedType>(constOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(constOp,
"unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
constOp, spirvType,
DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
return success();
}
};
struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(createOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(createOp,
"unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
createOp, spirvType, adaptor.getOperands());
return success();
}
};
struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(reOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
reOp, adaptor.getComplex(), llvm::ArrayRef(0));
return success();
}
};
struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(imOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
imOp, adaptor.getComplex(), llvm::ArrayRef(1));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
typeConverter, context);
}
|