1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
|
//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Utils/CodegenUtils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace sparse_tensor;
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
// Convert type range to new types range, with sparse tensors externalized.
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
// All "dense" data passes through unmodified.
if (!getSparseTensorEncoding(type)) {
convTypes.push_back(type);
continue;
}
// Convert the external representations of the pos/crd/val arrays.
const SparseTensorType stt(cast<RankedTensorType>(type));
foreachFieldAndTypeInSparseTensor(
stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
auto rtp = cast<ShapedType>(t);
if (!directOut) {
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
if (extraTypes)
extraTypes->push_back(rtp);
}
convTypes.push_back(rtp);
}
return true;
});
}
}
// Convert input and output values to [dis]assemble ops for sparse tensors.
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
ValueRange fromVals, ValueRange extraVals,
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
bool directOut) {
unsigned idx = 0;
for (auto type : types) {
// All "dense" data passes through unmodified.
if (!getSparseTensorEncoding(type)) {
toVals.push_back(fromVals[idx++]);
continue;
}
// Handle sparse data.
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
if (!isIn)
inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
// Collect the external representations of the pos/crd/val arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level lv, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else if (directOut) {
Value mem;
if (kind == SparseTensorFieldKind::PosMemRef)
mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
lv);
else if (kind == SparseTensorFieldKind::CrdMemRef)
mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
lv);
else
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
toVals.push_back(mem);
} else {
ShapedType rtp = cast<ShapedType>(t);
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
cntTypes.push_back(builder.getIndexType());
}
}
return true;
});
if (isIn) {
// Assemble multiple inputs into a single sparse tensor.
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
toVals.push_back(a.getResult());
} else if (!directOut) {
// Disassemble a single sparse input into multiple outputs.
// Note that this includes the counters, which are dropped.
unsigned len = retTypes.size();
retTypes.append(cntTypes);
auto d =
builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
for (unsigned i = 0; i < len; i++)
toVals.push_back(d.getResult(i));
}
}
}
//===----------------------------------------------------------------------===//
// Rewriting rules.
//===----------------------------------------------------------------------===//
namespace {
// A rewriting rules that converts public entry methods that use sparse tensors
// as input parameters and/or output return values into wrapper methods that
// [dis]assemble the individual tensors that constitute the actual storage used
// externally into MLIR sparse tensors before calling the original method.
//
// In particular, each sparse tensor input
//
// void foo(..., t, ...) { }
//
// makes the original foo() internal and adds the following wrapper method
//
// void foo(..., t1..tn, ...) {
// t = assemble t1..tn
// _internal_foo(..., t, ...)
// }
//
// and likewise, each output tensor
//
// ... T ... bar(...) { return ..., t, ...; }
//
// makes the original bar() internal and adds the following wrapper method
//
// ... T1..TN ... bar(..., t1'..tn') {
// ..., t, ... = _internal_bar(...)
// t1..tn = disassemble t, t1'..tn'
// return ..., t1..tn, ...
// }
//
// (with a direct-out variant without the disassemble).
//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
SparseFuncAssembler(MLIRContext *context, bool dO)
: OpRewritePattern(context), directOut(dO) {}
LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
// Only rewrite public entry methods.
if (funcOp.isPrivate())
return failure();
// Translate sparse tensor types to external types.
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<Type> extraTypes;
convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
// Only sparse inputs or outputs need a wrapper method.
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
outputTypes.size() == funcOp.getResultTypes().size())
return failure();
// Modify the original method into an internal, private method.
auto orgName = funcOp.getName();
std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
funcOp.setName(wrapper);
funcOp.setPrivate();
// Start the new public wrapper method with original name.
Location loc = funcOp.getLoc();
ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
MLIRContext *context = modOp.getContext();
OpBuilder moduleBuilder(modOp.getBodyRegion());
unsigned extra = inputTypes.size();
inputTypes.append(extraTypes);
auto func = moduleBuilder.create<func::FuncOp>(
loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
func.setPublic();
// Construct new wrapper method body.
OpBuilder::InsertionGuard insertionGuard(rewriter);
Block *body = func.addEntryBlock();
rewriter.setInsertionPointToStart(body);
// Convert inputs.
SmallVector<Value> inputs;
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
// Call the original, now private method. A subsequent inlining pass can
// determine whether cloning the method body in place is worthwhile.
auto org = SymbolRefAttr::get(context, wrapper);
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
inputs);
// Convert outputs and return.
SmallVector<Value> outputs;
convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
rewriter.create<func::ReturnOp>(loc, outputs);
// Finally, migrate a potential c-interface property.
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
}
return success();
}
private:
const bool directOut;
};
} // namespace
//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//
void mlir::populateSparseAssembler(RewritePatternSet &patterns,
bool directOut) {
patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
}
|