File: MemRefUtils.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (125 lines) | stat: -rw-r--r-- 4,868 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for the MemRef dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"

namespace mlir {
namespace memref {

bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
  if (!type.hasStaticShape())
    return false;

  SmallVector<int64_t> strides;
  int64_t offset;
  if (failed(getStridesAndOffset(type, strides, offset)))
    return false;

  // MemRef is contiguous if outer dimensions are size-1 and inner
  // dimensions have unit strides.
  int64_t runningStride = 1;
  int64_t curDim = strides.size() - 1;
  // Finds all inner dimensions with unit strides.
  while (curDim >= 0 && strides[curDim] == runningStride) {
    runningStride *= type.getDimSize(curDim);
    --curDim;
  }

  // Check if other dimensions are size-1.
  while (curDim >= 0 && type.getDimSize(curDim) == 1) {
    --curDim;
  }

  // All dims are unit-strided or size-1.
  return curDim < 0;
}

std::pair<Value, Value>
getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
                            int dstBits, SmallVector<Value> indices,
                            memref::ExtractStridedMetadataOp stridedMetadata,
                            OpBuilder &builder) {
  auto srcElementType = sourceType.getElementType();
  unsigned sourceRank = indices.size();

  Value baseBuffer = stridedMetadata.getBaseBuffer();
  SmallVector<Value> baseSizes = stridedMetadata.getSizes();
  SmallVector<Value> baseStrides = stridedMetadata.getStrides();
  Value baseOffset = stridedMetadata.getOffset();
  assert(indices.size() == baseStrides.size());

  // Create the affine symbols and values for linearization.
  SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
  bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
  symbols[0] = builder.getAffineSymbolExpr(0);
  AffineExpr addMulMap = symbols.front();
  AffineExpr mulMap = symbols.front();

  SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
  offsetValues[0] = builder.getIndexAttr(0);
  SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
  sizeValues[0] = builder.getIndexAttr(1);

  for (unsigned i = 0; i < sourceRank; ++i) {
    unsigned offsetIdx = 2 * i + 1;
    addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
    offsetValues[offsetIdx] = indices[i];
    offsetValues[offsetIdx + 1] = baseStrides[i];

    unsigned sizeIdx = i + 1;
    mulMap = mulMap * symbols[sizeIdx];
    sizeValues[sizeIdx] = baseSizes[i];
  }

  // Adjust linearizedOffset by the scale factor (dstBits / srcBits).
  OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
  AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
  offsetValues.back() = scaler;

  OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
      builder, loc, scaledAddMulMap, offsetValues);
  OpFoldResult linearizedSize =
      affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);

  // Adjust baseOffset by the scale factor (dstBits / srcBits).
  AffineExpr s0, s1;
  bindSymbols(builder.getContext(), s0, s1);
  OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
      builder, loc, s0.floorDiv(s1), {baseOffset, scaler});

  // Flatten n-D MemRef to 1-D MemRef.
  std::optional<int64_t> stride =
      getConstantIntValue(stridedMetadata.getConstifiedMixedStrides().back());
  auto layoutAttr =
      StridedLayoutAttr::get(sourceType.getContext(), ShapedType::kDynamic,
                             {stride ? stride.value() : ShapedType::kDynamic});
  int64_t staticShape = sourceType.hasStaticShape()
                            ? sourceType.getNumElements()
                            : ShapedType::kDynamic;
  auto flattenMemrefType = MemRefType::get(
      staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());

  auto reinterpret = builder.create<memref::ReinterpretCastOp>(
      loc, flattenMemrefType, baseBuffer,
      getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
      getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
      baseStrides.back());

  return std::make_pair(reinterpret, getValueOrCreateConstantIndexOp(
                                         builder, loc, linearizedOffset));
}

} // namespace memref
} // namespace mlir