1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
|
//===- AVXTranspose.cpp - Lower Vector transpose to AVX -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements vector.transpose rewrites as AVX patterns for particular
// sizes of interest.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;
using namespace mlir::x86vector::avx2;
using namespace mlir::x86vector::avx2::inline_asm;
using namespace mlir::x86vector::avx2::intrin;
Value mlir::x86vector::avx2::inline_asm::mm256BlendPsAsm(
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
auto asmDialectAttr =
LLVM::AsmDialectAttr::get(b.getContext(), LLVM::AsmDialect::AD_Intel);
const auto *asmTp = "vblendps $0, $1, $2, {0}";
const auto *asmCstr =
"=x,x,x"; // Careful: constraint parser is very brittle: no ws!
SmallVector<Value> asmVals{v1, v2};
auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, /*width=*/2)).str();
auto asmOp = b.create<LLVM::InlineAsmOp>(
v1.getType(), /*operands=*/asmVals, /*asm_string=*/asmStr,
/*constraints=*/asmCstr, /*has_side_effects=*/false,
/*is_align_stack=*/false, /*asm_dialect=*/asmDialectAttr,
/*operand_attrs=*/ArrayAttr());
return asmOp.getResult(0);
}
Value mlir::x86vector::avx2::intrin::mm256UnpackLoPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) {
return b.create<vector::ShuffleOp>(
v1, v2, ArrayRef<int64_t>{0, 8, 1, 9, 4, 12, 5, 13});
}
Value mlir::x86vector::avx2::intrin::mm256UnpackHiPs(ImplicitLocOpBuilder &b,
Value v1, Value v2) {
return b.create<vector::ShuffleOp>(
v1, v2, ArrayRef<int64_t>{2, 10, 3, 11, 6, 14, 7, 15});
}
/// a a b b a a b b
/// Takes an 8 bit mask, 2 bit for each position of a[0, 3) **and** b[0, 4):
/// 0:127 | 128:255
/// b01 b23 C8 D8 | b01+4 b23+4 C8+4 D8+4
Value mlir::x86vector::avx2::intrin::mm256ShufflePs(ImplicitLocOpBuilder &b,
Value v1, Value v2,
uint8_t mask) {
uint8_t b01, b23, b45, b67;
MaskHelper::extractShuffle(mask, b01, b23, b45, b67);
SmallVector<int64_t> shuffleMask{b01, b23, b45 + 8, b67 + 8,
b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
}
// imm[0:1] out of imm[0:3] is:
// 0 1 2 3
// a[0:127] or a[128:255] or b[0:127] or b[128:255] |
// a[0:127] or a[128:255] or b[0:127] or b[128:255]
// 0 1 2 3
// imm[0:1] out of imm[4:7].
Value mlir::x86vector::avx2::intrin::mm256Permute2f128Ps(
ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask) {
SmallVector<int64_t> shuffleMask;
auto appendToMask = [&](uint8_t control) {
if (control == 0)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{0, 1, 2, 3});
else if (control == 1)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{4, 5, 6, 7});
else if (control == 2)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{8, 9, 10, 11});
else if (control == 3)
llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});
else
llvm_unreachable("control > 3 : overflow");
};
uint8_t b03, b47;
MaskHelper::extractPermute(mask, b03, b47);
appendToMask(b03);
appendToMask(b47);
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
}
/// If bit i of `mask` is zero, take f32@i from v1 else take it from v2.
Value mlir::x86vector::avx2::intrin::mm256BlendPs(ImplicitLocOpBuilder &b,
Value v1, Value v2,
uint8_t mask) {
SmallVector<int64_t, 8> shuffleMask;
for (int i = 0; i < 8; ++i) {
bool isSet = mask & (1 << i);
shuffleMask.push_back(!isSet ? i : i + 8);
}
return b.create<vector::ShuffleOp>(v1, v2, shuffleMask);
}
/// AVX2 4x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose4x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
#ifndef NDEBUG
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
assert(vs.size() == 4 && "expects 4 vectors");
assert(llvm::all_of(ValueRange{vs}.getTypes(),
[&](Type t) { return t == vt; }) &&
"expects all types to be vector<8xf32>");
#endif
Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());
Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());
Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());
Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());
vs[0] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<2, 0>());
vs[1] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<2, 0>());
vs[2] = mm256Permute2f128Ps(ib, s0, s1, MaskHelper::permute<3, 1>());
vs[3] = mm256Permute2f128Ps(ib, s2, s3, MaskHelper::permute<3, 1>());
}
/// AVX2 8x8xf32-specific transpose lowering using a "C intrinsics" model.
void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
MutableArrayRef<Value> vs) {
auto vt = VectorType::get({8}, Float32Type::get(ib.getContext()));
(void)vt;
assert(vs.size() == 8 && "expects 8 vectors");
assert(llvm::all_of(ValueRange{vs}.getTypes(),
[&](Type t) { return t == vt; }) &&
"expects all types to be vector<8xf32>");
Value t0 = mm256UnpackLoPs(ib, vs[0], vs[1]);
Value t1 = mm256UnpackHiPs(ib, vs[0], vs[1]);
Value t2 = mm256UnpackLoPs(ib, vs[2], vs[3]);
Value t3 = mm256UnpackHiPs(ib, vs[2], vs[3]);
Value t4 = mm256UnpackLoPs(ib, vs[4], vs[5]);
Value t5 = mm256UnpackHiPs(ib, vs[4], vs[5]);
Value t6 = mm256UnpackLoPs(ib, vs[6], vs[7]);
Value t7 = mm256UnpackHiPs(ib, vs[6], vs[7]);
using inline_asm::mm256BlendPsAsm;
Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());
Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());
Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());
Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());
Value s0 =
mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value s1 =
mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
Value s2 =
mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value s3 =
mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
Value s4 =
mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value s5 =
mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
Value s6 =
mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());
Value s7 =
mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());
vs[0] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<2, 0>());
vs[1] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<2, 0>());
vs[2] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<2, 0>());
vs[3] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<2, 0>());
vs[4] = mm256Permute2f128Ps(ib, s0, s4, MaskHelper::permute<3, 1>());
vs[5] = mm256Permute2f128Ps(ib, s1, s5, MaskHelper::permute<3, 1>());
vs[6] = mm256Permute2f128Ps(ib, s2, s6, MaskHelper::permute<3, 1>());
vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
}
/// Rewrite AVX2-specific vector.transpose, for the supported cases and
/// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
/// transpose cases and n-D cases that have been decomposed into 2-D
/// transposition slices. For example, a 3-D transpose:
///
/// %0 = vector.transpose %arg0, [2, 0, 1]
/// : vector<1024x2048x4096xf32> to vector<4096x1024x2048xf32>
///
/// could be sliced into 2-D transposes by tiling two of its dimensions to one
/// of the vector lengths supported by the AVX2 patterns (e.g., 4x8):
///
/// %0 = vector.transpose %arg0, [2, 0, 1]
/// : vector<1x4x8xf32> to vector<8x1x4xf32>
///
/// This lowering will analyze the n-D vector.transpose and determine if it's a
/// supported 2-D transposition slice where any of the AVX2 patterns can be
/// applied.
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context,
int benefit)
: OpRewritePattern<vector::TransposeOp>(context, benefit),
loweringOptions(loweringOptions) {}
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
// Check if the source vector type is supported. AVX2 patterns can only be
// applied to f32 vector types with two dimensions greater than one.
VectorType srcType = op.getSourceVectorType();
if (!srcType.getElementType().isF32())
return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
if (failed(srcGtOneDims))
return rewriter.notifyMatchFailure(
op, "expected transposition on a 2D slice");
// Retrieve the sizes of the two dimensions greater than one to be
// transposed.
int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));
auto applyRewrite = [&]() {
ImplicitLocOpBuilder ib(loc, rewriter);
SmallVector<Value> vs;
// Reshape the n-D input vector with only two dimensions greater than one
// to a 2-D vector.
auto flattenedType =
VectorType::get({n * m}, op.getSourceVectorType().getElementType());
auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
auto reshInput =
ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
reshInput = ib.create<vector::ShapeCastOp>(reshInputType, reshInput);
// Extract 1-D vectors from the higher-order dimension of the input
// vector.
for (int64_t i = 0; i < m; ++i)
vs.push_back(ib.create<vector::ExtractOp>(reshInput, i));
// Transpose set of 1-D vectors.
if (m == 4)
transpose4x8xf32(ib, vs);
if (m == 8)
transpose8x8xf32(ib, vs);
// Insert transposed 1-D vectors into the higher-order dimension of the
// output vector.
Value res = ib.create<arith::ConstantOp>(reshInputType,
ib.getZeroAttr(reshInputType));
for (int64_t i = 0; i < m; ++i)
res = ib.create<vector::InsertOp>(vs[i], res, i);
// The output vector still has the shape of the input vector (e.g., 4x8).
// We have to transpose their dimensions and retrieve its original rank
// (e.g., 1x8x1x4x1).
res = ib.create<vector::ShapeCastOp>(flattenedType, res);
res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
rewriter.replaceOp(op, res);
return success();
};
if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)
return applyRewrite();
if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)
return applyRewrite();
return failure();
}
private:
LoweringOptions loweringOptions;
};
void mlir::x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
RewritePatternSet &patterns, LoweringOptions options, int benefit) {
patterns.add<TransposeOpLowering>(options, patterns.getContext(), benefit);
}
|