File: SparseAssembler.cpp

package info (click to toggle)
swiftlang 6.1.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 2,791,604 kB
  • sloc: cpp: 9,901,740; ansic: 2,201,431; asm: 1,091,827; python: 308,252; objc: 82,166; f90: 80,126; lisp: 38,358; pascal: 25,559; sh: 20,429; ml: 5,058; perl: 4,745; makefile: 4,484; awk: 3,535; javascript: 3,018; xml: 918; fortran: 664; cs: 573; ruby: 396
file content (249 lines) | stat: -rw-r--r-- 9,863 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;
using namespace sparse_tensor;

//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//

// Convert type range to new types range, with sparse tensors externalized.
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
                      SmallVectorImpl<Type> *extraTypes, bool directOut) {
  for (auto type : types) {
    // All "dense" data passes through unmodified.
    if (!getSparseTensorEncoding(type)) {
      convTypes.push_back(type);
      continue;
    }

    // Convert the external representations of the pos/crd/val arrays.
    const SparseTensorType stt(cast<RankedTensorType>(type));
    foreachFieldAndTypeInSparseTensor(
        stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
                                                 SparseTensorFieldKind kind,
                                                 Level, LevelType) {
          if (kind == SparseTensorFieldKind::PosMemRef ||
              kind == SparseTensorFieldKind::CrdMemRef ||
              kind == SparseTensorFieldKind::ValMemRef) {
            auto rtp = cast<ShapedType>(t);
            if (!directOut) {
              rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
              if (extraTypes)
                extraTypes->push_back(rtp);
            }
            convTypes.push_back(rtp);
          }
          return true;
        });
  }
}

// Convert input and output values to [dis]assemble ops for sparse tensors.
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
                     ValueRange fromVals, ValueRange extraVals,
                     SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
                     bool directOut) {
  unsigned idx = 0;
  for (auto type : types) {
    // All "dense" data passes through unmodified.
    if (!getSparseTensorEncoding(type)) {
      toVals.push_back(fromVals[idx++]);
      continue;
    }
    // Handle sparse data.
    auto rtp = cast<RankedTensorType>(type);
    const SparseTensorType stt(rtp);
    SmallVector<Value> inputs;
    SmallVector<Type> retTypes;
    SmallVector<Type> cntTypes;
    if (!isIn)
      inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble

    // Collect the external representations of the pos/crd/val arrays.
    foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                     SparseTensorFieldKind kind,
                                                     Level lv, LevelType) {
      if (kind == SparseTensorFieldKind::PosMemRef ||
          kind == SparseTensorFieldKind::CrdMemRef ||
          kind == SparseTensorFieldKind::ValMemRef) {
        if (isIn) {
          inputs.push_back(fromVals[idx++]);
        } else if (directOut) {
          Value mem;
          if (kind == SparseTensorFieldKind::PosMemRef)
            mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
                                                               lv);
          else if (kind == SparseTensorFieldKind::CrdMemRef)
            mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
                                                                 lv);
          else
            mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
          toVals.push_back(mem);
        } else {
          ShapedType rtp = cast<ShapedType>(t);
          rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
          inputs.push_back(extraVals[extra++]);
          retTypes.push_back(rtp);
          cntTypes.push_back(builder.getIndexType());
        }
      }
      return true;
    });

    if (isIn) {
      // Assemble multiple inputs into a single sparse tensor.
      auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
      toVals.push_back(a.getResult());
    } else if (!directOut) {
      // Disassemble a single sparse input into multiple outputs.
      // Note that this includes the counters, which are dropped.
      unsigned len = retTypes.size();
      retTypes.append(cntTypes);
      auto d =
          builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
      for (unsigned i = 0; i < len; i++)
        toVals.push_back(d.getResult(i));
    }
  }
}

//===----------------------------------------------------------------------===//
// Rewriting rules.
//===----------------------------------------------------------------------===//

namespace {

// A rewriting rules that converts public entry methods that use sparse tensors
// as input parameters and/or output return values into wrapper methods that
// [dis]assemble the individual tensors that constitute the actual storage used
// externally into MLIR sparse tensors before calling the original method.
//
// In particular, each sparse tensor input
//
// void foo(..., t, ...) { }
//
// makes the original foo() internal and adds the following wrapper method
//
// void foo(..., t1..tn, ...) {
//   t = assemble t1..tn
//   _internal_foo(..., t, ...)
// }
//
// and likewise, each output tensor
//
// ... T ... bar(...) { return ..., t, ...; }
//
// makes the original bar() internal and adds the following wrapper method
//
// ... T1..TN ... bar(..., t1'..tn') {
//   ..., t, ... = _internal_bar(...)
//   t1..tn = disassemble t, t1'..tn'
//   return ..., t1..tn, ...
// }
//
// (with a direct-out variant without the disassemble).
//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
  using OpRewritePattern::OpRewritePattern;

  SparseFuncAssembler(MLIRContext *context, bool dO)
      : OpRewritePattern(context), directOut(dO) {}

  LogicalResult matchAndRewrite(func::FuncOp funcOp,
                                PatternRewriter &rewriter) const override {
    // Only rewrite public entry methods.
    if (funcOp.isPrivate())
      return failure();

    // Translate sparse tensor types to external types.
    SmallVector<Type> inputTypes;
    SmallVector<Type> outputTypes;
    SmallVector<Type> extraTypes;
    convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
    convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);

    // Only sparse inputs or outputs need a wrapper method.
    if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
        outputTypes.size() == funcOp.getResultTypes().size())
      return failure();

    // Modify the original method into an internal, private method.
    auto orgName = funcOp.getName();
    std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
    funcOp.setName(wrapper);
    funcOp.setPrivate();

    // Start the new public wrapper method with original name.
    Location loc = funcOp.getLoc();
    ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
    MLIRContext *context = modOp.getContext();
    OpBuilder moduleBuilder(modOp.getBodyRegion());
    unsigned extra = inputTypes.size();
    inputTypes.append(extraTypes);
    auto func = moduleBuilder.create<func::FuncOp>(
        loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
    func.setPublic();

    // Construct new wrapper method body.
    OpBuilder::InsertionGuard insertionGuard(rewriter);
    Block *body = func.addEntryBlock();
    rewriter.setInsertionPointToStart(body);

    // Convert inputs.
    SmallVector<Value> inputs;
    convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
             ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);

    // Call the original, now private method. A subsequent inlining pass can
    // determine whether cloning the method body in place is worthwhile.
    auto org = SymbolRefAttr::get(context, wrapper);
    auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
                                              inputs);

    // Convert outputs and return.
    SmallVector<Value> outputs;
    convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
             body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
    rewriter.create<func::ReturnOp>(loc, outputs);

    // Finally, migrate a potential c-interface property.
    if (funcOp->getAttrOfType<UnitAttr>(
            LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
      func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
                    UnitAttr::get(context));
      funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
    }
    return success();
  }

private:
  const bool directOut;
};

} // namespace

//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//

void mlir::populateSparseAssembler(RewritePatternSet &patterns,
                                   bool directOut) {
  patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
}