1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
|
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#define DEBUG_TYPE "convert-to-spirv"
namespace mlir {
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
if (runSignatureConversion) {
// Unroll vectors in function signatures to native vector size.
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
|