1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
|
//===- Transforms.cpp ---------------------------------------------- C++ --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
#include "TransformsDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <iterator>
#include <numeric>
namespace mlir::mesh {
namespace {
/// Lower `mesh.process_multi_index` into expression using
/// `mesh.process_linear_index` and `mesh.mesh_shape`.
struct ProcessMultiIndexOpLowering
: OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> {
using OpRewritePatternWithSymbolTableCollection::
OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(ProcessMultiIndexOp op,
PatternRewriter &rewriter) const override {
MeshOp mesh = getMesh(op, symbolTableCollection);
if (!mesh) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value linearIndex = builder.create<ProcessLinearIndexOp>(mesh);
ValueRange meshShape = builder.create<MeshShapeOp>(mesh).getResults();
SmallVector<Value> completeMultiIndex =
builder.create<affine::AffineDelinearizeIndexOp>(linearIndex, meshShape)
.getMultiIndex();
SmallVector<Value> multiIndex;
ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
SmallVector<MeshAxis> opAxesIota;
if (opMeshAxes.empty()) {
opAxesIota.resize(mesh.getRank());
std::iota(opAxesIota.begin(), opAxesIota.end(), 0);
opMeshAxes = opAxesIota;
}
llvm::transform(opMeshAxes, std::back_inserter(multiIndex),
[&completeMultiIndex](MeshAxis meshAxis) {
return completeMultiIndex[meshAxis];
});
rewriter.replaceAllUsesWith(op.getResults(), multiIndex);
return success();
}
};
struct AllSliceOpLowering
: OpRewritePatternWithSymbolTableCollection<AllSliceOp> {
using OpRewritePatternWithSymbolTableCollection::
OpRewritePatternWithSymbolTableCollection;
LogicalResult matchAndRewrite(AllSliceOp op,
PatternRewriter &rewriter) const override {
// 1. Compute the process linear index inside the process group from its
// multi-index.
//
// 2. Extract a slice from the input tensor.
// All axes except the slicing axis are not interesting and take the full
// axis.
// The slice axis is split into equisized parts with count
// the number of processes in the collective process group induced by
// the mesh axes.
// The part for each process is determined by the corresponding
// linear-index in the process group.
//
// There are no collectives that require communication.
// Each process operates on its local tensor.
MeshOp mesh = getMesh(op, symbolTableCollection);
if (!mesh) {
return failure();
}
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
builder.setInsertionPointAfter(op.getOperation());
Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
Operation::result_range processInGroupMultiIndex =
builder.create<ProcessMultiIndexOp>(mesh.getSymName(), op.getMeshAxes())
.getResults();
Operation::result_range processGroupShape =
builder.create<MeshShapeOp>(mesh.getSymName(), op.getMeshAxes())
.getResult();
Value processGroupSize =
createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder);
int64_t sliceAxis = op.getSliceAxis().getSExtValue();
Value operandSliceAxisSize =
builder.create<tensor::DimOp>(op.getOperand(), sliceAxis);
Value operandSliceAxisSizeModProcessGroupSize =
builder.create<arith::RemUIOp>(operandSliceAxisSize, processGroupSize);
Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
arith::CmpIPredicate::eq, operandSliceAxisSizeModProcessGroupSize,
zero);
builder.create<cf::AssertOp>(isTargetShapeExactlyDivisible,
"Slicing a tensor with axis size that is "
"not exactly divisible by the "
"mesh process group size is not supported.");
Value resultSliceAxisSize =
builder.create<arith::DivUIOp>(operandSliceAxisSize, processGroupSize);
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
// insert tensor.extract_slice
RankedTensorType operandType =
cast<RankedTensorType>(op.getOperand().getType());
SmallVector<OpFoldResult> sizes;
for (int64_t i = 0; i < operandType.getRank(); ++i) {
if (i == sliceAxis) {
sizes.emplace_back(resultSliceAxisSize);
} else {
Value dimSize = builder.create<tensor::DimOp>(op.getOperand(), i);
sizes.emplace_back(dimSize);
}
}
SmallVector<OpFoldResult> offsets(
operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 0));
offsets[sliceAxis] =
ArithBuilder(builder, builder.getLoc())
.mul(getValueOrCreateConstantIndexOp(builder, builder.getLoc(),
processInGroupLinearIndex),
resultSliceAxisSize);
SmallVector<OpFoldResult> strides(
operandType.getRank(), getAsIndexOpFoldResult(builder.getContext(), 1));
Value slice = builder.create<tensor::ExtractSliceOp>(
op.getOperand(), offsets, sizes, strides);
Value newResult =
builder.create<tensor::CastOp>(op.getResult().getType(), slice);
rewriter.replaceAllUsesWith(op.getResult(), newResult);
return success();
}
};
} // namespace
void populateProcessMultiIndexOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<ProcessMultiIndexOpLowering>(symbolTableCollection,
patterns.getContext());
}
void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, mesh::MeshDialect>();
}
void populateAllSliceOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
patterns.add<AllSliceOpLowering>(symbolTableCollection,
patterns.getContext());
}
void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) {
registry.insert<affine::AffineDialect, arith::ArithDialect,
cf::ControlFlowDialect, mesh::MeshDialect,
tensor::TensorDialect>();
}
void populateAllOpLoweringPatterns(
RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection);
populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
}
void registerAllOpLoweringDialects(DialectRegistry ®istry) {
registerProcessMultiIndexOpLoweringDialects(registry);
registerAllSliceOpLoweringDialects(registry);
}
TypedValue<IndexType>
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder) {
Operation::result_range meshShape =
builder.create<mesh::MeshShapeOp>(mesh, axes).getResults();
return cast<TypedValue<IndexType>>(arith::createProduct(
builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape),
builder.getIndexType()));
}
TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
ArrayRef<MeshAxis> meshAxes,
ImplicitLocOpBuilder &builder) {
ResultRange processInGroupMultiIndex =
builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
Operation::result_range processGroupShape =
builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
}
} // namespace mlir::mesh
|